DeepSHAP¶
DeepSHAP is an efficient algorithm for computing SHAP values for deep neural networks. It combines the DeepLIFT algorithm with Shapley values to provide theoretically grounded feature attributions.
Overview¶
DeepSHAP works by:
- Running a forward pass to capture activations at each layer
- Computing reference activations from a background dataset
- Propagating attribution multipliers backward using the DeepLIFT rescale rule
- Averaging attributions over multiple background samples
Key Properties¶
| Property | Value |
|---|---|
| Accuracy | Approximate (DeepLIFT-based) |
| Complexity | O(layers × neurons × background samples) |
| Background data | Required |
| Local accuracy | Approximately satisfied |
Quick Start¶
package main
import (
"context"
"fmt"
"log"
"github.com/plexusone/shap-go/explainer"
"github.com/plexusone/shap-go/explainer/deepshap"
"github.com/plexusone/shap-go/model/onnx"
)
func main() {
// Initialize ONNX Runtime
if err := onnx.InitializeRuntime("/path/to/libonnxruntime.so"); err != nil {
log.Fatal(err)
}
defer onnx.DestroyRuntime()
// Parse the ONNX model graph
graphInfo, err := onnx.ParseGraph("model.onnx")
if err != nil {
log.Fatal(err)
}
// Create activation session with intermediate outputs
config := onnx.ActivationConfig{
Config: onnx.Config{
ModelPath: "model.onnx",
InputName: "input",
OutputName: "output",
NumFeatures: 10,
},
IntermediateOutputs: graphInfo.GetAllLayerOutputs(),
}
session, err := onnx.NewActivationSession(config)
if err != nil {
log.Fatal(err)
}
defer session.Close()
// Background data for SHAP computation
background := [][]float64{
{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
// ... more background samples
}
// Create DeepSHAP explainer
exp, err := deepshap.New(session, graphInfo, background,
explainer.WithFeatureNames([]string{"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10"}),
)
if err != nil {
log.Fatal(err)
}
// Explain a prediction
ctx := context.Background()
instance := []float64{0.5, 0.3, 0.8, 0.2, 0.9, 0.1, 0.7, 0.4, 0.6, 0.5}
explanation, err := exp.Explain(ctx, instance)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Prediction: %.4f\n", explanation.Prediction)
fmt.Printf("Base Value: %.4f\n", explanation.BaseValue)
for _, feat := range explanation.TopFeatures(5) {
fmt.Printf(" %s: %.4f\n", feat.Name, feat.Value)
}
}
Simplified Usage¶
If you don't need full graph structure, use NewSimple:
This uses a simpler attribution method that doesn't require graph parsing.
Supported Layer Types¶
DeepSHAP supports common neural network layer types:
| Layer Type | ONNX Op Types | Attribution Rule |
|---|---|---|
| Dense | Gemm, MatMul | Linear backprop through weights |
| ReLU | Relu | DeepLIFT rescale rule |
| Sigmoid | Sigmoid | DeepLIFT rescale rule |
| Tanh | Tanh | DeepLIFT rescale rule |
| Softmax | Softmax | DeepLIFT rescale rule |
| Add | Add | Proportional split |
| Identity | Identity, Dropout, Flatten | Pass-through |
Attribution Rules¶
DeepSHAP uses the DeepLIFT rescale rule for propagating attributions:
Where:
mult_in/mult_outare input/output multipliersx,x_refare input activations for instance/referencey,y_refare output activations for instance/reference
When the denominator is near zero, the rule falls back to gradient computation.
Background Dataset¶
The background dataset serves as the baseline for attribution:
- Use representative samples from your training data
- 100-1000 samples typically provides good results
- More samples improve accuracy but increase computation time
// Use k-means to summarize large datasets
bgDataset := background.NewDataset(trainingData, featureNames)
summary := bgDataset.KMeansSummary(100, 10, rng)
Configuration Options¶
exp, err := deepshap.New(session, graphInfo, background,
explainer.WithFeatureNames(names), // Human-readable names
explainer.WithModelID("my-nn"), // Model identifier
)
When to Use DeepSHAP¶
Use DeepSHAP when:
- You have a neural network model in ONNX format
- The network uses supported layer types (Dense, ReLU, Sigmoid, etc.)
- You need efficient explanations for deep networks
- Local accuracy (sum of SHAP ≈ prediction - baseline) is important
Don't use DeepSHAP when:
- Your model is a tree ensemble (use TreeSHAP)
- Your model is linear (use LinearSHAP)
- You need convolutional layer support (not yet implemented)
- You need guaranteed exact values (use PermutationSHAP or ExactSHAP)
Limitations¶
Current limitations of the DeepSHAP implementation:
- Sequential networks only: Residual connections have limited support
- No convolutional layers: Conv2D, MaxPool, BatchNorm not yet supported
- Single output: Multi-output models not yet supported
- Dense networks: Best suited for fully-connected architectures
Technical Details¶
ONNX Graph Parsing¶
DeepSHAP parses the ONNX model structure to understand layer connectivity:
graphInfo, err := onnx.ParseGraph("model.onnx")
// Get all layer outputs for activation capture
outputs := graphInfo.GetAllLayerOutputs()
// Get nodes in reverse topological order for backprop
reversed := graphInfo.ReverseTopologicalOrder()
Activation Capture¶
The ActivationSession captures intermediate layer outputs:
config := onnx.ActivationConfig{
Config: onnx.Config{...},
IntermediateOutputs: []string{"dense1_out", "relu1_out", "dense2_out"},
}
session, _ := onnx.NewActivationSession(config)
result, _ := session.PredictWithActivations(ctx, input)
// result.Prediction contains final output
// result.Activations contains intermediate values
Backward Propagation¶
The propagation engine traverses the graph in reverse order:
engine := deepshap.NewPropagationEngine(graphInfo)
result, _ := engine.Propagate(instanceAct, referenceAct, 1.0)
// result.Attributions contains SHAP values
References¶
- SHAP paper: A Unified Approach to Interpreting Model Predictions
- DeepLIFT paper: Learning Important Features Through Propagating Activation Differences
- Python SHAP library: Original implementation