import os
import re
import textwrap
import warnings
import copy
import numpy as np
import pandas as pd
import networkx as nx
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from graphviz import Source
from graphviz.backend.execute import ExecutableNotFound
import matplotlib.patches as mpatches
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
from PIL import Image
from .utils import delete_folder_contents
from .themes import (
DPG_COLORS,
resolve_theme_context,
)
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
Image.MAX_IMAGE_PIXELS = 500000000 # Adjust based on your needs
_PREDICATE_PATTERN = re.compile(
r"(.+?)\s*(<=|>)\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)"
)
_LAYOUT_TEMPLATES = {
# Current behavior, close to Graphviz defaults for DPG usage.
"default": {
"graph": {"rankdir": "LR"},
"node": {},
"edge": {},
},
# Good for reducing horizontal spread and making figures more compact.
"compact": {
"graph": {"rankdir": "TB", "nodesep": "0.2", "ranksep": "0.25"},
"node": {"margin": "0.03,0.02"},
"edge": {"arrowsize": "0.6"},
},
# Strong vertical layout for long/wide graphs.
"vertical": {
"graph": {"rankdir": "TB", "nodesep": "0.25", "ranksep": "0.35"},
"node": {},
"edge": {},
},
# Explicitly wide left-to-right style.
"wide": {
"graph": {"rankdir": "LR", "nodesep": "0.5", "ranksep": "0.6"},
"node": {},
"edge": {},
},
}
_READABILITY_PRESETS = {
"compact": {
"graph": {"nodesep": "0.32", "ranksep": "0.42", "pad": "0.14"},
"node": {"fontsize": "10", "margin": "0.05,0.03"},
"wrap_width": 20,
},
"normal": {
"graph": {"nodesep": "0.40", "ranksep": "0.58", "pad": "0.18"},
"node": {"fontsize": "11", "margin": "0.06,0.04"},
"wrap_width": 18,
},
"presentation": {
"graph": {"nodesep": "0.52", "ranksep": "0.78", "pad": "0.22"},
"node": {"fontsize": "12", "margin": "0.10,0.07"},
"wrap_width": 16,
},
}
def _apply_matplotlib_theme(theme_context: Dict[str, Any]) -> None:
plt.rcParams.update(theme_context["mpl_style"])
def _style_axes(ax, theme_context: Dict[str, Any], grid_axis: str = "y") -> None:
colors = theme_context["colors"]
ax.set_facecolor(colors["paper"])
for side in ("top", "right"):
ax.spines[side].set_visible(False)
for side in ("left", "bottom"):
ax.spines[side].set_color(colors["grid"])
ax.spines[side].set_linewidth(1.0)
if grid_axis in {"x", "y", "both"}:
ax.grid(axis=grid_axis, linestyle="--", alpha=0.55, zorder=0)
def _style_legend(legend, theme_context: Dict[str, Any]) -> None:
if legend is None:
return
colors = theme_context["colors"]
frame = legend.get_frame()
frame.set_facecolor("#FFFDF7" if theme_context["theme"] == "dpg" else colors["paper"])
frame.set_edgecolor(colors["light_gray"])
frame.set_alpha(0.98)
def _style_figure(fig, theme_context: Dict[str, Any], title: Optional[str] = None, subtitle: Optional[str] = None) -> None:
colors = theme_context["colors"]
fig.patch.set_facecolor(colors["paper"])
if title:
fig.suptitle(title, color=colors["ink"], fontsize=14, fontweight="semibold")
if subtitle:
fig.text(0.5, 0.015, subtitle, ha="center", color=colors["mid_gray"], fontsize=9, style="italic")
def _class_fill_color(theme_context: Dict[str, Any]) -> str:
return theme_context["colors"].get("class_fill", theme_context["colors"]["success"])
def _default_node_color(theme_context: Dict[str, Any]) -> str:
return theme_context["colors"]["node_fill"]
def _default_pred_node_color(theme_context: Dict[str, Any]) -> str:
return theme_context["colors"]["node_muted"]
def _apply_dpg_graphviz_skin(dot, theme_context: Dict[str, Any]) -> None:
colors = theme_context["colors"]
dot.attr(
"graph",
bgcolor=colors["paper"],
pad="0.18",
nodesep="0.38",
ranksep="0.55",
)
dot.attr(
"node",
shape="box",
style="rounded,filled",
color=colors["light_gray"],
fillcolor=_default_node_color(theme_context),
fontname="DejaVu Sans",
fontsize="11",
margin="0.06,0.04",
penwidth="1.1",
)
dot.attr(
"edge",
color=colors["edge"],
fontname="DejaVu Sans",
fontsize="10",
arrowsize="0.7",
penwidth="1.1",
)
def _style_class_nodes(dot, df: Any, theme_context: Dict[str, Any]) -> None:
colors = theme_context["colors"]
class_rows = df[df["Label"].astype(str).str.contains("Class", regex=False, na=False)]
for _, row in class_rows.iterrows():
dot.node(
str(row["Node"]),
style="rounded,filled",
fillcolor=_class_fill_color(theme_context),
color=colors["danger"],
fontcolor="black",
penwidth="1.4",
)
def _apply_layout_template(dot, theme_context: Dict[str, Any], layout_template=None, graph_style=None, node_style=None, edge_style=None):
"""Apply optional graph layout/style settings to Graphviz Digraph."""
_apply_dpg_graphviz_skin(dot, theme_context)
template_name = (layout_template or "default").lower()
template = _LAYOUT_TEMPLATES.get(template_name, _LAYOUT_TEMPLATES["default"])
merged_graph = dict(template.get("graph", {}))
merged_node = dict(template.get("node", {}))
merged_edge = dict(template.get("edge", {}))
if graph_style:
merged_graph.update(graph_style)
if node_style:
merged_node.update(node_style)
if edge_style:
merged_edge.update(edge_style)
if merged_graph:
dot.attr("graph", **{str(k): str(v) for k, v in merged_graph.items()})
if merged_node:
dot.attr("node", **{str(k): str(v) for k, v in merged_node.items()})
if merged_edge:
dot.attr("edge", **{str(k): str(v) for k, v in merged_edge.items()})
def _graphviz_not_found_error() -> RuntimeError:
message = (
"Graphviz executable 'dot' was not found in PATH.\n"
"Install Graphviz and ensure 'dot' is available from your terminal.\n"
"Install examples:\n"
"- macOS (Homebrew): brew install graphviz\n"
"- Ubuntu/Debian: sudo apt-get install graphviz\n"
"- Windows (winget): winget install Graphviz.Graphviz"
)
return RuntimeError(message)
def _shorten_feature_name(feature: str) -> str:
shortened = str(feature)
replacements = {
"sepal": "sep.",
"petal": "pet.",
"length": "len.",
"width": "wid.",
"diameter": "diam.",
"feature": "feat.",
}
for old, new in replacements.items():
shortened = re.sub(rf"\b{old}\b", new, shortened, flags=re.IGNORECASE)
shortened = re.sub(r"\s*\([^)]*\)", "", shortened).strip()
return shortened
def _format_graph_label_for_readability(
label: str,
wrap_width: int = 18,
label_mode: str = "full",
) -> str:
text = str(label)
if not text:
return text
normalized_mode = str(label_mode or "full").lower()
if normalized_mode not in {"full", "wrapped", "short"}:
raise ValueError("label_mode must be one of: full, wrapped, short.")
parsed = _PREDICATE_PATTERN.fullmatch(text.strip())
if parsed:
feature, operator, threshold = parsed.groups()
feature_text = feature.strip()
if normalized_mode == "short":
feature_text = _shorten_feature_name(feature_text)
return f"{feature_text}\\n{operator} {threshold}"
if normalized_mode == "wrapped":
feature_block = textwrap.fill(feature_text, width=wrap_width)
return f"{feature_block}\\n{operator} {threshold}"
return text
if text.startswith("Class "):
class_name = text.replace("Class ", "", 1)
if normalized_mode == "short":
return f"Class\\n{_shorten_feature_name(class_name)}"
if normalized_mode == "wrapped":
wrapped_class_name = textwrap.fill(class_name, width=wrap_width).replace("\n", "\\n")
return f"Class\\n{wrapped_class_name}"
return text
if normalized_mode == "short":
return _shorten_feature_name(text)
if normalized_mode == "wrapped" and len(text) > wrap_width:
return textwrap.fill(text, width=wrap_width).replace("\n", "\\n")
return text
def _sanitize_dot_source(
source: str,
wrap_labels: bool = False,
wrap_width: int = 18,
label_mode: str = "full",
) -> str:
# Escape brackets and quotes in node labels to avoid DOT parse errors, and
# optionally reformat long labels for better readability.
def repl(m):
label = m.group(1)
if wrap_labels:
label = _format_graph_label_for_readability(
label,
wrap_width=wrap_width,
label_mode=label_mode,
)
label = label.replace("\\", "\\\\").replace('"', '\\"')
label = label.replace("[", "\\[").replace("]", "\\]")
return f'label="{label}"'
source = re.sub(r'label="([^"]*)"', repl, source)
source = re.sub(r'label=([^\\s\\]]+)', r'label="\\1"', source)
return source
def _pipe_graph_png_with_fallback(dot_source: str, sanitizer) -> bytes:
try:
return Source(dot_source).pipe(format="png")
except ExecutableNotFound as exc:
raise _graphviz_not_found_error() from exc
except Exception as first_exc:
print(f"Plotting failed with {type(first_exc).__name__}; retrying with sanitized DOT source.")
try:
return Source(sanitizer(dot_source)).pipe(format="png")
except ExecutableNotFound as exc:
raise _graphviz_not_found_error() from exc
except Exception:
raise first_exc
def _apply_graph_readability_preset(dot, readability: str = "normal") -> int:
preset_name = str(readability or "normal").lower()
if preset_name not in _READABILITY_PRESETS:
raise ValueError("readability must be one of: compact, normal, presentation.")
preset = _READABILITY_PRESETS[preset_name]
dot.attr("graph", **preset["graph"])
dot.attr("node", **preset["node"])
return int(preset["wrap_width"])
[docs]
def plot_dpg(
plot_name,
dot,
df,
df_edges,
save_dir="results/",
attribute=None,
clusters=None,
threshold_clusters=None,
class_flag=False,
layout_template="default",
graph_style=None,
node_style=None,
edge_style=None,
fig_size=(16, 8),
dpi=300,
pdf_dpi=600,
show=True,
export_pdf=False,
theme: str = "dpg",
palette: str = "default",
label_mode: str = "full",
readability: str = "normal",
title: Optional[str] = None,
):
"""
Plot a Decision Predicate Graph (DPG) with optional node/edge styling.
Args:
plot_name: Output base name for saved files (no extension).
dot: Graphviz Digraph instance representing the DPG structure.
df: DataFrame with node metrics; must include ``'Node'`` and ``'Label'`` columns.
df_edges: DataFrame with edge metrics; must include ``'Source_id'``,
``'Target_id'``, and ``'Weight'``.
save_dir: Directory where output images are saved. Default is ``"results/"``.
attribute: Optional node metric column name to color nodes by (e.g. ``'Degree'``).
clusters: Optional mapping ``{cluster_label: [node_id, ...]}`` to color nodes
by cluster membership.
threshold_clusters: Optional value used only to annotate the output filename.
class_flag: If ``True``, class nodes are highlighted in yellow before other
coloring.
layout_template: Optional layout preset. One of
``{'default', 'compact', 'vertical', 'wide'}``.
graph_style: Optional dict of Graphviz graph attributes to override template
values.
node_style: Optional dict of Graphviz node attributes to override template
values.
edge_style: Optional dict of Graphviz edge attributes to override template
values.
fig_size: Matplotlib figure size as ``(width, height)``.
dpi: PNG export/display resolution.
pdf_dpi: PDF export resolution when ``export_pdf=True``.
show: Whether to display the image via Matplotlib. Default is ``True``.
export_pdf: If ``True``, also writes a PDF next to the PNG.
label_mode: Label formatting strategy. One of
``{'full', 'wrapped', 'short'}``.
readability: Graph spacing/readability preset. One of
``{'compact', 'normal', 'presentation'}``.
Returns:
None
"""
print("Plotting DPG...")
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
original_df = df.copy()
_apply_layout_template(
dot,
theme_context=theme_context,
layout_template=layout_template,
graph_style=graph_style,
node_style=node_style,
edge_style=edge_style,
)
wrap_width = _apply_graph_readability_preset(dot, readability=readability)
# Basic color scheme if no attribute or communities are specified
if attribute is None and clusters is None:
for index, row in df.iterrows():
if 'Class' in row['Label']:
change_node_color(dot, row['Node'], _class_fill_color(theme_context))
else:
change_node_color(dot, row['Node'], _default_node_color(theme_context))
# Color nodes based on a specific attribute
elif attribute is not None and clusters is None:
colormap = theme_context["sequential_cmap"]
norm = None
# Highlight class nodes if class_flag is True
if class_flag:
for index, row in df.iterrows():
if 'Class' in row['Label']:
change_node_color(dot, row['Node'], _class_fill_color(theme_context))
df = df[~df.Label.str.contains('Class')].reset_index(drop=True) # Exclude class nodes from further processing
# Normalize the attribute values if norm_flag is True
max_score = df[attribute].max()
norm = mcolors.Normalize(0, max_score)
node_rgba = colormap(norm(df[attribute])) # Assign colors based on normalized scores
for index, row in df.iterrows():
color = "#{:02x}{:02x}{:02x}".format(int(node_rgba[index][0]*255), int(node_rgba[index][1]*255), int(node_rgba[index][2]*255))
change_node_color(dot, row['Node'], color)
plot_name = plot_name + f"_{attribute}".replace(" ","")
elif attribute is None and clusters is not None:
palette_hex = theme_context["class_palette"][: max(1, len(clusters) + 1)]
# Highlight class nodes if class_flag is True
if class_flag:
for index, row in df.iterrows():
if 'Class' in row['Label']:
change_node_color(dot, row['Node'], _class_fill_color(theme_context))
df = df[~df.Label.str.contains('Class')].reset_index(drop=True) # Exclude class nodes from further processing
node_to_cluster = {}
for clabel, node_list in clusters.items():
for node_id in node_list:
node_to_cluster[str(node_id)] = clabel
df['Cluster'] = df['Node'].astype(str).map(lambda n: node_to_cluster.get(n, 'ambiguous'))
unique_clusters = sorted([c for c in df['Cluster'].unique() if c != 'ambiguous'])
cluster_to_idx = {lab: i for i, lab in enumerate(unique_clusters)}
ambiguous_idx = len(unique_clusters)
cluster_to_idx['ambiguous'] = ambiguous_idx
if 'ambiguous' in cluster_to_idx:
if cluster_to_idx['ambiguous'] >= len(palette_hex):
palette_hex.append(colors["light_gray"])
else:
palette_hex[cluster_to_idx['ambiguous']] = colors["light_gray"]
for i, row in df.iterrows():
idx = cluster_to_idx.get(row['Cluster'], cluster_to_idx['ambiguous'])
color = palette_hex[idx]
change_node_color(dot, row['Node'], color)
plot_name = plot_name + f"_clusters_{threshold_clusters}"
else:
raise AttributeError("The plot can show the basic plot, clusters or a specific node-metric")
# Highlight edges
colormap_edge = theme_context["edge_cmap"]
max_edge_value = df_edges['Weight'].max()
min_edge_value = df_edges['Weight'].min()
norm_edge = mcolors.Normalize(vmin=min_edge_value, vmax=max_edge_value)
for index, row in df_edges.iterrows():
edge_value = row['Weight']
color = colormap_edge(norm_edge(edge_value))
color_hex = "#{:02x}{:02x}{:02x}".format(int(color[0]*255),
int(color[1]*255),
int(color[2]*255))
penwidth = 1 + 3 * norm_edge(edge_value)
change_edge_color(dot, row['Source_id'], row['Target_id'], new_color=color_hex, new_width = penwidth)
# Convert to scientific notation
# def to_sci_notation(match):
# num = float(match.group(1))
# return f'label="{num:.2e}"'
# pattern = r'label=([0-9]+\.?[0-9]*)'
# for i in range(len(dot.body)):
# dot.body[i] = re.sub(pattern, to_sci_notation, dot.body[i])
# if "->" in dot.body[i]:
# dot.body[i] = re.sub(r'\s*label="[^"]*"', '', dot.body[i])
_style_class_nodes(dot, original_df, theme_context)
# Render the graph to PNG bytes (avoid temp files)
png_bytes = _pipe_graph_png_with_fallback(
dot.source,
lambda source: _sanitize_dot_source(
source,
wrap_labels=label_mode != "full",
wrap_width=wrap_width,
label_mode=label_mode,
),
)
# Open and display the rendered image
img = Image.open(BytesIO(png_bytes))
fig, ax = plt.subplots(figsize=fig_size)
fig.patch.set_facecolor(colors["paper"])
ax.set_axis_off()
ax.set_title(title or plot_name, color=colors["ink"], fontsize=13, fontweight="semibold")
ax.imshow(img)
# Add a color bar if an attribute is specified
if attribute is not None:
# Place the colorbar just below the graph to reduce whitespace
ax_pos = ax.get_position()
cbar_height = 0.02
cbar_pad = 0.02
cbar_y = max(0.01, ax_pos.y0 - (cbar_height + cbar_pad))
cax = fig.add_axes([ax_pos.x0, cbar_y, ax_pos.width, cbar_height])
cbar = fig.colorbar(
cm.ScalarMappable(norm=norm, cmap=colormap),
cax=cax,
orientation='horizontal',
)
cbar.set_label(attribute)
cbar.outline.set_edgecolor(colors["light_gray"])
cbar.ax.xaxis.label.set_color(colors["charcoal"])
cbar.ax.tick_params(colors=colors["charcoal"])
# Save the plot to the specified directory
os.makedirs(save_dir, exist_ok=True)
fig.savefig(os.path.join(save_dir, plot_name + ".png"), dpi=dpi, bbox_inches="tight", pad_inches=0.02)
if export_pdf:
fig.savefig(
os.path.join(save_dir, plot_name + ".pdf"),
format="pdf",
dpi=pdf_dpi,
bbox_inches="tight",
pad_inches=0.02,
)
#plt.show()
# No PDF output by default
# Clean up temporary files
# delete_folder_contents("temp")
if not show:
plt.close(fig)
[docs]
def plot_dpg_communities(
plot_name,
dot,
df,
dpg_metrics,
save_dir="results/",
class_flag=False,
df_edges=None,
layout_template="default",
graph_style=None,
node_style=None,
edge_style=None,
fig_size=(16, 8),
dpi=300,
pdf_dpi=600,
show=True,
export_pdf=False,
theme: str = "dpg",
palette: str = "default",
label_mode: str = "wrapped",
readability: str = "presentation",
title: Optional[str] = None,
):
"""
Plot a DPG colored by community assignment.
Args:
plot_name: Output base name for saved files (no extension).
dot: Graphviz Digraph instance representing the DPG structure.
df: DataFrame with node metrics; must include ``'Node'`` and ``'Label'``
columns.
dpg_metrics: Dict containing either ``'Communities'`` (list of sets/lists of
node labels) or ``'Clusters'`` (mapping cluster_label -> list of node
labels).
save_dir: Directory where output images are saved. Default is ``"results/"``.
class_flag: If ``True``, class nodes are highlighted in yellow before other
coloring.
df_edges: Optional DataFrame with edge metrics to color edges by weight.
layout_template: Optional layout preset. One of
``{'default', 'compact', 'vertical', 'wide'}``.
graph_style: Optional dict of Graphviz graph attributes to override template
values.
node_style: Optional dict of Graphviz node attributes to override template
values.
edge_style: Optional dict of Graphviz edge attributes to override template
values.
fig_size: Matplotlib figure size as ``(width, height)``.
dpi: PNG export/display resolution.
pdf_dpi: PDF export resolution when ``export_pdf=True``.
show: Whether to display the image via Matplotlib. Default is ``True``.
export_pdf: If ``True``, also writes a PDF next to the PNG.
label_mode: Label formatting strategy. One of
``{'full', 'wrapped', 'short'}``.
readability: Graph spacing/readability preset. One of
``{'compact', 'normal', 'presentation'}``.
Returns:
None
"""
print("Plotting DPG (communities)...")
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
original_df = df.copy()
_apply_layout_template(
dot,
theme_context=theme_context,
layout_template=layout_template,
graph_style=graph_style,
node_style=node_style,
edge_style=edge_style,
)
wrap_width = _apply_graph_readability_preset(dot, readability=readability)
if dpg_metrics is None:
raise AttributeError("dpg_metrics is required to plot communities.")
colormap = theme_context["community_cmap"]
# Highlight class nodes if class_flag is True
if class_flag:
for index, row in df.iterrows():
if 'Class' in row['Label']:
change_node_color(dot, row['Node'], _class_fill_color(theme_context))
df = df[~df.Label.str.contains('Class')].reset_index(drop=True) # Exclude class nodes from further processing
# Map labels to community indices
if "Communities" in dpg_metrics:
communities = dpg_metrics.get("Communities", [])
elif "Clusters" in dpg_metrics:
clusters = dpg_metrics.get("Clusters", {})
communities = list(clusters.values())
else:
raise AttributeError("dpg_metrics must include 'Communities' or 'Clusters' to plot communities.")
label_to_community = {}
for idx, community in enumerate(communities):
for label in community:
label_to_community[label] = idx
df['Community'] = df['Label'].map(label_to_community)
if df['Community'].isna().all():
raise AttributeError("No nodes matched communities/clusters labels.")
max_score = df['Community'].max()
if max_score <= 0:
norm = mcolors.Normalize(0, 1)
else:
norm = mcolors.Normalize(0, max_score) # Normalize the community indices
node_rgba = colormap(norm(df['Community'])) # Assign colors based on normalized community indices
for index, row in df.iterrows():
if pd.isna(row['Community']):
color = colors["light_gray"]
else:
color = "#{:02x}{:02x}{:02x}".format(
int(node_rgba[index][0] * 255),
int(node_rgba[index][1] * 255),
int(node_rgba[index][2] * 255),
)
change_node_color(dot, row['Node'], color)
plot_name = plot_name + "_communities"
# Highlight edges (optional)
if df_edges is not None:
colormap_edge = theme_context["edge_cmap"]
max_edge_value = df_edges['Weight'].max()
min_edge_value = df_edges['Weight'].min()
norm_edge = mcolors.Normalize(vmin=min_edge_value, vmax=max_edge_value)
for index, row in df_edges.iterrows():
edge_value = row['Weight']
color = colormap_edge(norm_edge(edge_value))
color_hex = "#{:02x}{:02x}{:02x}".format(
int(color[0] * 255),
int(color[1] * 255),
int(color[2] * 255),
)
penwidth = 1 + 3 * norm_edge(edge_value)
change_edge_color(dot, row['Source_id'], row['Target_id'], new_color=color_hex, new_width=penwidth)
_style_class_nodes(dot, original_df, theme_context)
# Render the graph to PNG bytes (avoid temp files)
png_bytes = _pipe_graph_png_with_fallback(
dot.source,
lambda source: _sanitize_dot_source(
source,
wrap_labels=label_mode != "full",
wrap_width=wrap_width,
label_mode=label_mode,
),
)
# Open and display the rendered image
img = Image.open(BytesIO(png_bytes))
fig, ax = plt.subplots(figsize=fig_size)
fig.patch.set_facecolor(colors["paper"])
ax.set_axis_off()
ax.set_title(title or plot_name, color=colors["ink"], fontsize=13, fontweight="semibold")
ax.imshow(img)
# Save the plot to the specified directory with tight borders
os.makedirs(save_dir, exist_ok=True)
fig.savefig(
os.path.join(save_dir, plot_name + ".png"),
dpi=dpi,
bbox_inches="tight",
pad_inches=0.02,
)
if export_pdf:
fig.savefig(
os.path.join(save_dir, plot_name + ".pdf"),
format="pdf",
dpi=pdf_dpi,
bbox_inches="tight",
pad_inches=0.02,
)
if not show:
plt.close(fig)
# No PDF output by default
# Clean up temporary files
# delete_folder_contents("temp")
[docs]
def plot_dpg_local_paths_aggregate(
plot_name,
dot,
df,
df_edges,
paths_node_ids,
path_confidences=None,
sample_id=None,
true_class_label=None,
obtained_class_label=None,
sample_metrics=None,
save_dir="results/",
class_flag=True,
layout_template="default",
graph_style=None,
node_style=None,
edge_style=None,
fig_size=(16, 8),
dpi=300,
pdf_dpi=600,
show=True,
export_pdf=False,
theme: str = "dpg",
palette: str = "default",
label_mode: str = "wrapped",
readability: str = "presentation",
title: Optional[str] = None,
):
"""
Plot a fitted DPG with one sample's local paths highlighted on top.
Args mirror the existing DPG plot API, with local path overlays passed as
ordered node-id paths and optional per-path confidence weights.
Returns:
matplotlib.figure.Figure
"""
print("Plotting local DPG paths...")
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
local_dot = copy.deepcopy(dot)
original_df = df.copy()
_apply_layout_template(
local_dot,
theme_context=theme_context,
layout_template=layout_template,
graph_style=graph_style,
node_style=node_style,
edge_style=edge_style,
)
wrap_width = _apply_graph_readability_preset(local_dot, readability=readability)
visited_node_weights: Dict[str, float] = {}
visited_edge_weights: Dict[Tuple[str, str], float] = {}
path_confidences = list(path_confidences or [])
for path_index, path in enumerate(paths_node_ids):
weight = 1.0
if path_index < len(path_confidences) and path_confidences[path_index] is not None:
weight = float(path_confidences[path_index])
for node_id in path:
if node_id is None:
continue
node_id = str(node_id)
visited_node_weights[node_id] = visited_node_weights.get(node_id, 0.0) + weight
for i in range(len(path) - 1):
src = path[i]
dst = path[i + 1]
if src is None or dst is None:
continue
edge_key = (str(src), str(dst))
visited_edge_weights[edge_key] = visited_edge_weights.get(edge_key, 0.0) + weight
visited_nodes = set(visited_node_weights)
visited_edges = set(visited_edge_weights)
subdued_pred_color = colors.get("node_muted", colors["light_gray"])
subdued_class_color = _class_fill_color(theme_context)
for _, row in df.iterrows():
node_id = str(row["Node"])
label = str(row["Label"])
if node_id in visited_nodes:
continue
if label.startswith("Class "):
change_node_color(local_dot, node_id, subdued_class_color)
else:
change_node_color(local_dot, node_id, subdued_pred_color)
highlight_node_color = colors.get(
"charcoal",
colors.get("edge", colors.get("danger", "#333333")),
)
true_class_fill = colors.get("success", highlight_node_color)
obtained_class_fill = colors.get("danger", highlight_node_color)
normalized_true_class_label = None
if true_class_label is not None:
normalized_true_class_label = str(true_class_label)
if not normalized_true_class_label.startswith("Class "):
normalized_true_class_label = f"Class {normalized_true_class_label}"
normalized_obtained_class_label = None
if obtained_class_label is not None:
normalized_obtained_class_label = str(obtained_class_label)
if not normalized_obtained_class_label.startswith("Class "):
normalized_obtained_class_label = f"Class {normalized_obtained_class_label}"
for _, row in df.iterrows():
node_id = str(row["Node"])
if node_id not in visited_nodes:
continue
label = str(row["Label"])
fillcolor = highlight_node_color
if label.startswith("Class "):
fillcolor = subdued_class_color
if normalized_true_class_label is not None and label == normalized_true_class_label:
fillcolor = true_class_fill
if normalized_obtained_class_label is not None and label == normalized_obtained_class_label:
fillcolor = obtained_class_fill
if (
normalized_true_class_label is not None
and normalized_obtained_class_label is not None
and normalized_true_class_label == normalized_obtained_class_label
and label == normalized_true_class_label
):
fillcolor = true_class_fill
change_node_color(local_dot, node_id, fillcolor)
if not df_edges.empty:
colormap_edge = theme_context["edge_cmap"]
max_edge_value = df_edges["Weight"].max()
min_edge_value = df_edges["Weight"].min()
norm_edge = mcolors.Normalize(vmin=min_edge_value, vmax=max_edge_value)
for _, row in df_edges.iterrows():
source_id = str(row["Source_id"])
target_id = str(row["Target_id"])
edge_value = row["Weight"]
color = colormap_edge(norm_edge(edge_value))
color_hex = "#{:02x}{:02x}{:02x}".format(
int(color[0] * 255),
int(color[1] * 255),
int(color[2] * 255),
)
penwidth = 0.8 + 1.8 * norm_edge(edge_value)
if (source_id, target_id) in visited_edges:
strength = visited_edge_weights[(source_id, target_id)]
max_strength = max(visited_edge_weights.values()) if visited_edge_weights else 1.0
normalized_strength = strength / max_strength if max_strength > 0 else 1.0
color_hex = highlight_node_color
penwidth = 1.8 + 4.2 * normalized_strength
else:
color_hex = colors.get("light_gray", color_hex)
penwidth = 0.7
change_edge_color(local_dot, source_id, target_id, new_color=color_hex, new_width=penwidth)
if class_flag:
_style_class_nodes(local_dot, original_df, theme_context)
for _, row in df.iterrows():
node_id = str(row["Node"])
label = str(row["Label"])
if node_id not in visited_nodes or not label.startswith("Class "):
continue
fillcolor = subdued_class_color
if normalized_true_class_label is not None and label == normalized_true_class_label:
fillcolor = true_class_fill
if normalized_obtained_class_label is not None and label == normalized_obtained_class_label:
fillcolor = obtained_class_fill
if (
normalized_true_class_label is not None
and normalized_obtained_class_label is not None
and normalized_true_class_label == normalized_obtained_class_label
and label == normalized_true_class_label
):
fillcolor = true_class_fill
change_node_color(local_dot, node_id, fillcolor)
title_lines = []
if title:
title_lines.append(title)
else:
title_lines.append(plot_name)
meta_parts = []
if sample_id is not None:
meta_parts.append(f"sample={sample_id}")
if obtained_class_label is not None:
meta_parts.append(f"pred={obtained_class_label}")
if true_class_label is not None:
meta_parts.append(f"true={true_class_label}")
if meta_parts:
title_lines.append(" | ".join(meta_parts))
if sample_metrics:
metric_parts = []
for key in ("vote_confidence", "evidence_score_pred", "trace_coverage_score"):
value = sample_metrics.get(key)
if value is not None:
metric_parts.append(f"{key}={float(value):.2f}")
if metric_parts:
title_lines.append(" | ".join(metric_parts))
png_bytes = _pipe_graph_png_with_fallback(
local_dot.source,
lambda source: _sanitize_dot_source(
source,
wrap_labels=label_mode != "full",
wrap_width=wrap_width,
label_mode=label_mode,
),
)
img = Image.open(BytesIO(png_bytes))
fig, ax = plt.subplots(figsize=fig_size)
fig.patch.set_facecolor(colors["paper"])
ax.set_axis_off()
ax.set_title("\n".join(title_lines), color=colors["ink"], fontsize=13, fontweight="semibold")
ax.imshow(img)
os.makedirs(save_dir, exist_ok=True)
fig.savefig(os.path.join(save_dir, plot_name + ".png"), dpi=dpi, bbox_inches="tight", pad_inches=0.02)
if export_pdf:
fig.savefig(
os.path.join(save_dir, plot_name + ".pdf"),
format="pdf",
dpi=pdf_dpi,
bbox_inches="tight",
pad_inches=0.02,
)
if not show:
plt.close(fig)
return fig
[docs]
def change_node_color(dot, node_id: str, fillcolor: str) -> None:
"""Update a node's fill color and set an appropriate contrasting font color.
Args:
dot: Graphviz Digraph object to modify in-place.
node_id: Node identifier string as used in the Digraph.
fillcolor: Hex color string (e.g. ``"#a4c2f4"``).
"""
r, g, b = int(fillcolor[1:3], 16), int(fillcolor[3:5], 16), int(fillcolor[5:7], 16)
brightness = (r * 299 + g * 587 + b * 114) / 1000 # fórmula perceptual
fontcolor = "white" if brightness < 100 else "black"
# Modifica o nó no objeto Graphviz
dot.node(node_id, style="filled", fillcolor=fillcolor, fontcolor=fontcolor)
[docs]
def normalize_data(df: Any, attribute: str, colormap) -> Dict[Any, str]:
"""Map a numeric DataFrame column to hex color strings via a colormap.
Args:
df: DataFrame containing at least a ``'Node'`` column and the ``attribute`` column.
attribute: Column name whose values drive the colormap.
colormap: Matplotlib colormap instance.
Returns:
Dict mapping node identifier to hex color string.
"""
norm = Normalize(vmin=df[attribute].min(), vmax=df[attribute].max())
colors = [colormap(norm(value)) for value in df[attribute]]
return {node: "#{:02x}{:02x}{:02x}".format(int(color[0]*255), int(color[1]*255), int(color[2]*255)) for node, color in zip(df['Node'], colors)}
[docs]
def plot_dpg_reg(
plot_name: str,
dot,
df: Any,
df_dpg: Dict[str, Any],
save_dir: str = "examples/",
attribute: Optional[str] = None,
communities: bool = False,
leaf_flag: bool = False,
theme: str = "dpg",
palette: str = "default",
) -> None:
"""Plot a regression DPG with optional node coloring by attribute or community.
Args:
plot_name: Output base name for saved files (no extension).
dot: Graphviz Digraph instance representing the DPG structure.
df: DataFrame with node metrics; must include ``'Node'`` and ``'Label'`` columns.
df_dpg: Dict of DPG metrics; used for ``'Communities'`` when ``communities=True``.
save_dir: Directory where output images are saved. Default is ``"examples/"``.
attribute: Optional node metric column name to color nodes by.
communities: If True, color nodes by community index instead of a single attribute.
leaf_flag: If True and ``attribute`` is set, exclude prediction (leaf) nodes from
attribute coloring.
"""
print("Rendering plot...")
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
node_colors = {}
if attribute or communities:
if attribute:
df = df[~df['Label'].str.contains('Pred')] if leaf_flag else df
node_colors = normalize_data(df, attribute, theme_context["sequential_cmap"])
plot_name += f"_{attribute.replace(' ', '')}"
elif communities:
df['Community'] = df['Label'].map({label: idx for idx, s in enumerate(df_dpg['Communities']) for label in s})
node_colors = normalize_data(df, 'Community', theme_context["community_cmap"])
plot_name += "_communities"
else:
node_colors = {}
for _, row in df.iterrows():
fill = _class_fill_color(theme_context) if "Pred" in str(row["Label"]) else _default_pred_node_color(theme_context)
node_colors[row["Node"]] = fill
# Apply node colors
for node, color in node_colors.items():
change_node_color(dot, node, color)
graph_path = os.path.join(save_dir, f"{plot_name}_temp.gv")
try:
dot.render(graph_path, view=False, format='png')
except ExecutableNotFound as exc:
raise _graphviz_not_found_error() from exc
# Display and save the image
img_path = f"{graph_path}.png"
img = Image.open(img_path)
fig, ax = plt.subplots(figsize=(16, 8))
fig.patch.set_facecolor(colors["paper"])
ax.axis('off')
ax.set_title(plot_name, color=colors["ink"], fontsize=13, fontweight="semibold")
ax.imshow(img)
if attribute:
cax = fig.add_axes([0.11, 0.1, 0.8, 0.025])
norm = Normalize(vmin=df[attribute].min(), vmax=df[attribute].max())
cbar = fig.colorbar(ScalarMappable(norm=norm, cmap=theme_context["sequential_cmap"]), cax=cax, orientation='horizontal')
cbar.set_label(attribute)
cbar.outline.set_edgecolor(colors["light_gray"])
cbar.ax.tick_params(colors=colors["charcoal"])
fig.savefig(os.path.join(save_dir, f"{plot_name}_REG.png"), dpi=300, bbox_inches="tight", pad_inches=0.04)
plt.close(fig) # Free up memory by closing the plot
# Clean up temporary files
if os.path.isdir("temp"):
delete_folder_contents("temp")
[docs]
def plot_dpg_constraints_overview(
normalized_constraints: Dict,
feature_names: List[str],
class_colors_list: List[str],
output_path: str = None,
title: str = "DPG Constraints Overview",
original_sample: Dict = None,
original_class: int = None,
target_class: int = None,
theme: str = "dpg",
palette: str = "default",
) -> Any:
"""Create a horizontal bar chart showing DPG constraints for all features.
Similar to the "Feature Changes" chart style, this shows:
- Original sample values as markers/bars
- Constraint boundaries (min/max) for original and target classes as colored regions
Args:
normalized_constraints: Dict with structure {class_name: {feature: {min, max}}}
feature_names: List of feature names to display
class_colors_list: List of colors for each class
output_path: Optional path to save the figure
title: Title for the figure
original_sample: Optional dict of original sample feature values
original_class: Optional original class index (for highlighting)
target_class: Optional target class index (for highlighting)
Returns:
matplotlib Figure object
"""
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
if not normalized_constraints:
print("WARNING: No constraints available for visualization")
return None
# Get list of classes
class_names = sorted(normalized_constraints.keys())
n_classes = len(class_names)
# Filter features that have constraints in at least one class
features_with_constraints = []
for feat in feature_names:
has_constraint = any(
feat in normalized_constraints.get(cname, {})
for cname in class_names
)
if has_constraint:
features_with_constraints.append(feat)
if not features_with_constraints:
print("WARNING: No features with constraints found")
return None
n_features = len(features_with_constraints)
# Identify non-overlapping features between classes
# Non-overlapping means the ranges are disjoint (no intersection)
# c1_max < c2_min means c1 range ends BEFORE c2 range starts (strictly less than)
non_overlapping_features = set()
for feat in features_with_constraints:
for i, c1 in enumerate(class_names):
for c2 in class_names[i+1:]:
c1_bounds = normalized_constraints.get(c1, {}).get(feat, {})
c2_bounds = normalized_constraints.get(c2, {}).get(feat, {})
c1_min = c1_bounds.get('min')
c1_max = c1_bounds.get('max')
c2_min = c2_bounds.get('min')
c2_max = c2_bounds.get('max')
# Check for non-overlap (strictly less than, not equal)
# Equal bounds means they touch/overlap, not non-overlapping
if c1_max is not None and c2_min is not None and c1_max < c2_min:
non_overlapping_features.add(feat)
if c2_max is not None and c1_min is not None and c2_max < c1_min:
non_overlapping_features.add(feat)
# Create figure
fig, ax = plt.subplots(figsize=(14, max(6, n_features * 0.5)))
fig.patch.set_facecolor(colors["paper"])
ax.set_facecolor(colors["paper"])
# Y positions for features
y_positions = np.arange(n_features)
bar_height = 0.35
# Collect global min/max for x-axis scaling
all_values = []
for cname in class_names:
for feat in features_with_constraints:
if feat in normalized_constraints.get(cname, {}):
bounds = normalized_constraints[cname][feat]
if bounds.get('min') is not None:
all_values.append(bounds['min'])
if bounds.get('max') is not None:
all_values.append(bounds['max'])
# Include original sample values in scaling if provided
if original_sample:
for feat in features_with_constraints:
if feat in original_sample:
all_values.append(original_sample[feat])
if not all_values:
print("WARNING: No constraint values found")
return None
# Filter out NaN and Inf values
all_values = [v for v in all_values if v is not None and np.isfinite(v)]
if not all_values:
print("WARNING: No valid (finite) constraint values found")
return None
value_range = max(all_values) - min(all_values)
# Handle case where all values are the same (range = 0)
if value_range == 0 or not np.isfinite(value_range):
value_range = abs(max(all_values)) * 0.2 if max(all_values) != 0 else 1.0
x_min = min(all_values) - 0.1 * value_range
x_max = max(all_values) + 0.1 * value_range
# For each feature, draw constraint regions for each class
for feat_idx, feat in enumerate(features_with_constraints):
y = y_positions[feat_idx]
# Highlight non-overlapping features
is_discriminative = feat in non_overlapping_features
if is_discriminative:
ax.axhspan(y - 0.45, y + 0.45, alpha=0.12, color=colors.get("gold", colors["success"]), zorder=0)
# Draw constraints for each class
for class_idx, cname in enumerate(class_names):
color = class_colors_list[class_idx % len(class_colors_list)]
if feat in normalized_constraints.get(cname, {}):
bounds = normalized_constraints[cname][feat]
feat_min = bounds.get('min')
feat_max = bounds.get('max')
# Determine y offset for this class
y_offset = (class_idx - (n_classes - 1) / 2) * bar_height * 0.8
# Draw constraint region as horizontal bar
if feat_min is not None and feat_max is not None:
# Both bounds - draw filled rectangle
rect = mpatches.Rectangle(
(feat_min, y + y_offset - bar_height/2),
feat_max - feat_min,
bar_height,
linewidth=2,
edgecolor=color,
facecolor=color,
alpha=0.3,
zorder=2
)
ax.add_patch(rect)
# Add min/max value labels
ax.text(feat_min, y + y_offset, f'{feat_min:.2f}',
ha='right', va='center', fontsize=7, color=color, weight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='#FFFDF7',
edgecolor=color, alpha=0.8, linewidth=0.5))
ax.text(feat_max, y + y_offset, f'{feat_max:.2f}',
ha='left', va='center', fontsize=7, color=color, weight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='#FFFDF7',
edgecolor=color, alpha=0.8, linewidth=0.5))
elif feat_min is not None:
# Only min bound - draw line with arrow pointing right
ax.plot([feat_min, x_max], [y + y_offset, y + y_offset],
color=color, linewidth=3, alpha=0.5, linestyle='--', zorder=2)
ax.scatter([feat_min], [y + y_offset], color=color, s=100,
marker='|', zorder=3, linewidths=3)
ax.text(feat_min, y + y_offset + bar_height/2, f'min:{feat_min:.2f}',
ha='center', va='bottom', fontsize=7, color=color, weight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='#FFFDF7',
edgecolor=color, alpha=0.8, linewidth=0.5))
elif feat_max is not None:
# Only max bound - draw line with arrow pointing left
ax.plot([x_min, feat_max], [y + y_offset, y + y_offset],
color=color, linewidth=3, alpha=0.5, linestyle='--', zorder=2)
ax.scatter([feat_max], [y + y_offset], color=color, s=100,
marker='|', zorder=3, linewidths=3)
ax.text(feat_max, y + y_offset + bar_height/2, f'max:{feat_max:.2f}',
ha='center', va='bottom', fontsize=7, color=color, weight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='#FFFDF7',
edgecolor=color, alpha=0.8, linewidth=0.5))
# Draw original sample value if provided
if original_sample and feat in original_sample:
sample_val = original_sample[feat]
# Draw as a prominent marker
ax.scatter([sample_val], [y], color=colors["ink"], s=150, marker='o',
zorder=10, edgecolors='#FFFDF7', linewidths=2)
ax.plot([sample_val, sample_val], [y - 0.4, y + 0.4],
color=colors["ink"], linewidth=2, linestyle='-', zorder=9, alpha=0.7)
ax.text(sample_val, y + 0.42, f'{sample_val:.2f}',
ha='center', va='bottom', fontsize=8, color=colors["ink"], weight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor=colors.get("gold", colors["success"]),
edgecolor=colors["charcoal"], alpha=0.92, linewidth=1))
# Configure axes
ax.set_yticks(y_positions)
# Format y-tick labels with discriminative feature highlighting
y_labels = []
for feat in features_with_constraints:
if feat in non_overlapping_features:
y_labels.append(f'★ {feat}')
else:
y_labels.append(feat)
ax.set_yticklabels(y_labels, fontsize=10)
# Color discriminative feature labels
for tick_label, feat in zip(ax.get_yticklabels(), features_with_constraints):
if feat in non_overlapping_features:
tick_label.set_color(colors["success"])
tick_label.set_weight('bold')
ax.tick_params(axis="y", length=0)
ax.set_xlim(x_min, x_max)
ax.set_xlabel('Feature value', fontsize=12, loc='left')
ax.axvline(x=0, color=colors["mid_gray"], linestyle=':', alpha=0.5, zorder=1)
_style_axes(ax, theme_context, grid_axis="x")
# Create legend
legend_elements = []
for class_idx, cname in enumerate(class_names):
color = class_colors_list[class_idx % len(class_colors_list)]
legend_elements.append(
mpatches.Patch(facecolor=color, edgecolor=color, alpha=0.3,
linewidth=2, label=f'{cname} Constraints')
)
if original_sample:
legend_elements.append(
Line2D([0], [0], marker='o', color='w', markerfacecolor='black',
markeredgecolor='white', markersize=10, label='Original Sample')
)
if non_overlapping_features:
legend_elements.append(
mpatches.Patch(facecolor=colors.get("gold", colors["success"]), alpha=0.2,
label=f'★ Non-overlapping ({len(non_overlapping_features)} features)')
)
legend = ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
_style_legend(legend, theme_context)
# Title with class info
title_text = title
if original_class is not None and target_class is not None:
title_text += f'\nOriginal Class: {original_class} → Target Class: {target_class}'
ax.set_title(title_text, fontsize=14, weight='semibold', pad=10, color=colors["ink"])
# Add statistics subtitle
n_non_overlap = len(non_overlapping_features)
subtitle = f"Features: {n_features} | Non-overlapping: {n_non_overlap} | Classes: {n_classes}"
fig.text(0.5, 0.01, subtitle, ha='center', fontsize=10, style='italic', color=colors["mid_gray"])
plt.tight_layout()
plt.subplots_adjust(bottom=0.08)
if output_path:
fig.savefig(output_path, bbox_inches='tight', dpi=150)
print(f"INFO: Saved DPG constraints overview to {output_path}")
return fig
[docs]
def parse_predicate_parts(label: str) -> Optional[Tuple[str, str, float]]:
"""Parse predicate labels like 'feature <= 1.23' or 'feature > 0.7'."""
match = _PREDICATE_PATTERN.search(str(label))
if not match:
return None
return match.group(1).strip(), match.group(2), float(match.group(3))
[docs]
def parse_feature_from_predicate(label: str) -> str:
"""Extract the feature name from a predicate label string.
Args:
label: Predicate string such as ``"petal_length <= 2.45"``.
Returns:
Feature name, or the original ``label`` string if parsing fails.
"""
parsed = parse_predicate_parts(label)
return parsed[0] if parsed else str(label)
def _feature_color_map(features: List[str]) -> Dict[str, Any]:
return feature_color_map(features)
[docs]
def lrc_predicate_scores(explanation, top_k: int = 10) -> Any:
"""Return top-k predicate rows ranked by Local reaching centrality."""
nm = explanation.node_metrics.copy()
mask = (
nm["Label"].astype(str).str.contains("<=", regex=False, na=False)
| nm["Label"].astype(str).str.contains(">", regex=False, na=False)
)
nm = nm[mask].sort_values("Local reaching centrality", ascending=False).head(top_k)
rows = []
for _, row in nm.iterrows():
parsed = parse_predicate_parts(row["Label"])
if not parsed:
continue
feature, operator, threshold = parsed
rows.append(
{
"predicate": str(row["Label"]),
"feature": feature,
"operator": operator,
"threshold": threshold,
"lrc": float(row["Local reaching centrality"]),
}
)
return pd.DataFrame(rows)
[docs]
def plot_lrc_vs_rf_importance(
explanation,
model,
X_df: Any,
top_k: int = 10,
dataset_name: str = "Dataset",
save_path: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Any:
"""
Compare top LRC predicates and top RF feature importances side-by-side.
Returns:
Matplotlib figure.
"""
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
top_lrc = lrc_predicate_scores(explanation, top_k=top_k).copy()
if top_lrc.empty:
raise ValueError("No predicate labels available to compute LRC scores.")
if not hasattr(model, "feature_importances_"):
raise ValueError("Model must expose feature_importances_.")
top_rf = (
pd.DataFrame(
{
"feature": list(getattr(model, "feature_names_in_", X_df.columns)),
"rf_importance": np.asarray(model.feature_importances_, dtype=float),
}
)
.sort_values("rf_importance", ascending=False)
.head(top_k)
)
top_lrc_plot = top_lrc.sort_values("lrc", ascending=True)
top_rf_plot = top_rf.sort_values("rf_importance", ascending=True)
all_features = top_lrc_plot["feature"].tolist() + top_rf_plot["feature"].tolist()
feature_to_color = theme_context["feature_color_map"](all_features)
fig, axes = plt.subplots(1, 2, figsize=(16, max(5, top_k * 0.45)))
fig.patch.set_facecolor(colors["paper"])
axes[0].barh(
top_lrc_plot["predicate"],
top_lrc_plot["lrc"],
color=[feature_to_color[f] for f in top_lrc_plot["feature"]],
edgecolor=colors["paper"],
linewidth=0.8,
)
axes[0].set_title(f"{dataset_name}: Top {top_k} LRC predicates")
axes[0].set_xlabel("Local reaching centrality")
axes[0].set_ylabel("Predicate")
_style_axes(axes[0], theme_context, grid_axis="x")
axes[1].barh(
top_rf_plot["feature"],
top_rf_plot["rf_importance"],
color=[feature_to_color[f] for f in top_rf_plot["feature"]],
edgecolor=colors["paper"],
linewidth=0.8,
)
axes[1].set_title(f"{dataset_name}: Top {top_k} RF feature importances")
axes[1].set_xlabel("Random Forest feature importance")
axes[1].set_ylabel("Feature")
_style_axes(axes[1], theme_context, grid_axis="x")
legend_features = list(dict.fromkeys(all_features))
legend_handles = [
Line2D(
[0],
[0],
marker="s",
color="w",
label=feature,
markerfacecolor=feature_to_color[feature],
markeredgecolor=colors["charcoal"],
markersize=8,
)
for feature in legend_features
]
legend = fig.legend(
handles=legend_handles,
title="Feature colors",
loc="lower center",
ncol=min(4, max(1, len(legend_handles))),
frameon=True,
)
_style_legend(legend, theme_context)
plt.tight_layout(rect=(0, 0.08, 1, 1))
if save_path is not None:
fig.savefig(save_path, dpi=200, bbox_inches="tight")
if show:
plt.show()
else:
plt.close(fig)
return fig
[docs]
def plot_lec_vs_rf_importance(*args, **kwargs) -> Any:
"""
Backward-compatible alias for a common typo.
Use `plot_lrc_vs_rf_importance` instead.
"""
warnings.warn(
"plot_lec_vs_rf_importance is deprecated; use plot_lrc_vs_rf_importance.",
DeprecationWarning,
stacklevel=2,
)
return plot_lrc_vs_rf_importance(*args, **kwargs)
[docs]
def plot_top_lrc_predicate_splits(
explanation,
X_df: Any,
y,
top_predicates: int = 5,
top_features: int = 2,
dataset_name: str = "Dataset",
class_names: Optional[Any] = None,
save_path: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Optional[Any]:
"""
Scatter the top-2 LRC features and overlay top-LRC predicate split lines.
Returns:
Matplotlib figure, or None when top features cannot be resolved.
"""
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
top_lrc = lrc_predicate_scores(explanation, top_k=max(top_predicates, 10)).copy()
top_pred = top_lrc.sort_values("lrc", ascending=False).head(top_predicates).copy()
feature_rank = (
top_lrc.groupby("feature", as_index=False)["lrc"]
.sum()
.sort_values("lrc", ascending=False)
.head(top_features)
)
selected_features = feature_rank["feature"].tolist()
if len(selected_features) < 2:
return None
fx, fy = selected_features[0], selected_features[1]
if fx not in X_df.columns or fy not in X_df.columns:
return None
split_rows = top_pred[top_pred["feature"].isin([fx, fy])].copy()
fig, ax = plt.subplots(figsize=(8, 6))
y_series = pd.Series(np.asarray(y))
y_numeric = pd.to_numeric(y_series, errors="coerce")
class_codes = None
class_labels: List[str] = []
if y_numeric.notna().all():
# Keep class colors discrete even when labels are numeric.
unique_classes = sorted(y_numeric.unique().tolist())
class_to_code = {cls: idx for idx, cls in enumerate(unique_classes)}
class_codes = y_numeric.map(class_to_code).to_numpy(dtype=int)
for cls in unique_classes:
cls_idx = int(cls) if float(cls).is_integer() else cls
if class_names is None:
class_labels.append(str(cls_idx))
elif isinstance(class_names, dict):
class_labels.append(str(class_names.get(cls_idx, cls_idx)))
else:
if isinstance(cls_idx, int) and 0 <= cls_idx < len(class_names):
class_labels.append(str(class_names[cls_idx]))
else:
class_labels.append(str(cls_idx))
else:
# Support string/categorical class labels (e.g., "F1", "F2") in scatter coloring.
class_codes, unique_labels = pd.factorize(y_series.astype(str), sort=True)
class_labels = [str(label) for label in unique_labels]
if not class_labels:
class_labels = ["unknown"]
class_codes = np.zeros(len(y_series), dtype=int)
n_classes = len(class_labels)
class_map = theme_context["class_cmap"](n_classes)
fig.patch.set_facecolor(colors["paper"])
ax.set_facecolor(colors["paper"])
ax.scatter(
X_df[fx],
X_df[fy],
c=class_codes,
cmap=class_map,
s=42,
alpha=0.82,
edgecolor="#FFFDF7",
linewidth=0.7,
zorder=2,
)
feature_to_color = theme_context["predicate_line_color_map"]([fx, fy])
labels_seen = set()
for _, row in split_rows.iterrows():
feature = row["feature"]
operator = row["operator"]
threshold = row["threshold"]
score = row["lrc"]
linestyle = "--" if operator == "<=" else "-"
label = f"{feature} {operator} {threshold:.2f} (LRC={score:.3f})"
if feature == fx:
ax.axvline(
threshold,
color=feature_to_color[feature],
linestyle=linestyle,
linewidth=2.4,
alpha=0.9,
label=label if label not in labels_seen else None,
)
labels_seen.add(label)
elif feature == fy:
ax.axhline(
threshold,
color=feature_to_color[feature],
linestyle=linestyle,
linewidth=2.4,
alpha=0.9,
label=label if label not in labels_seen else None,
)
labels_seen.add(label)
ax.set_title(f"{dataset_name}: Top-{top_predicates} LRC predicate splits")
ax.set_xlabel(fx)
ax.set_ylabel(fy)
_style_axes(ax, theme_context, grid_axis="both")
class_handles = [
Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor=class_map(i),
markeredgecolor="#FFFDF7",
markeredgewidth=0.7,
markersize=7,
linestyle="None",
alpha=0.85,
label=class_labels[i],
)
for i in range(n_classes)
]
class_legend = ax.legend(
handles=class_handles,
title="Class",
loc="upper left",
fontsize=8,
frameon=True,
)
_style_legend(class_legend, theme_context)
ax.add_artist(class_legend)
handles, labels = ax.get_legend_handles_labels()
if handles:
predicate_legend = ax.legend(handles, labels, title="Top LRC predicate lines", loc="lower right", fontsize=8)
_style_legend(predicate_legend, theme_context)
plt.tight_layout()
if save_path is not None:
fig.savefig(save_path, dpi=200, bbox_inches="tight")
if show:
plt.show()
else:
plt.close(fig)
return fig
[docs]
def sample_bc_weights(
explanation,
X_df: Any,
top_k: int = 10,
) -> Any:
"""
Compute a per-sample bottleneck exposure score from top-BC predicates.
This is not a graph-theoretic BC computed on samples themselves. Instead,
each sample receives the sum of the betweenness-centrality scores of the
top-k predicate nodes it satisfies:
weight(sample_i) = sum_p 1[sample_i satisfies predicate_p] * BC(predicate_p)
Args:
explanation: Global DPG explanation containing ``node_metrics``.
X_df: Feature matrix as a pandas DataFrame.
top_k: Number of highest-BC predicate nodes to include.
Returns:
pandas Series indexed like ``X_df`` with one BC-derived weight per sample.
"""
if not hasattr(X_df, "columns"):
raise ValueError("X_df must be a pandas DataFrame with named columns.")
nm = explanation.node_metrics.copy()
mask = (
nm["Label"].astype(str).str.contains("<=", regex=False, na=False)
| nm["Label"].astype(str).str.contains(">", regex=False, na=False)
)
top_bc = nm[mask].sort_values("Betweenness centrality", ascending=False).head(top_k)
weights = np.zeros(len(X_df), dtype=float)
for _, row in top_bc.iterrows():
parsed = parse_predicate_parts(row["Label"])
if not parsed:
continue
feature, operator, threshold = parsed
if feature not in X_df.columns:
continue
if operator == "<=":
active = X_df[feature].to_numpy() <= threshold
else:
active = X_df[feature].to_numpy() > threshold
weights += active.astype(float) * float(row["Betweenness centrality"])
return pd.Series(weights, index=X_df.index, name="bc_weight")
[docs]
def plot_sample_using_bc_weights(
explanation,
X_df: Any,
y,
top_k: int = 10,
dataset_name: str = "Dataset",
class_names: Optional[Any] = None,
save_path: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Any:
"""
Plot samples in PCA space with marker size driven by BC-derived weights.
Marker size reflects the sum of BC scores from the top-k predicate nodes
satisfied by each sample. Large points therefore indicate samples that
activate more bottleneck predicates in the DPG.
"""
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
from sklearn.decomposition import PCA
if not hasattr(X_df, "columns"):
raise ValueError("X_df must be a pandas DataFrame with named columns.")
bc_w = sample_bc_weights(explanation=explanation, X_df=X_df, top_k=top_k)
bc_w_norm = (bc_w - bc_w.min()) / (bc_w.max() - bc_w.min() + 1e-12)
pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(X_df)
y_series = pd.Series(np.asarray(y), index=X_df.index if len(X_df) == len(y) else None)
y_numeric = pd.to_numeric(y_series, errors="coerce")
class_codes = None
class_labels: List[str] = []
if y_numeric.notna().all():
unique_classes = sorted(y_numeric.unique().tolist())
class_to_code = {cls: idx for idx, cls in enumerate(unique_classes)}
class_codes = y_numeric.map(class_to_code).to_numpy(dtype=int)
for cls in unique_classes:
cls_idx = int(cls) if float(cls).is_integer() else cls
if class_names is None:
class_labels.append(str(cls_idx))
elif isinstance(class_names, dict):
class_labels.append(str(class_names.get(cls_idx, cls_idx)))
else:
if isinstance(cls_idx, int) and 0 <= cls_idx < len(class_names):
class_labels.append(str(class_names[cls_idx]))
else:
class_labels.append(str(cls_idx))
else:
class_codes, unique_labels = pd.factorize(y_series.astype(str), sort=True)
class_labels = [str(label) for label in unique_labels]
if not class_labels:
class_labels = ["unknown"]
class_codes = np.zeros(len(y_series), dtype=int)
n_classes = len(class_labels)
class_map = theme_context["class_cmap"](n_classes)
fig, ax = plt.subplots(figsize=(8, 6))
fig.patch.set_facecolor(colors["paper"])
ax.set_facecolor(colors["paper"])
ax.scatter(
X_pca[:, 0],
X_pca[:, 1],
c=class_codes,
cmap=class_map,
s=36 + 132 * bc_w_norm.to_numpy(),
alpha=0.78,
edgecolor="#FFFDF7",
linewidth=0.7,
)
legend_handles = [
Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor=class_map(i),
markersize=8,
label=label,
)
for i, label in enumerate(class_labels)
]
legend_handles.append(
Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor=colors["success"],
markersize=12,
label="large = high BC-derived weight",
)
)
legend = ax.legend(handles=legend_handles, loc="best", fontsize=9, frameon=True)
_style_legend(legend, theme_context)
ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0] * 100:.1f}%)")
ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1] * 100:.1f}%)")
ax.set_title(f"PCA - BC bottleneck weight per sample\n{dataset_name}")
_style_axes(ax, theme_context, grid_axis="both")
formula = "weight(i) = sum 1[predicate active on i] * BC(predicate)"
ax.text(
0.01,
0.01,
formula,
transform=ax.transAxes,
ha="left",
va="bottom",
fontsize=9,
color=colors["charcoal"],
bbox={"boxstyle": "round,pad=0.25", "facecolor": "#FFFDF7", "alpha": 0.92, "edgecolor": colors["light_gray"]},
)
fig.tight_layout()
if save_path is not None:
fig.savefig(save_path, dpi=200, bbox_inches="tight")
if show:
plt.show()
else:
plt.close(fig)
return fig
def _resolve_named_class_palette(
observed_labels: List[str],
class_names: Optional[Any],
theme_context: Dict[str, Any],
) -> Tuple[List[str], Dict[str, Any]]:
"""Resolve class order and colors using the same palette logic as PCA plots."""
observed = [str(label) for label in observed_labels]
ordered_labels: List[str]
if class_names is None:
ordered_labels = list(dict.fromkeys(observed))
elif isinstance(class_names, dict):
ordered_labels = [
str(value)
for _, value in sorted(class_names.items(), key=lambda item: item[0])
if str(value) in observed
]
else:
ordered_labels = [str(value) for value in class_names if str(value) in observed]
for label in observed:
if label not in ordered_labels:
ordered_labels.append(label)
class_map = theme_context["class_cmap"](max(1, len(ordered_labels)))
color_map = {label: class_map(i) for i, label in enumerate(ordered_labels)}
return ordered_labels, color_map
def _resolve_graph_node(graph: nx.DiGraph, candidate):
if candidate in graph:
return candidate
candidate_str = str(candidate)
for node in graph.nodes:
if str(node) == candidate_str:
return node
return None
def _normalize_class_label(label: Any) -> str:
text = str(label)
if text.startswith("Class "):
return text.replace("Class ", "", 1)
return text
def _community_specs(explanation, graph: nx.DiGraph, node_df: Any) -> List[Dict[str, Any]]:
communities = getattr(explanation, "communities", None)
if not communities:
return []
raw_specs = []
if isinstance(communities, dict) and "Clusters" in communities:
for key, members in communities.get("Clusters", {}).items():
class_name = _normalize_class_label(key)
if class_name.lower() == "ambiguous":
class_name = None
raw_specs.append({"class_name": class_name, "members": members})
elif isinstance(communities, dict) and "Communities" in communities:
for members in communities.get("Communities", []):
raw_specs.append({"class_name": None, "members": members})
label_to_nodes: Dict[str, List[Any]] = {}
for _, row in node_df.iterrows():
label_to_nodes.setdefault(str(row["Label"]), []).append(row["Node"])
output = []
for idx, spec in enumerate(raw_specs):
resolved = set()
for item in spec["members"]:
node = _resolve_graph_node(graph, item)
if node is not None:
resolved.add(node)
continue
for candidate in label_to_nodes.get(str(item), []):
node_candidate = _resolve_graph_node(graph, candidate)
if node_candidate is not None:
resolved.add(node_candidate)
if resolved:
output.append(
{
"community_id": idx,
"class_name": spec["class_name"],
"nodes": resolved,
}
)
return output
def _class_nodes_map(explanation) -> Dict[Any, str]:
node_df = explanation.node_metrics.copy()
graph = getattr(explanation, "graph", None)
if graph is None:
raise ValueError("explanation.graph is required")
class_df = node_df[node_df["Label"].astype(str).str.startswith("Class ")].copy()
class_nodes = {}
for _, row in class_df.iterrows():
node = _resolve_graph_node(graph, row["Node"])
if node is not None:
class_nodes[node] = str(row["Label"]).replace("Class ", "", 1)
return class_nodes
def _predicate_node_lookup(explanation) -> Dict[Any, Tuple[str, str, float]]:
node_df = explanation.node_metrics.copy()
graph = getattr(explanation, "graph", None)
if graph is None:
raise ValueError("explanation.graph is required")
pred_df = node_df.copy()
pred_df["parsed"] = pred_df["Label"].apply(parse_predicate_parts)
pred_df = pred_df[pred_df["parsed"].notna()].copy()
lookup: Dict[Any, Tuple[str, str, float]] = {}
for _, row in pred_df.iterrows():
node = _resolve_graph_node(graph, row["Node"])
if node is None:
continue
feature, operator, threshold = row["parsed"]
lookup[node] = (str(feature), str(operator), float(threshold))
return lookup
[docs]
def class_feature_predicate_counts(explanation) -> Any:
"""
Compute class-vs-feature predicate frequency table from DPG communities.
Returns:
DataFrame indexed by class, columns as features, values as counts.
"""
node_df = explanation.node_metrics.copy()
graph = getattr(explanation, "graph", None)
if graph is None:
raise ValueError("explanation.graph is required for class-path analysis")
if "Node" not in node_df.columns or "Label" not in node_df.columns:
raise ValueError("node_metrics must contain Node and Label columns")
class_nodes = _class_nodes_map(explanation)
pred_lookup = _predicate_node_lookup(explanation)
comm_specs = _community_specs(explanation, graph, node_df)
if not comm_specs:
comm_specs = [{"community_id": 0, "class_name": None, "nodes": set(pred_lookup.keys())}]
class_feature_counts: Dict[str, List[str]] = {}
for spec in comm_specs:
class_from_cluster = spec["class_name"]
for node in spec["nodes"]:
if node not in pred_lookup:
continue
feature, _, _ = pred_lookup[node]
if class_from_cluster is not None:
target_classes = [str(class_from_cluster)]
else:
descendants = nx.descendants(graph, node)
target_classes = [class_nodes[c] for c in class_nodes if c in descendants]
for cls in target_classes:
class_feature_counts.setdefault(cls, []).append(feature)
if not class_feature_counts:
return pd.DataFrame()
series_map = {k: pd.Series(v).value_counts() for k, v in class_feature_counts.items() if v}
if not series_map:
return pd.DataFrame()
heat = pd.DataFrame(series_map).T.fillna(0).astype(int)
heat = heat.loc[:, heat.sum(axis=0).sort_values(ascending=False).index]
return heat
[docs]
def plot_class_feature_complexity(
heat_df: Any,
dataset_name: str = "Dataset",
class_names: Optional[Any] = None,
top_n_features: int = 10,
save_prefix: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Tuple[Any, Any]:
"""
Plot class-feature predicate complexity with PCA-consistent class colors.
Produces:
- ``*_heatmap.png``: heatmap with class color strip and class-colored labels
- ``*_bars.png``: grouped bar chart with one color per class
"""
if heat_df is None or getattr(heat_df, "empty", True):
raise ValueError("heat_df must be a non-empty class-feature count DataFrame.")
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
h = heat_df.copy()
h.index = h.index.map(str)
h.columns = h.columns.map(str)
h = h.loc[:, h.sum(axis=0).sort_values(ascending=False).index]
if top_n_features is not None and int(top_n_features) > 0:
h = h.iloc[:, : int(top_n_features)]
class_order, class_color_map = _resolve_named_class_palette(
observed_labels=h.index.tolist(),
class_names=class_names,
theme_context=theme_context,
)
ordered_rows = [label for label in class_order if label in h.index]
h = h.loc[ordered_rows]
fig_heat = plt.figure(figsize=(max(8.5, 1.2 * len(h.columns) + 3.8), max(4.5, 1.0 * len(h.index) + 1.8)))
fig_heat.patch.set_facecolor(colors["paper"])
gs = fig_heat.add_gridspec(1, 2, width_ratios=[0.18, 6.0], wspace=0.08)
ax_strip = fig_heat.add_subplot(gs[0, 0])
ax_heat = fig_heat.add_subplot(gs[0, 1])
strip_rgba = np.array([[mcolors.to_rgba(class_color_map[label])] for label in h.index], dtype=float)
ax_strip.imshow(strip_rgba, aspect="auto")
ax_strip.set_xticks([])
ax_strip.set_yticks([])
for spine in ax_strip.spines.values():
spine.set_visible(False)
im = ax_heat.imshow(h.to_numpy(dtype=float), aspect="auto", cmap=theme_context["sequential_cmap"])
ax_heat.set_xticks(np.arange(len(h.columns)))
ax_heat.set_xticklabels(h.columns, rotation=35, ha="right")
ax_heat.set_yticks(np.arange(len(h.index)))
ax_heat.set_yticklabels(h.index, rotation=90, va="center")
for tick in ax_heat.get_yticklabels():
tick.set_color(class_color_map[str(tick.get_text())])
tick.set_fontweight("semibold")
ax_heat.set_xlabel("Feature")
ax_heat.set_ylabel("Class")
ax_heat.set_title(f"{dataset_name}: Community class-feature complexity", color=colors["ink"], fontsize=13, fontweight="semibold")
ax_heat.set_facecolor(colors["paper"])
ax_heat.set_xticks(np.arange(-0.5, len(h.columns), 1), minor=True)
ax_heat.set_yticks(np.arange(-0.5, len(h.index), 1), minor=True)
ax_heat.grid(which="minor", color=colors["paper"], linestyle="-", linewidth=1.2)
ax_heat.tick_params(which="minor", bottom=False, left=False)
for row_idx in range(len(h.index)):
for col_idx in range(len(h.columns)):
value = int(h.iat[row_idx, col_idx])
text_color = colors["paper"] if im.norm(value) > 0.55 else colors["charcoal"]
ax_heat.text(col_idx, row_idx, str(value), ha="center", va="center", fontsize=9, color=text_color)
cbar = fig_heat.colorbar(im, ax=ax_heat, fraction=0.046, pad=0.04)
cbar.set_label("Predicate count")
cbar.outline.set_edgecolor(colors["light_gray"])
_style_figure(fig_heat, theme_context)
fig_heat.subplots_adjust(left=0.12, right=0.94, bottom=0.22, top=0.88, wspace=0.08)
n_classes = len(h.index)
x = np.arange(len(h.columns), dtype=float)
total_width = 0.82
bar_width = total_width / max(1, n_classes)
offsets = (np.arange(n_classes) - (n_classes - 1) / 2.0) * bar_width
fig_bar, ax_bar = plt.subplots(
figsize=(max(8.8, 1.2 * len(h.columns) + 3.6), max(4.8, 0.75 * n_classes + 3.2))
)
fig_bar.patch.set_facecolor(colors["paper"])
ax_bar.set_facecolor(colors["paper"])
for idx, class_label in enumerate(h.index):
class_color = class_color_map[str(class_label)]
ax_bar.bar(
x + offsets[idx],
h.loc[class_label].to_numpy(dtype=float),
width=bar_width * 0.92,
color=class_color,
edgecolor=colors["paper"],
linewidth=0.7,
label=str(class_label),
alpha=0.95,
)
ax_bar.set_xticks(x)
ax_bar.set_xticklabels(h.columns, rotation=35, ha="right")
ax_bar.set_ylabel("Predicate count")
ax_bar.set_xlabel("Feature")
ax_bar.set_title(f"{dataset_name}: Community class-feature complexity by class", color=colors["ink"], fontsize=13, fontweight="semibold")
_style_axes(ax_bar, theme_context, grid_axis="y")
legend = ax_bar.legend(loc="best", fontsize=9, frameon=True, title="Class")
_style_legend(legend, theme_context)
for text in legend.get_texts():
label = str(text.get_text())
if label in class_color_map:
text.set_color(class_color_map[label])
_style_figure(fig_bar, theme_context)
fig_bar.tight_layout()
if save_prefix is not None:
fig_heat.savefig(f"{save_prefix}_heatmap.png", dpi=200, bbox_inches="tight")
fig_bar.savefig(f"{save_prefix}_bars.png", dpi=200, bbox_inches="tight")
if show:
plt.show()
else:
plt.close(fig_heat)
plt.close(fig_bar)
return fig_heat, fig_bar
[docs]
def classwise_feature_bounds_from_communities(explanation) -> Any:
"""Build per-class, per-community finite/unbounded feature ranges from predicates."""
node_df = explanation.node_metrics.copy()
graph = getattr(explanation, "graph", None)
if graph is None:
raise ValueError("explanation.graph is required")
class_nodes = _class_nodes_map(explanation)
pred_lookup = _predicate_node_lookup(explanation)
comm_specs = _community_specs(explanation, graph, node_df)
if not comm_specs:
comm_specs = [{"community_id": 0, "class_name": None, "nodes": set(pred_lookup.keys())}]
bucket: Dict[Tuple[str, int, str], Dict[str, List[float]]] = {}
for spec in comm_specs:
community_id = int(spec["community_id"])
class_from_cluster = spec["class_name"]
for node in spec["nodes"]:
if node not in pred_lookup:
continue
feature, operator, threshold = pred_lookup[node]
if class_from_cluster is not None:
target_classes = [str(class_from_cluster)]
else:
descendants = nx.descendants(graph, node)
target_classes = [class_nodes[c] for c in class_nodes if c in descendants]
if not target_classes:
continue
for cls in target_classes:
key = (cls, community_id, feature)
bucket.setdefault(key, {"gt": [], "le": [], "all": []})
if operator == ">":
bucket[key]["gt"].append(threshold)
elif operator == "<=":
bucket[key]["le"].append(threshold)
bucket[key]["all"].append(threshold)
rows = []
for (cls, community_id, feature), values in bucket.items():
lower = min(values["gt"]) if values["gt"] else float("-inf")
upper = max(values["le"]) if values["le"] else float("inf")
if lower > upper:
lower = min(values["all"]) if values["all"] else float("-inf")
upper = max(values["all"]) if values["all"] else float("inf")
width = (upper - lower) if (np.isfinite(lower) and np.isfinite(upper)) else np.nan
rows.append(
{
"class_name": cls,
"community_id": community_id,
"feature": feature,
"lower_bound": float(lower),
"upper_bound": float(upper),
"range_width": float(width) if pd.notna(width) else np.nan,
}
)
if not rows:
return pd.DataFrame(
columns=[
"class_name",
"community_id",
"feature",
"lower_bound",
"upper_bound",
"range_width",
]
)
return pd.DataFrame(rows)
[docs]
def class_feature_predicate_positions(explanation) -> Any:
"""
Collect raw predicate thresholds by class/feature/operator for density overlays.
"""
node_df = explanation.node_metrics.copy()
graph = getattr(explanation, "graph", None)
if graph is None:
raise ValueError("explanation.graph is required")
class_nodes = _class_nodes_map(explanation)
pred_lookup = _predicate_node_lookup(explanation)
comm_specs = _community_specs(explanation, graph, node_df)
if not comm_specs:
comm_specs = [{"community_id": 0, "class_name": None, "nodes": set(pred_lookup.keys())}]
rows = []
for spec in comm_specs:
community_id = int(spec["community_id"])
class_from_cluster = spec["class_name"]
for node in spec["nodes"]:
if node not in pred_lookup:
continue
feature, operator, threshold = pred_lookup[node]
if class_from_cluster is not None:
target_classes = [str(class_from_cluster)]
else:
descendants = nx.descendants(graph, node)
target_classes = [class_nodes[c] for c in class_nodes if c in descendants]
for cls in target_classes:
rows.append(
{
"class_name": cls,
"community_id": community_id,
"feature": feature,
"operator": operator,
"threshold": threshold,
}
)
if not rows:
return pd.DataFrame(columns=["class_name", "community_id", "feature", "operator", "threshold"])
return pd.DataFrame(rows)
def _aggregate_close_positions(values, tol: float):
vals = np.sort(np.asarray(values, dtype=float))
if vals.size == 0:
return []
groups = [[vals[0]]]
for value in vals[1:]:
if abs(value - groups[-1][-1]) <= tol:
groups[-1].append(value)
else:
groups.append([value])
return [(float(np.mean(group)), len(group)) for group in groups]
[docs]
def class_lookup_from_target_names(target_names: Optional[List[str]]) -> Dict[str, int]:
"""Build a class-name to class-index mapping from a target names list.
Args:
target_names: Ordered list of class name strings (e.g. from
``sklearn.LabelEncoder.classes_``), or ``None``.
Returns:
Dict mapping class name string to integer index, or an empty dict when
``target_names`` is ``None``.
"""
if target_names is None:
return {}
return {str(name): i for i, name in enumerate(list(target_names))}
def _class_mask(class_name: str, y, class_lookup: Optional[Dict[str, int]] = None):
y_arr = np.asarray(y)
# First try direct label matching; this supports string labels (e.g., "F1")
# even when a class_lookup mapping is provided.
direct_mask = pd.Series(y_arr).astype(str).values == str(class_name)
if np.any(direct_mask):
return direct_mask
# Fallback to lookup mapping (e.g., class name -> integer id).
if class_lookup and str(class_name) in class_lookup:
mapped_mask = y_arr == class_lookup[str(class_name)]
if np.any(mapped_mask):
return mapped_mask
try:
as_int = int(class_name)
return y_arr == as_int
except Exception:
pass
return direct_mask
[docs]
def dataset_feature_bounds_by_class(
X_df: Any,
y,
class_names: List[str],
class_lookup: Optional[Dict[str, int]] = None,
) -> Any:
"""Compute empirical per-class min/max ranges for every feature in ``X_df``.
Args:
X_df: Feature matrix with named columns.
y: Target labels aligned with ``X_df`` rows.
class_names: List of class name strings to compute bounds for.
class_lookup: Optional mapping from class name to numeric label used in ``y``.
Returns:
DataFrame with columns ``['class_name', 'feature', 'ds_lower_bound', 'ds_upper_bound']``.
"""
rows = []
for cls in class_names:
mask = _class_mask(cls, y, class_lookup=class_lookup)
class_frame = X_df.loc[mask]
if class_frame.empty:
continue
for feature in X_df.columns:
rows.append(
{
"class_name": str(cls),
"feature": str(feature),
"ds_lower_bound": float(class_frame[feature].min()),
"ds_upper_bound": float(class_frame[feature].max()),
}
)
return pd.DataFrame(rows)
[docs]
def plot_dpg_class_bounds_vs_dataset_feature_ranges(
explanation,
X_df: Any,
y,
dataset_name: str = "Dataset",
top_features: int = 4,
feature_cols_per_row: int = 4,
class_lookup: Optional[Dict[str, int]] = None,
predicate_positions: Optional[Any] = None,
class_bounds: Optional[Any] = None,
class_filter: Optional[List[str]] = None,
density_tol_ratio: float = 0.03,
predicate_alpha: float = 0.55,
dataset_range_lw: float = 10,
save_path: Optional[str] = None,
show: bool = True,
theme: str = "dpg",
palette: str = "default",
) -> Optional[Any]:
"""
Plot DPG class bounds against empirical dataset ranges per feature.
Args:
explanation: DPGExplanation instance.
X_df: Feature dataframe.
y: Class labels aligned with X_df.
feature_cols_per_row: Number of feature panels per row block.
class_lookup: Optional class-name to class-id mapping.
predicate_positions: Optional precomputed output from class_feature_predicate_positions.
class_bounds: Optional precomputed output from classwise_feature_bounds_from_communities.
class_filter: Optional class or list of classes to render.
Returns:
Matplotlib figure, or None when no plottable classes exist.
"""
theme_context = resolve_theme_context(theme=theme, palette=palette)
colors = theme_context["colors"]
_apply_matplotlib_theme(theme_context)
if class_bounds is None:
class_bounds = classwise_feature_bounds_from_communities(explanation)
if class_bounds.empty:
return None
if predicate_positions is None:
predicate_positions = class_feature_predicate_positions(explanation)
dpg_bounds = (
class_bounds.groupby(["class_name", "feature"], as_index=False)
.agg(
lower_bound=("lower_bound", "min"),
upper_bound=("upper_bound", "max"),
community_support=("community_id", "nunique"),
)
)
dpg_bounds["range_width"] = np.where(
np.isfinite(dpg_bounds["lower_bound"]) & np.isfinite(dpg_bounds["upper_bound"]),
dpg_bounds["upper_bound"] - dpg_bounds["lower_bound"],
np.nan,
)
classes = sorted(dpg_bounds["class_name"].unique())
if class_filter is not None:
if isinstance(class_filter, (list, tuple, set, np.ndarray, pd.Series)):
allowed = {str(value) for value in class_filter}
else:
allowed = {str(class_filter)}
classes = [cls for cls in classes if str(cls) in allowed]
if not classes:
return None
ds_bounds = dataset_feature_bounds_by_class(X_df, y, classes, class_lookup=class_lookup)
if ds_bounds.empty:
return None
feature_rank = (
dpg_bounds.groupby("feature", as_index=False)
.agg(
total_support=("community_support", "sum"),
class_coverage=("class_name", "nunique"),
finite_width_mean=("range_width", "mean"),
)
.sort_values(
["total_support", "class_coverage", "finite_width_mean"],
ascending=[False, False, False],
)
.head(top_features)
)
selected_features = feature_rank["feature"].tolist()
if not selected_features:
return None
n_cols = max(1, min(int(feature_cols_per_row), len(selected_features)))
n_feature_blocks = int(np.ceil(len(selected_features) / n_cols))
# Insert one separator row between feature blocks to improve visual grouping.
n_rows = (len(classes) * n_feature_blocks) + max(0, n_feature_blocks - 1)
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=(
5.2 * n_cols,
(2.6 * len(classes) * n_feature_blocks) + (0.8 * max(0, n_feature_blocks - 1)),
),
squeeze=False,
)
fig.patch.set_facecolor(colors["paper"])
feature_axis_limits: Dict[str, Tuple[float, float]] = {}
for feat in selected_features:
ds_feat = ds_bounds[ds_bounds["feature"] == feat]
dpg_feat = dpg_bounds[dpg_bounds["feature"] == feat]
if ds_feat.empty and dpg_feat.empty:
continue
ds_min = float(ds_feat["ds_lower_bound"].min()) if not ds_feat.empty else np.inf
ds_max = float(ds_feat["ds_upper_bound"].max()) if not ds_feat.empty else -np.inf
feat_global_min = float(X_df[feat].min()) if feat in X_df.columns else ds_min
feat_global_max = float(X_df[feat].max()) if feat in X_df.columns else ds_max
dpg_lo = dpg_feat["lower_bound"].astype(float).to_numpy(copy=True)
dpg_hi = dpg_feat["upper_bound"].astype(float).to_numpy(copy=True)
finite_dpg_lo = dpg_lo[np.isfinite(dpg_lo)]
finite_dpg_hi = dpg_hi[np.isfinite(dpg_hi)]
dpg_min = float(finite_dpg_lo.min()) if finite_dpg_lo.size else ds_min
dpg_max = float(finite_dpg_hi.max()) if finite_dpg_hi.size else ds_max
x_min = max(0.0, min(ds_min, feat_global_min, dpg_min))
x_max = max(ds_max, feat_global_max, dpg_max)
pad = max((x_max - x_min) * 0.2, 1e-6)
feature_axis_limits[feat] = (max(0.0, x_min - pad), x_max + pad)
density_gt_labeled = False
density_le_labeled = False
for block_idx in range(n_feature_blocks):
start = block_idx * n_cols
end = min(len(selected_features), start + n_cols)
block_features = selected_features[start:end]
for r, cls in enumerate(classes):
class_ds = ds_bounds[ds_bounds["class_name"] == cls]
class_dpg = dpg_bounds[dpg_bounds["class_name"] == cls]
class_pred = (
predicate_positions[predicate_positions["class_name"] == cls]
if predicate_positions is not None and not predicate_positions.empty
else pd.DataFrame(columns=["feature", "operator", "threshold"])
)
row_idx = block_idx * (len(classes) + 1) + r
for c in range(n_cols):
ax = axes[row_idx, c]
if c >= len(block_features):
ax.axis("off")
continue
feat = block_features[c]
left_lim, right_lim = feature_axis_limits.get(feat, (0.0, 1.0))
ds_row = class_ds[class_ds["feature"] == feat]
dpg_row = class_dpg[class_dpg["feature"] == feat]
y0 = 0.0
if not ds_row.empty:
ds_lo = float(ds_row["ds_lower_bound"].iloc[0])
ds_hi = float(ds_row["ds_upper_bound"].iloc[0])
ax.hlines(
y0,
ds_lo,
ds_hi,
color=colors["range_fill"],
linewidth=dataset_range_lw,
alpha=0.85,
label="dataset class range" if (row_idx == 0 and c == 0) else None,
)
ax.scatter(
[ds_lo, ds_hi],
[y0, y0],
color=colors["range_marker"],
s=28,
label="dataset min/max" if (row_idx == 0 and c == 0) else None,
)
if not dpg_row.empty:
dpg_lo = float(dpg_row["lower_bound"].iloc[0])
dpg_hi = float(dpg_row["upper_bound"].iloc[0])
lo_inf = not np.isfinite(dpg_lo)
hi_inf = not np.isfinite(dpg_hi)
draw_lo = left_lim if lo_inf else dpg_lo
draw_hi = right_lim if hi_inf else dpg_hi
ax.hlines(
y0,
draw_lo,
draw_hi,
color=_class_fill_color(theme_context),
linewidth=3,
alpha=0.95,
label="DPG community range" if (row_idx == 0 and c == 0) else None,
)
if np.isfinite(dpg_lo):
ax.scatter(
[dpg_lo],
[y0],
color=colors["success"],
s=38,
label="DPG min bound" if (row_idx == 0 and c == 0) else None,
)
if np.isfinite(dpg_hi):
ax.scatter(
[dpg_hi],
[y0],
color=colors["danger"],
s=38,
label="DPG max bound" if (row_idx == 0 and c == 0) else None,
)
if lo_inf:
ax.scatter(
[left_lim],
[y0],
marker="<",
color=colors["success"],
s=70,
label="DPG min = -inf" if (row_idx == 0 and c == 0) else None,
)
if hi_inf:
ax.scatter(
[right_lim],
[y0],
marker=">",
color=colors["danger"],
s=70,
label="DPG max = +inf" if (row_idx == 0 and c == 0) else None,
)
if not class_pred.empty:
pred_feature = class_pred[class_pred["feature"] == feat]
tol = max((right_lim - left_lim) * float(density_tol_ratio), 1e-9)
vals_gt = pred_feature.loc[pred_feature["operator"] == ">", "threshold"].astype(float).to_numpy()
vals_le = pred_feature.loc[pred_feature["operator"] == "<=", "threshold"].astype(float).to_numpy()
vals_gt = vals_gt[(vals_gt >= left_lim) & (vals_gt <= right_lim)]
vals_le = vals_le[(vals_le >= left_lim) & (vals_le <= right_lim)]
dense_gt = _aggregate_close_positions(vals_gt, tol) if vals_gt.size else []
dense_le = _aggregate_close_positions(vals_le, tol) if vals_le.size else []
if dense_gt:
xs = np.array([d[0] for d in dense_gt], dtype=float)
counts = np.array([d[1] for d in dense_gt], dtype=float)
sizes = 14 + 16 * np.sqrt(counts)
ax.scatter(
xs,
np.full_like(xs, y0 + 0.10, dtype=float),
s=sizes,
marker="^",
c=colors["success"],
alpha=predicate_alpha,
edgecolors=colors["paper"],
linewidths=0.35,
label="predicate density (>)" if not density_gt_labeled else None,
zorder=4,
)
density_gt_labeled = True
if dense_le:
xs = np.array([d[0] for d in dense_le], dtype=float)
counts = np.array([d[1] for d in dense_le], dtype=float)
sizes = 14 + 16 * np.sqrt(counts)
ax.scatter(
xs,
np.full_like(xs, y0 - 0.10, dtype=float),
s=sizes,
marker="v",
c=colors["danger"],
alpha=predicate_alpha,
edgecolors=colors["paper"],
linewidths=0.35,
label="predicate density (<=)" if not density_le_labeled else None,
zorder=4,
)
density_le_labeled = True
ax.set_xlim(left_lim, right_lim)
ax.set_ylim(-0.35, 0.35)
ax.set_yticks([])
_style_axes(ax, theme_context, grid_axis="x")
if r == 0:
ax.set_title(str(feat), color=colors["ink"], fontsize=12, fontweight="semibold")
is_bottom_class_row = r == len(classes) - 1
ax.tick_params(axis="x", labelbottom=is_bottom_class_row)
if not is_bottom_class_row:
ax.set_xticklabels([])
if is_bottom_class_row:
ax.set_xlabel("Feature value")
if c == 0:
ax.set_ylabel(f"Class {cls}", color=colors["charcoal"])
# Hide separator row axes between feature blocks.
if block_idx < n_feature_blocks - 1:
sep_row = block_idx * (len(classes) + 1) + len(classes)
for c in range(n_cols):
axes[sep_row, c].axis("off")
handles, labels = axes[0, 0].get_legend_handles_labels()
if handles:
legend = fig.legend(handles, labels, loc="lower center", ncol=3, frameon=True)
_style_legend(legend, theme_context)
_style_figure(fig, theme_context, title=(
f"{dataset_name}: DPG vs dataset ranges "
f"(rows=classes, features/row={n_cols})"
))
plt.tight_layout(rect=(0, 0.10, 1, 1))
if save_path is not None:
fig.savefig(save_path, dpi=200, bbox_inches="tight")
if show:
plt.show()
else:
plt.close(fig)
return fig
[docs]
def change_edge_color(graph, source_id, target_id, new_color, new_width):
"""
Changes the color and dimension (penwidth) of a specified edge in the Graphviz Digraph.
Args:
graph: A Graphviz Digraph object.
source_id: The source node of the edge.
target_id: The target node of the edge.
new_color: The new color to be applied to the edge.
new_width: The new penwidth (edge thickness) to be applied.
Returns:
None
"""
# Look for the existing edge in the graph body
for i, line in enumerate(graph.body):
if f'{source_id} -> {target_id}' in line:
# Modify the existing edge attributes to include both color and penwidth
new_line = line.rstrip().replace(']', f' color="{new_color}" penwidth="{new_width}"]')
graph.body[i] = new_line
break