Source code for dpg.sklearn_normalizer
"""
Normalizer for scikit-learn ensemble models.
Handles differences in tree storage between RandomForest, GradientBoosting,
AdaBoost, and other ensemble methods to provide a consistent interface for DPG.
"""
import copy
import numpy as np
from sklearn.ensemble import (
GradientBoostingClassifier,
GradientBoostingRegressor,
)
[docs]
class SklearnEnsembleNormalizer:
"""
Normalizes sklearn ensemble models to have consistent tree access interface.
Problem: Different ensemble methods store trees differently:
- RandomForest: estimators_ is 1D list of DecisionTree objects
- GradientBoosting: estimators_ is 2D (n_estimators, n_trees_per_iteration)
- AdaBoost: estimators_ is 1D list of DecisionTree objects
Solution: Normalize to 1D list for consistent access.
"""
GB_MODELS = (GradientBoostingClassifier, GradientBoostingRegressor)
[docs]
@staticmethod
def needs_normalization(model):
"""
Check if a model needs tree structure normalization.
Args:
model: sklearn ensemble model
Returns:
bool: True if normalization needed
"""
return isinstance(model, SklearnEnsembleNormalizer.GB_MODELS)
[docs]
@staticmethod
def normalize(model):
"""
Normalize a sklearn ensemble model's tree structure.
For GradientBoosting models:
- Converts 2D estimators_ to 1D list
- Preserves the class-slot index for multiclass classifiers
Args:
model: sklearn ensemble model
Returns:
model: Normalized shallow copy with DPG-specific metadata
"""
if not SklearnEnsembleNormalizer.needs_normalization(model):
return model
# Check if already normalized
if isinstance(model.estimators_, list):
return model
normalized_model = copy.copy(model)
normalized_model._original_estimators_shape = model.estimators_.shape
normalized_model._normalized_for_dpg = True
# Flatten 2D (n_estimators, n_trees_per_iteration) to 1D list.
# For multiclass GradientBoostingClassifier the column index is the class slot.
flat_estimators = []
tree_class_indices = []
for row in model.estimators_:
for class_index, tree in enumerate(row):
flat_estimators.append(tree)
if isinstance(model, GradientBoostingClassifier) and len(row) > 1:
tree_class_indices.append(class_index)
else:
tree_class_indices.append(None)
normalized_model.estimators_ = flat_estimators
normalized_model._dpg_tree_class_indices = tree_class_indices
return normalized_model
[docs]
@staticmethod
def get_tree_class_index(model, tree_index):
"""Return the preserved class-slot index for a normalized GB tree."""
indices = getattr(model, "_dpg_tree_class_indices", None)
if indices is None:
return None
if tree_index < 0 or tree_index >= len(indices):
return None
return indices[tree_index]