ONNX Runtime Integration¶
SHAP-Go supports ONNX models through ONNX Runtime bindings, enabling explanations for models from any framework that exports to ONNX.
Prerequisites¶
Install ONNX Runtime¶
Download from ONNX Runtime releases and add the DLL directory to your PATH.
Export Models to ONNX¶
From scikit-learn¶
import sklearn
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
# Train model
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
# Convert to ONNX
initial_type = [('float_input', FloatTensorType([None, X_train.shape[1]]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)
# Save
with open("model.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
From PyTorch¶
import torch
# Your trained model
model = MyModel()
model.load_state_dict(torch.load("model.pt"))
model.eval()
# Export to ONNX
dummy_input = torch.randn(1, num_features)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=['float_input'],
output_names=['output'],
dynamic_axes={'float_input': {0: 'batch_size'}}
)
From TensorFlow/Keras¶
import tf2onnx
import tensorflow as tf
# Load your model
model = tf.keras.models.load_model("model.h5")
# Convert to ONNX
spec = (tf.TensorSpec((None, num_features), tf.float32, name="input"),)
output_path = "model.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, output_path=output_path)
Basic Usage¶
package main
import (
"context"
"fmt"
"log"
"github.com/plexusone/shap-go/explainer"
"github.com/plexusone/shap-go/explainer/permutation"
"github.com/plexusone/shap-go/model/onnx"
)
func main() {
// Initialize ONNX Runtime (once per application)
err := onnx.InitializeRuntime("/usr/local/lib/libonnxruntime.dylib")
if err != nil {
log.Fatal(err)
}
defer onnx.DestroyRuntime()
// Create ONNX session
session, err := onnx.NewSession(onnx.Config{
ModelPath: "model.onnx",
InputName: "float_input",
OutputName: "output",
NumFeatures: 10,
})
if err != nil {
log.Fatal(err)
}
defer session.Close()
// Background data for SHAP
background := [][]float64{
{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5},
// ... more representative samples
}
// Create explainer
exp, err := permutation.New(session, background,
explainer.WithNumSamples(100),
explainer.WithFeatureNames([]string{"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"}),
)
if err != nil {
log.Fatal(err)
}
// Explain a prediction
ctx := context.Background()
instance := []float64{0.3, 0.7, 0.2, 0.8, 0.4, 0.6, 0.1, 0.9, 0.5, 0.5}
explanation, err := exp.Explain(ctx, instance)
if err != nil {
log.Fatal(err)
}
// Print results
fmt.Printf("Prediction: %.4f\n", explanation.Prediction)
for _, f := range explanation.TopFeatures(5) {
fmt.Printf(" %s: %.4f\n", f.Name, f.SHAPValue)
}
}
Configuration¶
Session Config¶
session, _ := onnx.NewSession(onnx.Config{
// Required: path to ONNX model file
ModelPath: "model.onnx",
// Required: input tensor name (from model)
InputName: "float_input",
// Required: output tensor name (from model)
OutputName: "output",
// Required: number of input features
NumFeatures: 10,
// Optional: output index for multi-output models
OutputIndex: 0,
})
Finding Input/Output Names¶
Use netron to visualize your ONNX model:
Or use Python:
import onnx
model = onnx.load("model.onnx")
print("Inputs:", [i.name for i in model.graph.input])
print("Outputs:", [o.name for o in model.graph.output])
Batch Predictions¶
The ONNX session supports efficient batch predictions:
// Single prediction
pred, _ := session.Predict(ctx, instance)
// Batch prediction (more efficient)
instances := [][]float64{
{0.1, 0.2, ...},
{0.3, 0.4, ...},
}
predictions, _ := session.PredictBatch(ctx, instances)
Performance Tips¶
Reuse Session¶
Create the session once and reuse it:
// Good: Create once
session, _ := onnx.NewSession(config)
defer session.Close()
for _, instance := range instances {
pred, _ := session.Predict(ctx, instance)
}
// Bad: Creating session each time
for _, instance := range instances {
session, _ := onnx.NewSession(config) // Slow!
pred, _ := session.Predict(ctx, instance)
session.Close()
}
Optimize Background Size¶
ONNX models can be slow, so minimize background data:
// Use k-means to summarize background data
background := kMeansSummarize(fullData, 20) // 20 centroids
// Reduce sample count for faster explanations
exp, _ := permutation.New(session, background,
explainer.WithNumSamples(50), // Fewer samples
)
Use Parallel Workers¶
exp, _ := permutation.New(session, background,
explainer.WithNumSamples(100),
explainer.WithNumWorkers(4), // Parallel computation
)
Troubleshooting¶
"failed to initialize runtime"¶
The ONNX Runtime library wasn't found:
Common locations:
- macOS (Homebrew):
/usr/local/lib/libonnxruntime.dylib - Linux:
/usr/local/lib/libonnxruntime.so - Windows:
C:\Program Files\onnxruntime\lib\onnxruntime.dll
"input name not found"¶
The input name doesn't match the model:
# Check your model's input names
import onnx
model = onnx.load("model.onnx")
print([i.name for i in model.graph.input])
"shape mismatch"¶
Input dimensions don't match:
// Make sure NumFeatures matches model's expected input
session, _ := onnx.NewSession(onnx.Config{
NumFeatures: 10, // Must match model's input shape
})
Multi-Output Models¶
For models with multiple outputs (e.g., probabilities for each class):
session, _ := onnx.NewSession(onnx.Config{
ModelPath: "model.onnx",
InputName: "float_input",
OutputName: "probabilities",
NumFeatures: 10,
OutputIndex: 1, // Use second output (e.g., probability of class 1)
})
Next Steps¶
- PermutationSHAP Guide - Deep dive into black-box explanations
- Visualization - Create charts from explanations
- Benchmarks - Performance characteristics