Census income classification with scikit-learn

This example uses the standard adult census income dataset from the UCI machine learning data repository. We train a k-nearest neighbors classifier using sci-kit learn and then explain the predictions.

[1]:
import sklearn
import shap

shap.initjs()

Load the census data

[2]:
X,y = shap.datasets.adult()
X_display,y_display = shap.datasets.adult(display=True)
X_train, X_valid, y_train, y_valid = sklearn.model_selection.train_test_split(X, y, test_size=0.2, random_state=7)

Train a k-nearest neighbors classifier

Here we just train directly on the raw data, without any normalizations.

[3]:
knn = sklearn.neighbors.KNeighborsClassifier()
knn.fit(X_train, y_train)
[3]:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform')

Explain predictions

Normally we would use a logit link function to allow the additive feature inputs to better map to the model’s probabilistic output space, but knn’s can produce infinite log odds ratios so we don’t for this example.

It is important to note that fnlwgt (a statistical reweighting term) is the dominant feature in the 1000 predictions we explain. This is because it has larger variations in value than the other features and so it impacts the k-nearest neighbors calculations more.

[4]:
f = lambda x: knn.predict_proba(x)[:,1]
med = X_train.median().values.reshape((1,X_train.shape[1]))
explainer = shap.KernelExplainer(f, med)
shap_values_single = explainer.shap_values(X.iloc[0,:], nsamples=1000)
shap.force_plot(explainer.expected_value, shap_values_single, X_display.iloc[0,:])
[4]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.
[5]:
shap_values = explainer.shap_values(X_valid.iloc[0:1000,:], nsamples=1000)
shap.force_plot(explainer.expected_value, shap_values, X_valid.iloc[0:1000,:])
100%|██████████| 1000/1000 [00:12<00:00, 81.09it/s]
[5]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.

A summary plot is an even better way to see the relative impact of all features over the entire dataset. Features are sorted by the sum of their SHAP value magnitudes across all samples.

[6]:
shap.summary_plot(shap_values, X_valid.iloc[0:1000,:])
../../_images/example_notebooks_kernel_explainer_Census_income_classification_with_scikit-learn_10_0.png

Normalize the data before training the model

Here we retrain a KNN model on standardized data.

[7]:
# normalize data
dtypes = list(zip(X.dtypes.index, map(str, X.dtypes)))
X_train_norm = X_train.copy()
X_valid_norm = X_valid.copy()
for k,dtype in dtypes:
    m = X_train[k].mean()
    s = X_train[k].std()
    X_train_norm[k] -= m
    X_train_norm[k] /= s

    X_valid_norm[k] -= m
    X_valid_norm[k] /= s
[8]:
knn_norm = sklearn.neighbors.KNeighborsClassifier()
knn_norm.fit(X_train_norm, y_train)
[8]:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform')

Explain predictions

When we explain predictions from the new KNN model we find that fnlwgt is no longer the dominate feature, but instead more natural features, such as marital status, drive most predictions. This is simple example of hoe explaining why your model is making it’s predicitons can uncover problems in the training process.

[9]:
f = lambda x: knn_norm.predict_proba(x)[:,1]
med = X_train_norm.median().values.reshape((1,X_train_norm.shape[1]))
explainer = shap.KernelExplainer(f, med)
shap_values_norm = explainer.shap_values(X_valid_norm.iloc[0:1000,:], nsamples=1000)
shap.force_plot(explainer.expected_value, shap_values_norm, X_valid.iloc[0:1000,:])
100%|██████████| 1000/1000 [00:43<00:00, 23.15it/s]
[9]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.

With a summary plot with see marital status is the most important on average, but other features (such as captial gain) can have more impact on a particular individual.

[10]:
shap.summary_plot(shap_values_norm, X_valid.iloc[0:1000,:])
../../_images/example_notebooks_kernel_explainer_Census_income_classification_with_scikit-learn_17_0.png

A dependence plot shows how the number of years of education increases the chance of making over 50K annually.

[11]:
shap.dependence_plot("Education-Num", shap_values_norm, X_valid.iloc[0:1000,:])
../../_images/example_notebooks_kernel_explainer_Census_income_classification_with_scikit-learn_19_0.png

Contrast this with how the number of years of education has no trend in the KNN model from unnormalized data:

[12]:
shap.dependence_plot("Education-Num", shap_values, X_valid.iloc[0:1000,:])
../../_images/example_notebooks_kernel_explainer_Census_income_classification_with_scikit-learn_21_0.png