API Reference¶
Complete API documentation for SHAP-Go packages.
Package Overview¶
| Package | Description |
|---|---|
explainer | Core interfaces, types, and configuration options |
explainer/tree | TreeSHAP for tree ensembles (XGBoost, LightGBM, CatBoost) |
explainer/linear | LinearSHAP for linear models |
explainer/kernel | KernelSHAP for model-agnostic explanations |
explainer/exact | ExactSHAP for brute-force exact computation |
explainer/deepshap | DeepSHAP for neural networks |
explainer/gradient | GradientSHAP using expected gradients |
explainer/partition | PartitionSHAP for hierarchical feature groupings |
explainer/additive | AdditiveSHAP for Generalized Additive Models |
explainer/permutation | PermutationSHAP for black-box models |
explainer/sampling | SamplingSHAP (Monte Carlo approximation) |
explanation | Explanation types and methods |
model | Model interfaces and adapters |
model/onnx | ONNX Runtime integration |
background | Background data utilities |
masker | Feature masking strategies |
render | Visualization chart generation |
explainer¶
Core interfaces, types, and configuration shared by all explainers.
Explainer Interface¶
type Explainer interface {
// Explain computes SHAP values for a single instance
Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
// BaseValue returns E[f(X)] - expected model output
BaseValue() float64
// FeatureNames returns the feature names
FeatureNames() []string
}
Config¶
type Config struct {
// NumSamples is the number of Monte Carlo samples (default: 100)
NumSamples int
// Seed is the random seed for reproducibility (nil uses current time)
Seed *int64
// NumWorkers is the number of parallel workers (0 = sequential)
NumWorkers int
// ModelID is an optional identifier for the model
ModelID string
// FeatureNames are the names of the input features
FeatureNames []string
// ConfidenceLevel for confidence intervals (0 = disabled, e.g., 0.95 for 95% CI)
ConfidenceLevel float64
// UseBatchedPredictions enables batched model predictions for efficiency
UseBatchedPredictions bool
}
Options¶
// WithNumSamples sets the number of samples for sampling-based explainers
func WithNumSamples(n int) Option
// WithSeed sets random seed for reproducibility
func WithSeed(seed int64) Option
// WithNumWorkers sets parallel workers for computation
func WithNumWorkers(n int) Option
// WithModelID sets the model identifier
func WithModelID(id string) Option
// WithFeatureNames sets feature names for explanations
func WithFeatureNames(names []string) Option
// WithConfidenceLevel sets confidence level for intervals (e.g., 0.95 for 95% CI)
func WithConfidenceLevel(level float64) Option
// WithBatchedPredictions enables batched model predictions for efficiency
func WithBatchedPredictions(enabled bool) Option
Batch Parallel API¶
// ExplainBatchParallel explains multiple instances in parallel using any explainer
func ExplainBatchParallel[E Explainer](
ctx context.Context,
exp E,
instances [][]float64,
config BatchConfig,
) ([]*explanation.Explanation, error)
// ExplainBatchWithProgress explains with progress callback
func ExplainBatchWithProgress[E Explainer](
ctx context.Context,
exp E,
instances [][]float64,
config BatchConfig,
progress func(completed, total int),
) ([]*explanation.Explanation, error)
type BatchConfig struct {
Workers int // Number of parallel workers (0 = GOMAXPROCS)
StopOnError bool // Stop all workers on first error
}
explanation¶
Explanation types and methods.
Explanation¶
type Explanation struct {
// Values maps feature name to SHAP value
Values map[string]float64
// FeatureNames in order
FeatureNames []string
// FeatureValues maps feature name to instance value
FeatureValues map[string]float64
// Prediction is the model output for this instance
Prediction float64
// BaseValue is the expected value E[f(x)]
BaseValue float64
// ModelID identifies the model (optional)
ModelID string
// Timestamp when explanation was computed
Timestamp time.Time
// Metadata contains algorithm-specific information
Metadata ExplanationMetadata
}
ExplanationMetadata¶
type ExplanationMetadata struct {
// Algorithm used (e.g., "tree", "kernel", "permutation")
Algorithm string
// NumSamples used for sampling-based methods
NumSamples int
// BackgroundSize is the number of background samples
BackgroundSize int
// ComputeTimeMS is the computation time in milliseconds
ComputeTimeMS int64
// ConfidenceIntervals if computed
ConfidenceIntervals *ConfidenceIntervals
}
ConfidenceIntervals¶
type ConfidenceIntervals struct {
// Level is the confidence level (e.g., 0.95 for 95%)
Level float64
// Lower bounds for each feature
Lower map[string]float64
// Upper bounds for each feature
Upper map[string]float64
// StandardErrors for each feature
StandardErrors map[string]float64
}
Explanation Methods¶
// TopFeatures returns the n features with highest absolute SHAP values
func (e *Explanation) TopFeatures(n int) []FeatureContribution
// Verify checks local accuracy: sum(SHAP) ≈ prediction - baseValue
func (e *Explanation) Verify(tolerance float64) VerificationResult
// HasConfidenceIntervals returns true if confidence intervals are available
func (e *Explanation) HasConfidenceIntervals() bool
// GetConfidenceInterval returns the CI for a feature (lower, upper, ok)
func (e *Explanation) GetConfidenceInterval(feature string) (float64, float64, bool)
// ToJSON serializes to JSON
func (e *Explanation) ToJSON() ([]byte, error)
// ToJSONPretty serializes to formatted JSON
func (e *Explanation) ToJSONPretty() ([]byte, error)
Supporting Types¶
type FeatureContribution struct {
Name string
SHAPValue float64
Index int
}
type VerificationResult struct {
Valid bool // Whether within tolerance
Expected float64 // prediction - baseValue
SumSHAP float64 // sum of SHAP values
Difference float64 // |Expected - SumSHAP|
}
explainer/tree¶
TreeSHAP for tree-based models (XGBoost, LightGBM, CatBoost).
Explainer¶
// New creates a TreeSHAP explainer from a tree ensemble
func New(ensemble *TreeEnsemble, opts ...explainer.Option) (*Explainer, error)
// Explain computes exact SHAP values for an instance (O(TLD²))
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
// ExplainInteractions computes SHAP interaction values
func (e *Explainer) ExplainInteractions(ctx context.Context, instance []float64) (*InteractionResult, error)
InteractionResult¶
type InteractionResult struct {
// Interactions[i][j] is the interaction between features i and j
// Diagonal elements are main effects
Interactions [][]float64
// FeatureNames for indexing
FeatureNames []string
// Prediction for this instance
Prediction float64
// BaseValue E[f(X)]
BaseValue float64
}
// MainEffect returns the main effect for a feature
func (r *InteractionResult) MainEffect(feature int) float64
// Interaction returns the interaction between two features
func (r *InteractionResult) Interaction(i, j int) float64
TreeEnsemble¶
type TreeEnsemble struct {
Trees []*Tree
NumTrees int
NumFeatures int
FeatureNames []string
BaseScore float64
Objective string
}
// Predict computes model output for an instance
func (e *TreeEnsemble) Predict(instance []float64) float64
Model Loading¶
// XGBoost
func LoadXGBoostModel(path string) (*TreeEnsemble, error)
func LoadXGBoostModelFromReader(r io.Reader) (*TreeEnsemble, error)
func ParseXGBoostJSON(data []byte) (*TreeEnsemble, error)
// LightGBM (JSON format)
func LoadLightGBMModel(path string) (*TreeEnsemble, error)
func LoadLightGBMModelFromReader(r io.Reader) (*TreeEnsemble, error)
func ParseLightGBMJSON(data []byte) (*TreeEnsemble, error)
// LightGBM (text format)
func LoadLightGBMTextModel(path string) (*TreeEnsemble, error)
func ParseLightGBMText(data []byte) (*TreeEnsemble, error)
// CatBoost
func LoadCatBoostModel(path string) (*TreeEnsemble, error)
func ParseCatBoostJSON(data []byte) (*TreeEnsemble, error)
// ONNX-ML TreeEnsemble
func ParseONNXTreeEnsemble(modelPath string) (*TreeEnsemble, error)
func ParseONNXTreeEnsembleFromBytes(data []byte) (*TreeEnsemble, error)
explainer/linear¶
LinearSHAP for linear models with closed-form solution.
Explainer¶
// New creates a LinearSHAP explainer
// weights: model coefficients, intercept: bias term
func New(weights []float64, intercept float64, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes exact SHAP values (O(n) complexity)
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
explainer/kernel¶
KernelSHAP for model-agnostic explanations using weighted linear regression.
Explainer¶
// New creates a KernelSHAP explainer
func New(m model.Model, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes SHAP values using weighted linear regression
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
Supports: WithNumSamples, WithSeed, WithNumWorkers, WithBatchedPredictions
explainer/exact¶
ExactSHAP for brute-force exact Shapley value computation.
Explainer¶
// New creates an ExactSHAP explainer (max 20 features due to O(2^n) complexity)
func New(m model.Model, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes exact SHAP values by enumerating all 2^n coalitions
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
Supports: WithBatchedPredictions
Limitations: Maximum 20 features (configurable via MaxFeatures constant)
explainer/deepshap¶
DeepSHAP for neural networks using DeepLIFT attribution rules.
Explainer¶
// New creates a DeepSHAP explainer from an ONNX activation session
func New(session *onnx.ActivationSession, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes SHAP values using DeepLIFT rescale rule
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
Supported Layers: Dense/Gemm, ReLU, Sigmoid, Tanh, Softmax, Add, Identity
explainer/gradient¶
GradientSHAP using expected gradients with numerical differentiation.
Explainer¶
// New creates a GradientSHAP explainer
func New(m model.Model, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes SHAP values using expected gradients
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
GradientSHAP-Specific Options¶
// WithEpsilon sets the finite difference step size (default: 1e-7)
func WithEpsilon(eps float64) Option
// WithNoiseStdev sets noise standard deviation for SmoothGrad (default: 0)
func WithNoiseStdev(stdev float64) Option
// WithLocalSmoothing sets the number of noise samples (default: 1)
func WithLocalSmoothing(n int) Option
Supports: WithNumSamples, WithSeed, WithNumWorkers, WithConfidenceLevel
explainer/partition¶
PartitionSHAP for hierarchical Owen values with feature groupings.
Explainer¶
// New creates a PartitionSHAP explainer
// hierarchy: feature hierarchy tree (nil for flat mode)
func New(m model.Model, background [][]float64, hierarchy *Node, opts ...explainer.Option) (*Explainer, error)
// Explain computes SHAP values using hierarchical Owen values
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
// Hierarchy returns the feature hierarchy
func (e *Explainer) Hierarchy() *Node
Node¶
type Node struct {
// Name of this node (feature or group name)
Name string
// FeatureIdx for leaf nodes (-1 for internal nodes)
FeatureIdx int
// Children for internal nodes (empty for leaves)
Children []*Node
}
// IsLeaf returns true if this is a leaf node
func (n *Node) IsLeaf() bool
// GetFeatureIndices returns all feature indices under this node
func (n *Node) GetFeatureIndices() []int
Supports: WithNumSamples, WithSeed, WithBatchedPredictions
explainer/additive¶
AdditiveSHAP for Generalized Additive Models (GAMs).
Explainer¶
// New creates an AdditiveSHAP explainer for additive models (no interactions)
func New(m model.Model, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes exact SHAP values (O(n) complexity)
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
// Reference returns the reference point (mean of background)
func (e *Explainer) Reference() []float64
// ExpectedEffects returns precomputed E[fᵢ(Xᵢ)] for each feature
func (e *Explainer) ExpectedEffects() []float64
explainer/permutation¶
PermutationSHAP for black-box models with guaranteed local accuracy.
Explainer¶
// New creates a PermutationSHAP explainer with antithetic sampling
func New(m model.Model, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes SHAP values using permutation sampling
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
Supports: WithNumSamples, WithSeed, WithNumWorkers, WithConfidenceLevel
explainer/sampling¶
SamplingSHAP using Monte Carlo estimation.
Explainer¶
// New creates a SamplingSHAP explainer
func New(m model.Model, background [][]float64, opts ...explainer.Option) (*Explainer, error)
// Explain computes approximate SHAP values
func (e *Explainer) Explain(ctx context.Context, instance []float64) (*explanation.Explanation, error)
// ExplainBatch computes SHAP values for multiple instances
func (e *Explainer) ExplainBatch(ctx context.Context, instances [][]float64) ([]*explanation.Explanation, error)
Supports: WithNumSamples, WithSeed, WithConfidenceLevel
model¶
Model interfaces and adapters.
Model Interface¶
type Model interface {
// Predict returns model output for an input
Predict(ctx context.Context, input []float64) (float64, error)
// PredictBatch returns outputs for multiple inputs
PredictBatch(ctx context.Context, inputs [][]float64) ([]float64, error)
// NumFeatures returns the number of input features
NumFeatures() int
}
FuncModel¶
Wraps a prediction function as a Model:
// NewFuncModel creates a Model from a prediction function
func NewFuncModel(
predict func(ctx context.Context, input []float64) (float64, error),
numFeatures int,
) Model
Example:
predict := func(ctx context.Context, input []float64) (float64, error) {
return input[0]*2 + input[1]*3, nil
}
m := model.NewFuncModel(predict, 2)
model/onnx¶
ONNX Runtime integration for model inference.
Runtime Management¶
// InitializeRuntime loads the ONNX Runtime library
func InitializeRuntime(libraryPath string) error
// DestroyRuntime releases ONNX Runtime resources
func DestroyRuntime()
Session¶
// NewSession creates an ONNX inference session
func NewSession(config Config) (*Session, error)
// Predict runs inference for a single input
func (s *Session) Predict(ctx context.Context, input []float64) (float64, error)
// PredictBatch runs inference for multiple inputs
func (s *Session) PredictBatch(ctx context.Context, inputs [][]float64) ([]float64, error)
// Close releases session resources
func (s *Session) Close() error
// NumFeatures returns the number of input features
func (s *Session) NumFeatures() int
ActivationSession¶
For DeepSHAP with intermediate layer access:
// NewActivationSession creates a session that captures intermediate activations
func NewActivationSession(config ActivationConfig) (*ActivationSession, error)
// PredictWithActivations returns prediction and layer activations
func (s *ActivationSession) PredictWithActivations(ctx context.Context, input []float64) (*ActivationResult, error)
type ActivationConfig struct {
Config
IntermediateOutputs []string // Layer outputs to capture
}
type ActivationResult struct {
Prediction float64
Activations map[string][]float32
}
Graph Parsing¶
// ParseGraph parses ONNX model graph structure
func ParseGraph(modelPath string) (*GraphInfo, error)
// ParseGraphFromBytes parses from model bytes
func ParseGraphFromBytes(data []byte) (*GraphInfo, error)
type GraphInfo struct {
Nodes []NodeInfo
TopologicalOrder []string
}
type NodeInfo struct {
Name string
OpType string
LayerType LayerType
Inputs []string
Outputs []string
}
Config¶
type Config struct {
ModelPath string // Path to ONNX model file
InputName string // Input tensor name
OutputName string // Output tensor name
NumFeatures int // Number of input features
OutputIndex int // Output index for multi-output models
}
background¶
Background data utilities.
Dataset¶
type Dataset struct {
Data [][]float64
FeatureNames []string
}
// NewDataset creates a background dataset
func NewDataset(data [][]float64, featureNames []string) (*Dataset, error)
// Sample returns a random subset of the background data
func (d *Dataset) Sample(n int, rng *rand.Rand) [][]float64
// Mean returns the mean of each feature
func (d *Dataset) Mean() []float64
// Subset returns rows at the given indices
func (d *Dataset) Subset(indices []int) [][]float64
masker¶
Feature masking strategies.
IndependentMasker¶
type IndependentMasker struct {
Background [][]float64
}
// NewIndependentMasker creates a masker using independent feature assumption
func NewIndependentMasker(background [][]float64) *IndependentMasker
// Mask replaces masked features with background values
func (m *IndependentMasker) Mask(instance []float64, mask []bool) ([]float64, error)
// MaskWithBackground uses a specific background sample
func (m *IndependentMasker) MaskWithBackground(instance []float64, mask []bool, bgIndex int) ([]float64, error)
render¶
Chart generation in ChartIR format.
Waterfall¶
// Waterfall creates a waterfall chart specification
func Waterfall(explanation *explanation.Explanation, opts WaterfallOptions) *chartir.Chart
type WaterfallOptions struct {
Title string
MaxFeatures int
ShowValues bool
Features []string // Specific features to show
}
FeatureImportance¶
// FeatureImportance creates a bar chart of feature importance
func FeatureImportance(explanations []*explanation.Explanation, opts ImportanceOptions) *chartir.Chart
type ImportanceOptions struct {
Title string
MaxFeatures int
SortBy string // "mean_abs", "max_abs", "variance"
ExcludeFeatures []string
}
Summary¶
// Summary creates a beeswarm/summary plot
func Summary(
explanations []*explanation.Explanation,
featureValues [][]float64,
opts SummaryOptions,
) *chartir.Chart
type SummaryOptions struct {
Title string
MaxFeatures int
ColorScale string // "bluered", "viridis", "plasma"
}
Dependence¶
// Dependence creates a dependence scatter plot
func Dependence(
explanations []*explanation.Explanation,
featureValues [][]float64,
opts DependenceOptions,
) *chartir.Chart
type DependenceOptions struct {
Feature string // Feature to analyze
ColorFeature string // Feature for color coding
Title string
}
Error Types¶
Common Errors¶
var (
ErrNilModel = errors.New("model cannot be nil")
ErrNoBackground = errors.New("background data cannot be empty")
ErrFeatureMismatch = errors.New("feature count mismatch")
ErrTooManyFeatures = errors.New("too many features for exact computation")
ErrInstanceFeatureMismatch = errors.New("instance feature count mismatch")
ErrMaskFeatureMismatch = errors.New("mask feature count mismatch")
ErrBackgroundFeatureMismatch = errors.New("background feature count mismatch")
)
Checking Errors¶
explanation, err := exp.Explain(ctx, instance)
if err != nil {
if errors.Is(err, explainer.ErrFeatureMismatch) {
// Handle feature mismatch
}
return err
}
Context Support¶
All Explain methods accept context.Context for:
- Cancellation: Stop long computations
- Timeouts: Limit computation time
- Tracing: Integrate with observability tools
// With timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
explanation, err := exp.Explain(ctx, instance)
if errors.Is(err, context.DeadlineExceeded) {
log.Println("Explanation timed out")
}
Thread Safety¶
- All explainers are safe for concurrent use after creation
- Create once, use from multiple goroutines
- Internal state is read-only after initialization
exp, _ := tree.New(ensemble)
// Safe: concurrent explains
var wg sync.WaitGroup
for _, instance := range instances {
wg.Add(1)
go func(inst []float64) {
defer wg.Done()
exp.Explain(ctx, inst) // Safe
}(instance)
}
wg.Wait()
Explainer Comparison¶
| Explainer | Model Type | Complexity | Exact | Interactions |
|---|---|---|---|---|
| TreeSHAP | Tree ensembles | O(TLD²) | ✅ | ✅ |
| LinearSHAP | Linear | O(n) | ✅ | ❌ |
| AdditiveSHAP | GAMs | O(n×b) | ✅ | ❌ |
| ExactSHAP | Any | O(n×2ⁿ×b) | ✅ | ❌ |
| KernelSHAP | Any | O(s×b) | ❌ | ❌ |
| PermutationSHAP | Any | O(s×n×b) | ❌ | ❌ |
| SamplingSHAP | Any | O(s×n×b) | ❌ | ❌ |
| GradientSHAP | Differentiable | O(s×n×b) | ❌ | ❌ |
| PartitionSHAP | Structured | O(s×g×b) | ❌ | ❌ |
| DeepSHAP | Neural nets | O(L×b) | ❌ | ❌ |
Where: n=features, b=background, s=samples, T=trees, L=layers, D=depth, g=groups