Simple Kernel SHAP

This notebook provides a simple brute force version of Kernel SHAP that enumerates the entire \(2^M\) sample space. We also compare to the full KernelExplainer implementation. Note that KernelExplainer does a sampling approximation for large values of \(M\), but for small values it is exact.

Brute Force Kernel SHAP

[1]:
import scipy.special
import numpy as np
import itertools

def powerset(iterable):
    s = list(iterable)
    return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))

def shapley_kernel(M,s):
    if s == 0 or s == M:
        return 10000
    return (M-1)/(scipy.special.binom(M,s)*s*(M-s))

def f(X):
    np.random.seed(0)
    beta = np.random.rand(X.shape[-1])
    return np.dot(X,beta) + 10

def kernel_shap(f, x, reference, M):
    X = np.zeros((2**M,M+1))
    X[:,-1] = 1
    weights = np.zeros(2**M)
    V = np.zeros((2**M,M))
    for i in range(2**M):
        V[i,:] = reference

    ws = {}
    for i,s in enumerate(powerset(range(M))):
        s = list(s)
        V[i,s] = x[s]
        X[i,s] = 1
        ws[len(s)] = ws.get(len(s), 0) + shapley_kernel(M,len(s))
        weights[i] = shapley_kernel(M,len(s))
    y = f(V)
    tmp = np.linalg.inv(np.dot(np.dot(X.T, np.diag(weights)), X))
    return np.dot(tmp, np.dot(np.dot(X.T, np.diag(weights)), y))

M = 4
np.random.seed(1)
x = np.random.randn(M)
reference = np.zeros(M)
phi = kernel_shap(f, x, reference, M)
base_value = phi[-1]
shap_values = phi[:-1]

print("  reference =", reference)
print("          x =", x)
print("shap_values =", shap_values)
print(" base_value =", base_value)
print("   sum(phi) =", np.sum(phi))
print("       f(x) =", f(x))
  reference = [0. 0. 0. 0.]
          x = [ 1.62434536 -0.61175641 -0.52817175 -1.07296862]
shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]
 base_value = 9.999999999999996
   sum(phi) = 9.55093584211863
       f(x) = 9.55093584213122

Using KernelExplainer

[2]:
import shap

explainer = shap.KernelExplainer(f, np.reshape(reference, (1, len(reference))))
shap_values = explainer.shap_values(x)
print("shap_values =", shap_values)
print("base value =", explainer.expected_value)
shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]
base value = 10.0