TreeSHAP¶
TreeSHAP computes exact SHAP values for tree ensemble models (XGBoost, LightGBM) using a polynomial-time algorithm that exploits the tree structure.
Why TreeSHAP?¶
| Property | TreeSHAP | Permutation SHAP |
|---|---|---|
| Accuracy | Exact | Exact (with enough samples) |
| Speed | O(TLD²) | O(samples × features × model calls) |
| Typical time | 8-200μs | 15-150ms |
| Background data | Not needed | Required |
For a model with 100 trees, TreeSHAP is typically 100-1000x faster than sampling-based methods.
Basic Usage¶
From XGBoost¶
package main
import (
"context"
"fmt"
"log"
"github.com/plexusone/shap-go/explainer/tree"
)
func main() {
// Load XGBoost model (exported from Python as JSON)
ensemble, err := tree.LoadXGBoostModel("model.json")
if err != nil {
log.Fatal(err)
}
// Create explainer
exp, err := tree.New(ensemble)
if err != nil {
log.Fatal(err)
}
// Explain a prediction
ctx := context.Background()
instance := []float64{0.5, 0.3, 0.8, 0.2, 0.9}
explanation, err := exp.Explain(ctx, instance)
if err != nil {
log.Fatal(err)
}
// Print results
fmt.Printf("Prediction: %.4f\n", explanation.Prediction)
fmt.Printf("Base Value: %.4f\n", explanation.BaseValue)
fmt.Println("\nTop contributing features:")
for _, feat := range explanation.TopFeatures(5) {
sign := "+"
if feat.SHAPValue < 0 {
sign = "-"
}
fmt.Printf(" %s %s: %.4f\n", sign, feat.Name, feat.SHAPValue)
}
}
From LightGBM¶
// Load LightGBM model (exported with dump_model())
ensemble, err := tree.LoadLightGBMModel("model.json")
if err != nil {
log.Fatal(err)
}
exp, err := tree.New(ensemble)
// ... same as above
Batch Processing¶
Explain multiple instances efficiently with parallel workers:
// Create explainer with parallel processing
exp, _ := tree.New(ensemble,
explainer.WithNumWorkers(4), // Use 4 CPU cores
)
// Prepare batch
instances := [][]float64{
{0.1, 0.2, 0.3, 0.4, 0.5},
{0.5, 0.4, 0.3, 0.2, 0.1},
// ... more instances
}
// Explain all at once
ctx := context.Background()
explanations, err := exp.ExplainBatch(ctx, instances)
if err != nil {
log.Fatal(err)
}
// Process results
for i, exp := range explanations {
fmt.Printf("Instance %d: prediction=%.4f\n", i, exp.Prediction)
}
Understanding TreeSHAP Output¶
Base Value¶
The base value is the expected prediction across all possible feature coalitions, weighted by the training data distribution (cover). This is computed directly from the tree structure.
SHAP Values¶
Each feature gets a SHAP value indicating how much it pushed the prediction away from the base value:
- Positive SHAP: Feature increased the prediction
- Negative SHAP: Feature decreased the prediction
- Zero SHAP: Feature had no effect
Local Accuracy¶
TreeSHAP guarantees:
Verify this:
result := explanation.Verify(1e-9)
if !result.Valid {
log.Printf("Warning: accuracy issue, diff=%.2e\n", result.Difference)
}
Configuration Options¶
exp, err := tree.New(ensemble,
// Feature names (optional - uses model's if available)
explainer.WithFeatureNames([]string{"age", "income", "score"}),
// Model identifier for tracking
explainer.WithModelID("loan-approval-v2"),
// Parallel workers for batch processing
explainer.WithNumWorkers(4),
)
Working with Tree Ensembles¶
Inspecting the Model¶
ensemble, _ := tree.LoadXGBoostModel("model.json")
fmt.Printf("Trees: %d\n", ensemble.NumTrees)
fmt.Printf("Features: %d\n", ensemble.NumFeatures)
fmt.Printf("Max depth: %d\n", ensemble.MaxDepth())
fmt.Printf("Base score: %.4f\n", ensemble.BaseScore)
Feature Names¶
If your model has feature names:
Serialization¶
Save/load the unified ensemble format:
// Save
jsonData, _ := ensemble.ToJSON()
os.WriteFile("ensemble.json", jsonData, 0644)
// Load
data, _ := os.ReadFile("ensemble.json")
loaded, _ := tree.EnsembleFromJSON(data)
Interaction Values¶
TreeSHAP can compute pairwise feature interactions, which reveal how features work together to influence predictions. The interaction matrix is symmetric and satisfies key SHAP properties.
Computing Interactions¶
// Compute SHAP interaction values
result, err := exp.ExplainInteractions(ctx, instance)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Prediction: %.4f\n", result.Prediction)
fmt.Printf("Base Value: %.4f\n", result.BaseValue)
Interaction Matrix Properties¶
The interaction matrix Φ has these properties:
- Diagonal
Φ[i][i]: Main effect of feature i - Off-diagonal
Φ[i][j]: Interaction between features i and j - Symmetric:
Φ[i][j] == Φ[j][i] - Rows sum to SHAP:
sum(Φ[i][:]) == SHAP[i] - Total sum:
sum(all Φ) == prediction - base_value
Working with Interactions¶
// Get interaction between two features
interaction := result.GetInteraction(0, 1)
// Get main effect (diagonal)
mainEffect := result.GetMainEffect(0)
// Get derived SHAP value (row sum)
shapValue := result.GetSHAPValue(0)
// Get top k strongest interactions (by absolute value)
topK := result.TopInteractions(5)
for _, inter := range topK {
fmt.Printf("%s <-> %s: %.4f\n", inter.Name1, inter.Name2, inter.Value)
}
When to Use Interactions¶
Interaction values are useful when:
- Understanding feature synergies: Which features amplify each other's effects?
- Model debugging: Identify unexpected feature dependencies
- Feature engineering: Discover features that should be combined
- Regulatory compliance: Explain complex model behavior
Complexity¶
Computing interactions is more expensive than regular SHAP values:
- Regular TreeSHAP: O(TLD²)
- TreeSHAP Interactions: O(TLD² × D) where D is the number of features
For large feature sets, consider computing interactions only for instances of interest.
Algorithm Details¶
TreeSHAP uses a path-based algorithm that:
- Traverses each tree from root to leaf
- Tracks which features are "on path" vs "off path"
- Computes contributions using combinatorial weights
- Accumulates SHAP values across all trees
Complexity¶
- Time: O(T × L × D²) where T=trees, L=leaves per tree, D=depth
- Space: O(D) for the path tracking
Handling Missing Values¶
TreeSHAP respects the model's missing value handling. If a feature is NaN:
instance := []float64{0.5, math.NaN(), 0.8} // Feature 1 is missing
explanation, _ := exp.Explain(ctx, instance)
// Uses the tree's default direction for missing values
Performance Tips¶
- Use batch processing for multiple instances
- Set appropriate workers based on CPU cores
- Reuse the explainer - don't recreate for each instance
// Good: Create once, use many times
exp, _ := tree.New(ensemble, explainer.WithNumWorkers(4))
for _, instance := range instances {
explanation, _ := exp.Explain(ctx, instance)
// process...
}
// Bad: Recreating explainer each time
for _, instance := range instances {
exp, _ := tree.New(ensemble) // Wasteful!
explanation, _ := exp.Explain(ctx, instance)
}
Troubleshooting¶
"model has no trees"¶
The JSON file doesn't contain tree data. Ensure you're using the correct export format:
- XGBoost:
model.save_model("model.json") - LightGBM:
json.dump(model.dump_model(), f)
Feature count mismatch¶
Your input doesn't match the model's expected features. Check ensemble.NumFeatures.
Local accuracy violation¶
TreeSHAP should always satisfy local accuracy. If you see violations, it might indicate:
- Floating point precision issues (use tolerance like
1e-9) - Bug in model loading (please report!)
Next Steps¶
- XGBoost Integration - Export models from Python
- LightGBM Integration - Export models from Python
- Visualization - Create charts from explanations
- Benchmarks - Performance data