import hashlib
from collections import Counter
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
import pandas as pd
from sklearn.ensemble import (
AdaBoostRegressor,
ExtraTreesRegressor,
GradientBoostingClassifier,
RandomForestRegressor,
)
from .core import DecisionPredicateGraph
from .sklearn_normalizer import SklearnEnsembleNormalizer
from .visualizer import (
class_feature_predicate_counts,
class_lookup_from_target_names,
plot_class_feature_complexity,
plot_dpg_local_paths_aggregate,
plot_sample_using_bc_weights,
plot_dpg,
plot_dpg_class_bounds_vs_dataset_feature_ranges,
plot_dpg_communities,
plot_lrc_vs_rf_importance,
plot_top_lrc_predicate_splits,
sample_bc_weights,
)
from metrics.graph import GraphMetrics
from metrics.nodes import NodeMetrics
from metrics.edges import EdgeMetrics
[docs]
@dataclass
class DPGExplanation:
"""Container for global DPG outputs."""
graph: Any
nodes: List[List[str]]
dot: Any
node_metrics: Any
edge_metrics: Any
class_boundaries: Dict[str, Any]
communities: Optional[Dict[str, Any]] = None
community_threshold: Optional[float] = None
[docs]
def as_dict(self) -> Dict[str, Any]:
return {
"graph": self.graph,
"nodes": self.nodes,
"dot": self.dot,
"node_metrics": self.node_metrics,
"edge_metrics": self.edge_metrics,
"class_boundaries": self.class_boundaries,
"communities": self.communities,
"community_threshold": self.community_threshold,
}
[docs]
@dataclass
class DPGTreePathExplanation:
tree_index: int
tree_prefix: str
labels: List[str]
node_ids: List[Optional[str]]
predicate_truths: List[bool]
edge_exists: List[bool]
starts_from_root: bool
ends_in_leaf: bool
graph_path_valid: bool
mean_lrc: Optional[float] = None
mean_bc: Optional[float] = None
path_confidence: Optional[float] = None
[docs]
def as_dict(self) -> Dict[str, Any]:
return {
"tree_index": self.tree_index,
"tree_prefix": self.tree_prefix,
"labels": self.labels,
"node_ids": self.node_ids,
"predicate_truths": self.predicate_truths,
"edge_exists": self.edge_exists,
"starts_from_root": self.starts_from_root,
"ends_in_leaf": self.ends_in_leaf,
"graph_path_valid": self.graph_path_valid,
"mean_lrc": self.mean_lrc,
"mean_bc": self.mean_bc,
"path_confidence": self.path_confidence,
}
[docs]
@dataclass
class DPGLocalExplanation:
sample_id: int
sample: List[float]
tree_paths: List[DPGTreePathExplanation]
graph_validated: bool
all_trees_valid: bool
majority_vote: Optional[str]
class_votes: Dict[str, int]
path_mode: str
sample_confidence: Optional[Dict[str, Any]] = None
[docs]
def as_dict(self) -> Dict[str, Any]:
return {
"sample_id": self.sample_id,
"sample": self.sample,
"tree_paths": [path.as_dict() for path in self.tree_paths],
"graph_validated": self.graph_validated,
"all_trees_valid": self.all_trees_valid,
"majority_vote": self.majority_vote,
"class_votes": self.class_votes,
"path_mode": self.path_mode,
"sample_confidence": self.sample_confidence,
}
[docs]
class DPGExplainer:
"""
High-level, user-friendly API for building and plotting DPG explanations.
This class wraps DecisionPredicateGraph and the metrics/visualization utilities
into a cohesive workflow.
"""
def __init__(
self,
model: Any,
feature_names: Iterable[str],
target_names: Optional[Iterable[str]] = None,
config_file: str = "config.yaml",
dpg_config: Optional[Dict[str, Any]] = None,
) -> None:
self._builder = DecisionPredicateGraph(
model=model,
feature_names=list(feature_names),
target_names=list(target_names) if target_names is not None else None,
config_file=config_file,
dpg_config=dpg_config,
)
self._is_fitted = False
self._dot = None
self._graph = None
self._nodes = None
self._node_metrics = None
self._node_metrics_lookup = None
self._edge_metrics = None
@property
def builder(self) -> DecisionPredicateGraph:
return self._builder
[docs]
def fit(self, X: Any) -> "DPGExplainer":
"""Fit the DPG structure from training data."""
self._dot = self._builder.fit(X)
self._graph, self._nodes = self._builder.to_networkx(self._dot)
self._node_metrics = None
self._node_metrics_lookup = None
self._edge_metrics = None
self._is_fitted = True
return self
[docs]
def explain_global(
self,
X: Optional[Any] = None,
communities: bool = False,
community_threshold: float = 0.2,
) -> DPGExplanation:
"""
Build global DPG metrics and return a structured explanation object.
Args:
X: Optional training data. If provided, fit() is called before extracting metrics.
communities: Whether to compute cluster-based communities.
community_threshold: Threshold used by community extraction.
"""
if X is not None:
self.fit(X)
if not self._is_fitted:
raise ValueError("DPGExplainer is not fitted. Call fit(X) or explain_global(X=...).")
node_metrics = self._get_node_metrics()
edge_metrics = EdgeMetrics.extract_edge_metrics(self._graph, self._nodes)
class_boundaries = GraphMetrics.extract_class_boundaries(
self._graph,
self._nodes,
target_names=self._builder.target_names or [],
)
communities_out = None
if communities:
communities_out = GraphMetrics.extract_communities(
self._graph,
node_metrics,
self._nodes,
threshold_clusters=community_threshold,
)
return DPGExplanation(
graph=self._graph,
nodes=self._nodes,
dot=self._dot,
node_metrics=node_metrics,
edge_metrics=edge_metrics,
class_boundaries=class_boundaries,
communities=communities_out,
community_threshold=community_threshold if communities else None,
)
[docs]
def explain_local(
self,
sample: Any,
sample_id: int = 0,
X: Optional[Any] = None,
validate_graph: bool = True,
) -> DPGLocalExplanation:
"""
Trace a single sample through every estimator and map the executed path
onto the fitted DPG graph.
Args:
sample: One sample with the same feature dimension used to fit the model.
sample_id: Identifier to attach to the returned explanation.
X: Optional training data. If provided, fit() is called before tracing.
validate_graph: Whether to validate node and edge presence in the fitted graph.
"""
if X is not None:
self.fit(X)
if not self._is_fitted:
raise ValueError("DPGExplainer is not fitted. Call fit(X) or explain_local(X=...).")
sample_array = np.asarray(sample).reshape(-1)
expected_features = len(self._builder.feature_names)
if sample_array.shape[0] != expected_features:
raise ValueError(
f"Sample has {sample_array.shape[0]} features, expected {expected_features}."
)
node_lookup = {label: node_id for node_id, label in self._nodes}
node_metrics_lookup = self._get_node_metrics_lookup()
tree_paths = []
class_votes = Counter()
for tree_index, tree in enumerate(self._builder.model.estimators_):
path = self._trace_tree_path(
tree=tree,
sample=sample_array,
sample_id=sample_id,
tree_index=tree_index,
node_lookup=node_lookup,
node_metrics_lookup=node_metrics_lookup,
validate_graph=validate_graph,
)
tree_paths.append(path)
if path.labels and path.labels[-1].startswith("Class "):
class_votes[self._normalize_class_vote_label(path.labels[-1])] += 1
majority_vote = None
if class_votes:
majority_vote = class_votes.most_common(1)[0][0]
sample_confidence = self._compute_sample_confidence(
tree_paths,
dict(class_votes),
sample_array,
)
return DPGLocalExplanation(
sample_id=sample_id,
sample=sample_array.tolist(),
tree_paths=tree_paths,
graph_validated=validate_graph,
all_trees_valid=all(path.graph_path_valid for path in tree_paths),
majority_vote=majority_vote,
class_votes=dict(class_votes),
path_mode="execution_trace",
sample_confidence=sample_confidence,
)
[docs]
def plot_local_on_dpg(
self,
plot_name: str,
local_explanation: Optional[DPGLocalExplanation] = None,
sample: Optional[Any] = None,
sample_id: int = 0,
X: Optional[Any] = None,
validate_graph: bool = True,
path_indices: Optional[List[int]] = None,
true_class_label: Optional[str] = None,
obtained_class_label: Optional[str] = None,
sample_metrics: Optional[Dict[str, Any]] = None,
save_dir: str = "results/",
class_flag: bool = True,
layout_template: str = "default",
graph_style: Optional[Dict[str, Any]] = None,
node_style: Optional[Dict[str, Any]] = None,
edge_style: Optional[Dict[str, Any]] = None,
fig_size: Tuple[float, float] = (16, 8),
dpi: int = 300,
pdf_dpi: int = 600,
show: bool = True,
export_pdf: bool = False,
theme: str = "dpg",
palette: str = "default",
label_mode: str = "wrapped",
readability: str = "presentation",
title: Optional[str] = None,
) -> Any:
if X is not None:
self.fit(X)
if not self._is_fitted:
raise ValueError("DPGExplainer is not fitted. Call fit(X) or plot_local_on_dpg(X=...).")
if local_explanation is None:
if sample is None:
raise ValueError("Either local_explanation or sample must be provided.")
local_explanation = self.explain_local(
sample=sample,
sample_id=sample_id,
X=None,
validate_graph=validate_graph,
)
selected_paths = list(local_explanation.tree_paths)
if path_indices is not None:
selected_paths = []
for idx in path_indices:
if not isinstance(idx, int) or idx < 0 or idx >= len(local_explanation.tree_paths):
raise ValueError("path_indices must reference valid path indices.")
selected_paths.append(local_explanation.tree_paths[idx])
if obtained_class_label is None:
obtained_class_label = local_explanation.majority_vote
if sample_metrics is None:
sample_metrics = local_explanation.sample_confidence
return plot_dpg_local_paths_aggregate(
plot_name=plot_name,
dot=self._dot,
df=self._get_node_metrics(),
df_edges=self._get_edge_metrics(),
paths_node_ids=[path.node_ids for path in selected_paths],
path_confidences=[path.path_confidence for path in selected_paths],
sample_id=local_explanation.sample_id,
true_class_label=true_class_label,
obtained_class_label=obtained_class_label,
sample_metrics=sample_metrics,
save_dir=save_dir,
class_flag=class_flag,
layout_template=layout_template,
graph_style=graph_style,
node_style=node_style,
edge_style=edge_style,
fig_size=fig_size,
dpi=dpi,
pdf_dpi=pdf_dpi,
show=show,
export_pdf=export_pdf,
theme=theme,
palette=palette,
label_mode=label_mode,
readability=readability,
title=title,
)
[docs]
def local_path_dataframe(self, local_explanation: DPGLocalExplanation) -> Any:
"""
Flatten a local explanation into one row per path step.
Args:
local_explanation: Structured local explanation returned by explain_local().
Returns:
pd.DataFrame: Ordered path-step table.
"""
columns = [
"sample_id",
"tree_index",
"step_index",
"label",
"node_id",
"is_leaf",
"predicate_true",
"edge_exists_from_prev",
"starts_from_root",
"ends_in_leaf",
"graph_path_valid",
"mean_lrc",
"mean_bc",
"path_confidence",
]
rows = []
for path in sorted(local_explanation.tree_paths, key=lambda path: path.tree_index):
for step_index, label in enumerate(path.labels):
is_leaf = label.startswith("Class ") or label.startswith("Pred ")
predicate_true = (
path.predicate_truths[step_index]
if step_index < len(path.predicate_truths)
else None
)
edge_exists_from_prev = (
True if step_index == 0 else path.edge_exists[step_index - 1]
)
node_id = path.node_ids[step_index] if step_index < len(path.node_ids) else None
rows.append(
{
"sample_id": local_explanation.sample_id,
"tree_index": path.tree_index,
"step_index": step_index,
"label": label,
"node_id": node_id,
"is_leaf": is_leaf,
"predicate_true": predicate_true,
"edge_exists_from_prev": edge_exists_from_prev,
"starts_from_root": path.starts_from_root,
"ends_in_leaf": path.ends_in_leaf,
"graph_path_valid": path.graph_path_valid,
"mean_lrc": path.mean_lrc,
"mean_bc": path.mean_bc,
"path_confidence": path.path_confidence,
}
)
return pd.DataFrame(rows, columns=columns)
[docs]
def evaluate_faithfulness(
self,
X,
y_true=None,
max_samples=None,
weights=None,
return_details=False,
sample_ids=None,
):
"""
Evaluate local DPG explanations against the fitted black-box model.
This measures output fidelity to the underlying model plus structural
faithfulness diagnostics derived from local DPG traces. It does not
measure ground-truth correctness unless ``y_true`` is provided, and the
returned composite score is a heuristic summary, not a calibrated
probability.
"""
if not self._is_fitted:
raise ValueError("DPGExplainer is not fitted. Call fit(X) before evaluate_faithfulness().")
if max_samples is not None and max_samples <= 0:
raise ValueError("max_samples must be a positive integer when provided.")
weights = self._validate_faithfulness_weights(weights)
if isinstance(X, pd.DataFrame):
X_eval = X.iloc[:max_samples].copy() if max_samples is not None else X.copy()
row_iter = [(i, X_eval.iloc[i], X_eval.iloc[i].values) for i in range(len(X_eval))]
else:
X_array = np.asarray(X)
X_eval = X_array[:max_samples] if max_samples is not None else X_array
row_iter = [(i, X_eval[i], np.asarray(X_eval[i]).reshape(-1)) for i in range(len(X_eval))]
n_samples = len(X_eval)
if n_samples == 0:
raise ValueError("X must contain at least one sample for faithfulness evaluation.")
if y_true is not None:
y_true_seq = list(y_true[:n_samples] if max_samples is not None else y_true)
if len(y_true_seq) != n_samples:
raise ValueError("y_true length must match the number of evaluated samples.")
else:
y_true_seq = None
if sample_ids is not None:
sample_ids_seq = list(sample_ids[:n_samples] if max_samples is not None else sample_ids)
if len(sample_ids_seq) != n_samples:
raise ValueError("sample_ids length must match the number of evaluated samples.")
else:
sample_ids_seq = list(range(n_samples))
per_sample_records = []
successful_records = []
n_local_failures = 0
for idx, row_for_predict, row_values in row_iter:
sample_id = sample_ids_seq[idx]
true_label_normalized = None
if y_true_seq is not None:
true_label_normalized = self._normalize_prediction_label(y_true_seq[idx])
if isinstance(X_eval, pd.DataFrame):
model_pred_raw = self._builder.model.predict(row_for_predict.to_frame().T)[0]
else:
model_pred_raw = self._builder.model.predict(np.asarray(row_values).reshape(1, -1))[0]
model_pred = self._normalize_prediction_label(model_pred_raw)
try:
local = self.explain_local(sample=row_values, sample_id=sample_id)
local_pred = local.majority_vote
sample_confidence = local.sample_confidence or {}
record = {
"sample_id": sample_id,
"model_pred": model_pred,
"local_pred": local_pred,
"matches_model": bool(local_pred == model_pred),
"vote_confidence": self._maybe_float(sample_confidence.get("vote_confidence")),
"evidence_score_pred": self._maybe_float(sample_confidence.get("evidence_score_pred")),
"evidence_score_margin": self._maybe_float(sample_confidence.get("evidence_score_margin")),
"trace_coverage_score": self._maybe_float(sample_confidence.get("trace_coverage_score")),
"recombination_rate": self._maybe_float(sample_confidence.get("recombination_rate")),
"graph_path_valid_rate": self._maybe_float(sample_confidence.get("graph_path_valid_rate")),
"node_recall": self._maybe_float(sample_confidence.get("node_recall")),
"node_precision": self._maybe_float(sample_confidence.get("node_precision")),
"edge_recall": self._maybe_float(sample_confidence.get("edge_recall")),
"edge_precision": self._maybe_float(sample_confidence.get("edge_precision")),
"evidence_margin_pred_vs_competitor": self._maybe_float(
sample_confidence.get("evidence_margin_pred_vs_competitor")
),
"path_purity": self._maybe_float(sample_confidence.get("path_purity")),
"competitor_exposure": self._maybe_float(sample_confidence.get("competitor_exposure")),
"explanation_confidence": self._maybe_float(sample_confidence.get("explanation_confidence")),
"error": None,
}
if true_label_normalized is not None:
record["correct"] = bool(local_pred == true_label_normalized)
per_sample_records.append(record)
successful_records.append(record)
except Exception as exc:
n_local_failures += 1
failure_record = {
"sample_id": sample_id,
"model_pred": model_pred,
"local_pred": None,
"matches_model": False,
"vote_confidence": None,
"evidence_score_pred": None,
"evidence_score_margin": None,
"trace_coverage_score": None,
"recombination_rate": None,
"graph_path_valid_rate": None,
"node_recall": None,
"node_precision": None,
"edge_recall": None,
"edge_precision": None,
"evidence_margin_pred_vs_competitor": None,
"path_purity": None,
"competitor_exposure": None,
"explanation_confidence": None,
"error": str(exc),
}
if y_true_seq is not None:
failure_record["correct"] = None
per_sample_records.append(failure_record)
if not successful_records:
raise ValueError(
"All local explanations failed during faithfulness evaluation; "
"no faithfulness metrics could be computed."
)
output_fidelity = float(np.mean([record["matches_model"] for record in successful_records]))
local_accuracy = None
if y_true_seq is not None:
local_accuracy = float(
np.mean([record["correct"] for record in successful_records if record.get("correct") is not None])
)
mean_node_recall = self._mean_records(successful_records, "node_recall")
mean_node_precision = self._mean_records(successful_records, "node_precision")
mean_edge_recall = self._mean_records(successful_records, "edge_recall")
mean_edge_precision = self._mean_records(successful_records, "edge_precision")
mean_trace_coverage_score = self._mean_records(successful_records, "trace_coverage_score")
mean_recombination_rate = self._mean_records(successful_records, "recombination_rate")
mean_vote_confidence = self._mean_records(successful_records, "vote_confidence")
mean_evidence_score_pred = self._mean_records(successful_records, "evidence_score_pred")
mean_evidence_score_margin = self._mean_records(successful_records, "evidence_score_margin")
mean_evidence_margin_pred_vs_competitor = self._mean_records(
successful_records,
"evidence_margin_pred_vs_competitor",
)
mean_path_purity = self._mean_records(successful_records, "path_purity")
mean_competitor_exposure = self._mean_records(successful_records, "competitor_exposure")
mean_explanation_confidence = self._mean_records(successful_records, "explanation_confidence")
composite = (
weights["output_fidelity"] * output_fidelity
+ weights["trace_coverage"] * mean_trace_coverage_score
+ weights["anti_recombination"] * (1.0 - mean_recombination_rate)
+ weights["evidence_margin"] * mean_evidence_score_margin
)
if not return_details:
return float(composite)
details = {
"faithfulness_score": float(composite),
"weights": weights,
"n_samples": n_samples,
"n_successful": len(successful_records),
"n_local_failures": n_local_failures,
"output_fidelity": output_fidelity,
"mean_node_recall": mean_node_recall,
"mean_node_precision": mean_node_precision,
"mean_edge_recall": mean_edge_recall,
"mean_edge_precision": mean_edge_precision,
"mean_trace_coverage_score": mean_trace_coverage_score,
"mean_recombination_rate": mean_recombination_rate,
"mean_vote_confidence": mean_vote_confidence,
"mean_evidence_score_pred": mean_evidence_score_pred,
"mean_evidence_score_margin": mean_evidence_score_margin,
"mean_evidence_margin_pred_vs_competitor": mean_evidence_margin_pred_vs_competitor,
"mean_path_purity": mean_path_purity,
"mean_competitor_exposure": mean_competitor_exposure,
"mean_explanation_confidence": mean_explanation_confidence,
"per_sample": pd.DataFrame(per_sample_records),
}
if local_accuracy is not None:
details["local_accuracy"] = local_accuracy
return details
def _trace_tree_path(
self,
tree: Any,
sample: np.ndarray,
sample_id: int,
tree_index: int,
node_lookup: Dict[str, str],
node_metrics_lookup: Dict[str, Dict[str, Any]],
validate_graph: bool,
) -> DPGTreePathExplanation:
is_regressor = isinstance(
self._builder.model,
(RandomForestRegressor, ExtraTreesRegressor, AdaBoostRegressor),
)
tree_ = tree.tree_
node_index = 0
tree_prefix = f"sample{sample_id}_dt{tree_index}"
labels: List[str] = []
predicate_truths: List[bool] = []
while True:
left = tree_.children_left[node_index]
right = tree_.children_right[node_index]
if left == right:
if is_regressor:
pred = round(tree_.value[node_index][0][0], 2)
labels.append(f"Pred {pred}")
else:
labels.append(self._leaf_class_label(tree_index, tree_, node_index))
break
feature_index = tree_.feature[node_index]
threshold = round(tree_.threshold[node_index], self._builder.decimal_threshold)
feature_name = self._builder.feature_names[feature_index]
sample_val = sample[feature_index]
if sample_val <= threshold:
labels.append(f"{feature_name} <= {threshold}")
predicate_truths.append(True)
node_index = left
else:
labels.append(f"{feature_name} > {threshold}")
predicate_truths.append(True)
node_index = right
native_node_ids = [self._label_to_node_id(label) for label in labels]
node_ids = [
node_lookup.get(label) if validate_graph else native_node_id
for label, native_node_id in zip(labels, native_node_ids)
]
edge_exists = []
for i in range(len(native_node_ids) - 1):
edge_exists.append(self._graph.has_edge(native_node_ids[i], native_node_ids[i + 1]))
graph_path_valid = all(native_node_id in node_metrics_lookup for native_node_id in native_node_ids) and all(edge_exists)
active_metric_rows = [
node_metrics_lookup[native_node_id]
for native_node_id in native_node_ids
if native_node_id in node_metrics_lookup
]
mean_lrc = None
mean_bc = None
if active_metric_rows:
mean_lrc = float(np.mean([row["Local reaching centrality"] for row in active_metric_rows]))
mean_bc = float(np.mean([row["Betweenness centrality"] for row in active_metric_rows]))
node_coverage = (
len(active_metric_rows) / len(labels)
if labels
else 0.0
)
edge_coverage = (
sum(edge_exists) / len(edge_exists)
if edge_exists
else 1.0
)
path_confidence = float((node_coverage + edge_coverage) / 2.0) if labels else 0.0
return DPGTreePathExplanation(
tree_index=tree_index,
tree_prefix=tree_prefix,
labels=labels,
node_ids=node_ids,
predicate_truths=predicate_truths,
edge_exists=edge_exists,
starts_from_root=len(labels) > 0,
ends_in_leaf=bool(labels and (labels[-1].startswith("Class ") or labels[-1].startswith("Pred "))),
graph_path_valid=graph_path_valid,
mean_lrc=mean_lrc,
mean_bc=mean_bc,
path_confidence=path_confidence,
)
@staticmethod
def _label_to_node_id(label: str) -> str:
return str(int(hashlib.sha1(label.encode()).hexdigest(), 16))
def _leaf_class_label(self, tree_index: int, tree_: Any, node_index: int) -> str:
"""Return the class label for a classifier leaf node."""
gb_class_index = SklearnEnsembleNormalizer.get_tree_class_index(
self._builder.model,
tree_index,
)
if gb_class_index is not None:
pred_class = gb_class_index
elif isinstance(self._builder.model, GradientBoostingClassifier) and getattr(self._builder.model, "n_classes_", 0) == 2:
leaf_score = float(tree_.value[node_index][0][0])
pred_class = 1 if leaf_score > 0 else 0
else:
pred_class = int(tree_.value[node_index].argmax())
if self._builder.target_names is not None:
pred_class = self._builder.target_names[pred_class]
elif hasattr(self._builder.model, "classes_"):
pred_class = self._builder.model.classes_[pred_class]
return f"Class {pred_class}"
@staticmethod
def _normalize_class_vote_label(label: str) -> str:
return label[len("Class ") :] if label.startswith("Class ") else label
def _get_node_metrics(self) -> Any:
if self._node_metrics is None:
self._node_metrics = NodeMetrics.extract_node_metrics(self._graph, self._nodes)
return self._node_metrics
def _get_node_metrics_lookup(self) -> Dict[str, Dict[str, Any]]:
if self._node_metrics_lookup is None:
node_metrics = self._get_node_metrics()
self._node_metrics_lookup = {
row["Node"]: row
for row in node_metrics.to_dict(orient="records")
}
return self._node_metrics_lookup
def _get_edge_metrics(self) -> Any:
if self._edge_metrics is None:
self._edge_metrics = EdgeMetrics.extract_edge_metrics(self._graph, self._nodes)
return self._edge_metrics
def _compute_sample_confidence(
self,
tree_paths: List[DPGTreePathExplanation],
class_votes: Dict[str, int],
sample_array: np.ndarray,
) -> Dict[str, Any]:
num_paths = len(tree_paths)
num_valid_paths = sum(path.graph_path_valid for path in tree_paths)
active_node_ids = []
for path in tree_paths:
for node_id in path.node_ids:
if node_id is not None:
active_node_ids.append(node_id)
unique_active_node_ids = list(dict.fromkeys(active_node_ids))
node_metrics_lookup = self._get_node_metrics_lookup()
active_metric_rows = [
node_metrics_lookup[node_id]
for node_id in unique_active_node_ids
if node_id in node_metrics_lookup
]
mean_lrc_active_nodes = (
float(np.mean([row["Local reaching centrality"] for row in active_metric_rows]))
if active_metric_rows
else 0.0
)
mean_bc_active_nodes = (
float(np.mean([row["Betweenness centrality"] for row in active_metric_rows]))
if active_metric_rows
else 0.0
)
total_votes = sum(class_votes.values())
class_scores = (
{label: votes / total_votes for label, votes in class_votes.items()}
if total_votes > 0
else {}
)
sorted_scores = sorted(class_scores.values(), reverse=True)
if not sorted_scores:
vote_confidence = 0.0
score_margin = 0.0
else:
vote_confidence = float(sorted_scores[0])
score_margin = float(
sorted_scores[0] - sorted_scores[1]
if len(sorted_scores) > 1
else sorted_scores[0]
)
class_support, evidence_scores = self._compute_evidence_scores(
tree_paths=tree_paths,
class_scores=class_scores,
)
trace_diagnostics = self._compute_trace_diagnostics(tree_paths, sample_array)
evidence_score_pred = None
if class_votes:
majority_vote = max(class_votes, key=class_votes.get)
evidence_score_pred = evidence_scores.get(majority_vote)
sorted_evidence = sorted(
evidence_scores.items(),
key=lambda item: item[1],
reverse=True,
)
if not sorted_evidence:
evidence_score_margin = None
top_competitor_class_pred = None
evidence_score_competitor_pred = None
evidence_margin_pred_vs_competitor = None
else:
top_score = float(sorted_evidence[0][1])
second_item = sorted_evidence[1] if len(sorted_evidence) > 1 else None
evidence_score_margin = float(
top_score - second_item[1] if second_item is not None else top_score
)
top_competitor_class_pred = second_item[0] if second_item is not None else None
evidence_score_competitor_pred = (
float(second_item[1]) if second_item is not None else None
)
evidence_margin_pred_vs_competitor = (
float(evidence_score_pred - evidence_score_competitor_pred)
if evidence_score_pred is not None and evidence_score_competitor_pred is not None
else float(evidence_score_pred) if evidence_score_pred is not None else None
)
return {
"num_paths": num_paths,
"num_valid_paths": num_valid_paths,
"num_active_nodes": len(unique_active_node_ids),
"mean_lrc_active_nodes": mean_lrc_active_nodes,
"mean_bc_active_nodes": mean_bc_active_nodes,
"graph_path_valid_rate": (num_valid_paths / num_paths) if num_paths > 0 else 0.0,
"vote_confidence": vote_confidence,
"class_scores": class_scores,
"score_margin": score_margin,
"class_support": class_support,
"evidence_scores": evidence_scores,
"evidence_score_pred": evidence_score_pred,
"evidence_score_margin": evidence_score_margin,
"top_competitor_class_pred": top_competitor_class_pred,
"evidence_score_competitor_pred": evidence_score_competitor_pred,
"evidence_margin_pred_vs_competitor": evidence_margin_pred_vs_competitor,
**trace_diagnostics,
}
def _compute_evidence_scores(
self,
tree_paths: List[DPGTreePathExplanation],
class_scores: Dict[str, float],
) -> Tuple[Dict[str, float], Dict[str, float]]:
class_support: Dict[str, float] = {}
for path in tree_paths:
if not path.labels:
continue
leaf_label = path.labels[-1]
if not leaf_label.startswith("Class "):
continue
class_name = self._normalize_class_vote_label(leaf_label)
support = float(path.path_confidence or 0.0)
class_support[class_name] = class_support.get(class_name, 0.0) + support
total_support = sum(class_support.values())
if total_support > 0:
evidence_scores = {
class_name: support / total_support
for class_name, support in class_support.items()
}
else:
evidence_scores = dict(class_scores)
return class_support, evidence_scores
def _extract_execution_trace_labels(self, sample_arr: np.ndarray) -> List[List[str]]:
traces = []
for tree_index, tree in enumerate(self._builder.model.estimators_):
traces.append(
self._trace_execution_labels_for_tree(
tree,
sample_arr,
tree_index=tree_index,
)
)
return traces
def _trace_reference_sets(
self,
sample_arr: np.ndarray,
) -> Tuple[set, set]:
trace_node_labels = set()
trace_edge_labels = set()
for labels in self._extract_execution_trace_labels(sample_arr):
for label in labels:
if not (label.startswith("Class ") or label.startswith("Pred ")):
trace_node_labels.add(label)
for i in range(len(labels) - 1):
trace_edge_labels.add((labels[i], labels[i + 1]))
return trace_node_labels, trace_edge_labels
def _compute_trace_diagnostics(
self,
tree_paths: List[DPGTreePathExplanation],
sample_arr: np.ndarray,
) -> Dict[str, Any]:
trace_node_labels, trace_edge_labels = self._trace_reference_sets(sample_arr)
explanation_node_labels = set()
explanation_edge_labels = set()
for path in tree_paths:
for label, node_id in zip(path.labels, path.node_ids):
if node_id is not None and not (label.startswith("Class ") or label.startswith("Pred ")):
explanation_node_labels.add(label)
for i in range(len(path.labels) - 1):
src_id = path.node_ids[i] if i < len(path.node_ids) else None
dst_id = path.node_ids[i + 1] if i + 1 < len(path.node_ids) else None
edge_ok = path.edge_exists[i] if i < len(path.edge_exists) else False
if src_id is not None and dst_id is not None and edge_ok:
explanation_edge_labels.add((path.labels[i], path.labels[i + 1]))
node_overlap = len(trace_node_labels & explanation_node_labels)
edge_overlap = len(trace_edge_labels & explanation_edge_labels)
node_recall = (
float(node_overlap / len(trace_node_labels))
if trace_node_labels
else 1.0
)
node_precision = (
float(node_overlap / len(explanation_node_labels))
if explanation_node_labels
else 1.0
)
edge_recall = (
float(edge_overlap / len(trace_edge_labels))
if trace_edge_labels
else 1.0
)
edge_precision = (
float(edge_overlap / len(explanation_edge_labels))
if explanation_edge_labels
else 1.0
)
trace_coverage_score = float((node_recall + edge_recall) / 2.0)
recombination_rate = (
float(len(explanation_edge_labels - trace_edge_labels) / len(explanation_edge_labels))
if explanation_edge_labels
else 0.0
)
return {
"trace_node_count_unique": int(len(trace_node_labels)),
"trace_edge_count_unique": int(len(trace_edge_labels)),
"explanation_node_count_unique": int(len(explanation_node_labels)),
"explanation_edge_count_unique": int(len(explanation_edge_labels)),
"node_recall": node_recall,
"node_precision": node_precision,
"edge_recall": edge_recall,
"edge_precision": edge_precision,
"trace_coverage_score": trace_coverage_score,
"recombination_rate": recombination_rate,
}
def _trace_execution_labels_for_tree(
self,
tree: Any,
sample: np.ndarray,
tree_index: Optional[int] = None,
) -> List[str]:
is_regressor = isinstance(
self._builder.model,
(RandomForestRegressor, ExtraTreesRegressor, AdaBoostRegressor),
)
tree_ = tree.tree_
node_index = 0
labels: List[str] = []
while True:
left = tree_.children_left[node_index]
right = tree_.children_right[node_index]
if left == right:
if is_regressor:
pred = round(tree_.value[node_index][0][0], 2)
labels.append(f"Pred {pred}")
else:
if tree_index is None:
pred_class = int(tree_.value[node_index].argmax())
if self._builder.target_names is not None:
pred_class = self._builder.target_names[pred_class]
elif hasattr(self._builder.model, "classes_"):
pred_class = self._builder.model.classes_[pred_class]
labels.append(f"Class {pred_class}")
else:
labels.append(self._leaf_class_label(tree_index, tree_, node_index))
break
feature_index = tree_.feature[node_index]
threshold = round(tree_.threshold[node_index], self._builder.decimal_threshold)
feature_name = self._builder.feature_names[feature_index]
sample_val = sample[feature_index]
if sample_val <= threshold:
labels.append(f"{feature_name} <= {threshold}")
node_index = left
else:
labels.append(f"{feature_name} > {threshold}")
node_index = right
return labels
def _normalize_prediction_label(self, value: Any) -> Any:
if self._builder.target_names is not None and hasattr(self._builder.model, "classes_"):
classes = list(self._builder.model.classes_)
target_names = list(self._builder.target_names)
if len(classes) == len(target_names):
for class_value, target_name in zip(classes, target_names):
if value == class_value:
return str(target_name)
if isinstance(value, str) and value.startswith("Class "):
return value[len("Class ") :]
return str(value)
def _validate_faithfulness_weights(self, weights: Optional[Dict[str, float]]) -> Dict[str, float]:
default_weights = {
"output_fidelity": 0.35,
"trace_coverage": 0.30,
"anti_recombination": 0.20,
"evidence_margin": 0.15,
}
if weights is None:
return default_weights
weights = dict(weights)
unsupported = set(weights) - set(default_weights)
if unsupported:
raise ValueError(
f"Unsupported faithfulness weight keys: {sorted(unsupported)}. "
f"Supported keys are: {sorted(default_weights)}"
)
missing = set(default_weights) - set(weights)
if missing:
raise ValueError(
f"Missing faithfulness weight keys: {sorted(missing)}. "
f"Supported keys are: {sorted(default_weights)}"
)
total = float(sum(weights.values()))
if not np.isclose(total, 1.0, atol=1e-6):
raise ValueError("Faithfulness weights must sum to 1.0.")
return {key: float(value) for key, value in weights.items()}
@staticmethod
def _mean_records(records: List[Dict[str, Any]], key: str) -> float:
values = [
float(record[key])
for record in records
if record.get(key) is not None and not pd.isna(record.get(key))
]
if not values:
return 0.0
return float(np.mean(values))
@staticmethod
def _maybe_float(value: Any) -> Optional[float]:
if value is None or pd.isna(value):
return None
return float(value)
[docs]
def plot(
self,
plot_name: str,
explanation: Optional[DPGExplanation] = None,
save_dir: str = "results/",
attribute: Optional[str] = None,
class_flag: bool = False,
layout_template: str = "default",
graph_style: Optional[Dict[str, Any]] = None,
node_style: Optional[Dict[str, Any]] = None,
edge_style: Optional[Dict[str, Any]] = None,
fig_size: Tuple[float, float] = (16, 8),
dpi: int = 300,
pdf_dpi: int = 600,
show: bool = True,
export_pdf: bool = False,
theme: str = "dpg",
palette: str = "default",
label_mode: str = "full",
readability: str = "normal",
title: Optional[str] = None,
) -> None:
"""Render a standard DPG plot."""
if explanation is None:
explanation = self.explain_global()
plot_dpg(
plot_name,
explanation.dot,
explanation.node_metrics,
explanation.edge_metrics,
save_dir=save_dir,
attribute=attribute,
class_flag=class_flag,
layout_template=layout_template,
graph_style=graph_style,
node_style=node_style,
edge_style=edge_style,
fig_size=fig_size,
dpi=dpi,
pdf_dpi=pdf_dpi,
show=show,
export_pdf=export_pdf,
theme=theme,
palette=palette,
label_mode=label_mode,
readability=readability,
title=title,
)
[docs]
def plot_communities(
self,
plot_name: str,
explanation: Optional[DPGExplanation] = None,
save_dir: str = "results/",
class_flag: bool = True,
layout_template: str = "default",
graph_style: Optional[Dict[str, Any]] = None,
node_style: Optional[Dict[str, Any]] = None,
edge_style: Optional[Dict[str, Any]] = None,
fig_size: Tuple[float, float] = (16, 8),
dpi: int = 300,
pdf_dpi: int = 600,
show: bool = True,
export_pdf: bool = False,
community_threshold: float = 0.2,
theme: str = "dpg",
palette: str = "default",
label_mode: str = "wrapped",
readability: str = "presentation",
title: Optional[str] = None,
) -> None:
"""Render a community-colored DPG plot."""
if explanation is None or explanation.communities is None:
explanation = self.explain_global(
communities=True,
community_threshold=community_threshold,
)
plot_dpg_communities(
plot_name,
explanation.dot,
explanation.node_metrics,
explanation.communities,
save_dir=save_dir,
class_flag=class_flag,
layout_template=layout_template,
graph_style=graph_style,
node_style=node_style,
edge_style=edge_style,
fig_size=fig_size,
dpi=dpi,
pdf_dpi=pdf_dpi,
show=show,
export_pdf=export_pdf,
theme=theme,
palette=palette,
label_mode=label_mode,
readability=readability,
title=title,
)
[docs]
def plot_lrc_importance(
self,
X_df: Any,
explanation: Optional[DPGExplanation] = None,
top_k: int = 10,
dataset_name: str = "Dataset",
save_path: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Any:
"""Plot top LRC predicates vs RF feature importances."""
if explanation is None:
explanation = self.explain_global()
return plot_lrc_vs_rf_importance(
explanation=explanation,
model=self._builder.model,
X_df=X_df,
top_k=top_k,
dataset_name=dataset_name,
save_path=save_path,
show=show,
theme=theme,
palette=palette,
)
[docs]
def plot_top_lrc_splits(
self,
X_df: Any,
y,
explanation: Optional[DPGExplanation] = None,
top_predicates: int = 5,
top_features: int = 2,
dataset_name: str = "Dataset",
class_names: Optional[Any] = None,
save_path: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Optional[Any]:
"""Plot top-LRC split lines over the top-2 LRC feature space."""
if explanation is None:
explanation = self.explain_global()
return plot_top_lrc_predicate_splits(
explanation=explanation,
X_df=X_df,
y=y,
top_predicates=top_predicates,
top_features=top_features,
dataset_name=dataset_name,
class_names=class_names,
save_path=save_path,
show=show,
theme=theme,
palette=palette,
)
[docs]
def class_feature_predicate_counts(
self,
explanation: Optional[DPGExplanation] = None,
community_threshold: float = 0.2,
) -> Any:
"""Return class-vs-feature predicate count matrix from communities."""
if explanation is None or explanation.communities is None:
explanation = self.explain_global(communities=True, community_threshold=community_threshold)
return class_feature_predicate_counts(explanation)
[docs]
def plot_class_feature_complexity(
self,
explanation: Optional[DPGExplanation] = None,
dataset_name: str = "Dataset",
top_n_features: int = 10,
save_prefix: Optional[str] = None,
show: bool = True,
community_threshold: float = 0.2,
theme: str = "dpg",
palette: str = "default",
) -> Tuple[Any, Any]:
"""Plot community class-feature complexity using PCA-consistent class colors."""
if explanation is None or explanation.communities is None:
explanation = self.explain_global(communities=True, community_threshold=community_threshold)
heat_df = class_feature_predicate_counts(explanation)
return plot_class_feature_complexity(
heat_df=heat_df,
dataset_name=dataset_name,
class_names=self._builder.target_names,
top_n_features=top_n_features,
save_prefix=save_prefix,
show=show,
theme=theme,
palette=palette,
)
[docs]
def sample_bc_weights(
self,
X_df: Any,
explanation: Optional[DPGExplanation] = None,
top_k: int = 10,
) -> Any:
"""Return the per-sample BC-derived bottleneck exposure weights."""
if explanation is None:
explanation = self.explain_global()
return sample_bc_weights(
explanation=explanation,
X_df=X_df,
top_k=top_k,
)
[docs]
def plot_sample_using_bc_weights(
self,
X_df: Any,
y,
explanation: Optional[DPGExplanation] = None,
top_k: int = 10,
dataset_name: str = "Dataset",
class_names: Optional[Any] = None,
save_path: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Any:
"""Plot samples in PCA space with marker size set by BC-derived weights."""
if explanation is None:
explanation = self.explain_global()
return plot_sample_using_bc_weights(
explanation=explanation,
X_df=X_df,
y=y,
top_k=top_k,
dataset_name=dataset_name,
class_names=class_names,
save_path=save_path,
show=show,
theme=theme,
palette=palette,
)
[docs]
def plot_class_bounds_vs_dataset_ranges(
self,
X_df: Any,
y,
explanation: Optional[DPGExplanation] = None,
dataset_name: str = "Dataset",
top_features: int = 4,
feature_cols_per_row: int = 4,
class_lookup: Optional[Dict[str, int]] = None,
class_filter: Optional[List[str]] = None,
density_tol_ratio: float = 0.03,
predicate_alpha: float = 0.55,
dataset_range_lw: float = 10,
save_path: Optional[str] = None,
show: bool = True,
community_threshold: float = 0.2,
theme: str = "dpg",
palette: str = "default",
) -> Optional[Any]:
"""Plot DPG class bounds against empirical dataset feature ranges."""
if explanation is None or explanation.communities is None:
explanation = self.explain_global(communities=True, community_threshold=community_threshold)
if class_lookup is None:
class_lookup = class_lookup_from_target_names(self._builder.target_names)
return plot_dpg_class_bounds_vs_dataset_feature_ranges(
explanation=explanation,
X_df=X_df,
y=y,
dataset_name=dataset_name,
top_features=top_features,
feature_cols_per_row=feature_cols_per_row,
class_lookup=class_lookup,
class_filter=class_filter,
density_tol_ratio=density_tol_ratio,
predicate_alpha=predicate_alpha,
dataset_range_lw=dataset_range_lw,
save_path=save_path,
show=show,
theme=theme,
palette=palette,
)