import numpy as np
import pandas as pd
from sklearn import clone
from sklearn.base import BaseEstimator
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.validation import (
check_is_fitted,
check_array,
)
from sklego.common import as_list, expanding_list
from ._grouped_utils import relative_shrinkage, constant_shrinkage, min_n_obs_shrinkage
from ._grouped_utils import _split_groups_and_values
[docs]class GroupedPredictor(BaseEstimator):
"""
Construct an estimator per data group. Splits data by values of a
single column and fits one estimator per such column.
:param estimator: the model/pipeline to be applied per group
:param groups: the column(s) of the matrix/dataframe to select as a grouping parameter set
:param shrinkage: How to perform shrinkage.
None: No shrinkage (default)
{"constant", "min_n_obs", "relative"} or a callable
* constant: shrunk prediction for a level is weighted average of its prediction and its
parents prediction
* min_n_obs: shrunk prediction is the prediction for the smallest group with at least
n observations in it
* relative: each group-level is weight according to its size
* function: a function that takes a list of group lengths and returns an array of the
same size with the weights for each group
:param use_global_model: With shrinkage: whether to have a model over the entire input as first group
Without shrinkage: whether or not to fall back to a general model in case the group
parameter is not found during `.predict()`
:param check_X: Whether to validate X to be non-empty 2D array of finite values and attempt to cast X to float.
If disabled, the model/pipeline is expected to handle e.g. missing,
non-numeric, or non-finite values.
:param **shrinkage_kwargs: keyword arguments to the shrinkage function
"""
# Number of features in value df can be 0, e.g. for dummy models
_check_kwargs = {"ensure_min_features": 0, "accept_large_sparse": False}
_global_col_name = "a-column-that-is-constant-for-all-data"
_global_col_value = "global"
def __init__(
self,
estimator,
groups,
shrinkage=None,
use_global_model=True,
check_X=True,
**shrinkage_kwargs,
):
self.estimator = estimator
self.groups = groups
self.shrinkage = shrinkage
self.use_global_model = use_global_model
self.shrinkage_kwargs = shrinkage_kwargs
self.check_X = check_X
def __set_shrinkage_function(self):
if (
self.shrinkage
and len(as_list(self.groups)) == 1
and not self.use_global_model
):
raise ValueError(
"Cannot do shrinkage with a single group if use_global_model is False"
)
if isinstance(self.shrinkage, str):
# Predefined shrinkage functions
shrink_options = {
"constant": constant_shrinkage,
"relative": relative_shrinkage,
"min_n_obs": min_n_obs_shrinkage,
}
try:
self.shrinkage_function_ = shrink_options[self.shrinkage]
except KeyError:
raise ValueError(
f"The specified shrinkage function {self.shrinkage} is not valid, "
f"choose from {list(shrink_options.keys())} or supply a callable."
)
elif callable(self.shrinkage):
self.__check_shrinkage_func()
self.shrinkage_function_ = self.shrinkage
else:
raise ValueError(
"Invalid shrinkage specified. Should be either None (no shrinkage), str or callable."
)
def __check_shrinkage_func(self):
"""Validate the shrinkage function if a function is specified"""
group_lengths = [10, 5, 2]
expected_shape = np.array(group_lengths).shape
try:
result = self.shrinkage(group_lengths)
except Exception as e:
raise ValueError(
f"Caught an exception while checking the shrinkage function: {str(e)}"
) from e
else:
if not isinstance(result, np.ndarray):
raise ValueError(
f"shrinkage_function({group_lengths}) should return an np.ndarray"
)
if result.shape != expected_shape:
raise ValueError(
f"shrinkage_function({group_lengths}).shape should be {expected_shape}"
)
def __get_shrinkage_factor(self, X_group):
"""Get for all complete groups an array of shrinkages"""
group_colnames = X_group.columns.to_list()
counts = X_group.groupby(group_colnames).size()
# Groups that are split on all
most_granular_groups = [
grp for grp in self.groups_ if len(as_list(grp)) == len(group_colnames)
]
# For each hierarchy level in each most granular group, get the number of observations
hierarchical_counts = {
granular_group: [
counts[tuple(subgroup)].sum()
for subgroup in expanding_list(granular_group, tuple)
]
for granular_group in most_granular_groups
}
# For each hierarchy level in each most granular group, get the shrinkage factor
shrinkage_factors = {
group: self.shrinkage_function_(counts, **self.shrinkage_kwargs)
for group, counts in hierarchical_counts.items()
}
# Make sure that the factors sum to one
shrinkage_factors = {
group: value / value.sum() for group, value in shrinkage_factors.items()
}
return shrinkage_factors
def __fit_single_group(self, group, X, y=None):
try:
return clone(self.estimator).fit(X, y)
except Exception as e:
raise type(e)(f"Exception for group {group}: {e}")
def __fit_grouped_estimator(self, X_group, X_value, y=None, columns=None):
# Reset indices such that they are the same in X and y
if not columns:
columns = X_group.columns.tolist()
# Make the groups based on the groups dataframe, use the indices on the values array
try:
group_indices = X_group.groupby(columns).indices
except TypeError:
# This one is needed because of line #918 of sklearn/utils/estimator_checks
raise TypeError("argument must be a string, date or number")
if y is not None:
if isinstance(y, pd.Series):
y.index = X_group.index
grouped_estimators = {
# Fit a clone of the transformer to each group
group: self.__fit_single_group(group, X_value[indices, :], y[indices])
for group, indices in group_indices.items()
}
else:
grouped_estimators = {
group: self.__fit_single_group(group, X_value[indices, :])
for group, indices in group_indices.items()
}
return grouped_estimators
def __fit_shrinkage_groups(self, X_group, X_value, y):
estimators = {}
for grouping_colnames in self.group_colnames_hierarchical_:
# Fit a grouped estimator to each (sub)group hierarchically
estimators.update(
self.__fit_grouped_estimator(
X_group, X_value, y, columns=grouping_colnames
)
)
return estimators
def __add_shrinkage_column(self, X_group):
"""Add global group as first column if needed for shrinkage"""
if self.shrinkage is not None and self.use_global_model:
return pd.concat(
[
pd.Series(
[self._global_col_value] * len(X_group),
name=self._global_col_name,
),
X_group,
],
axis=1,
)
return X_group
[docs] def fit(self, X, y=None):
"""
Fit the model using X, y as training data. Will also learn the groups that exist within the dataset.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:param y: array-like, shape=(n_samples,) training data.
:return: Returns an instance of self.
"""
X_group, X_value = _split_groups_and_values(
X, self.groups, min_value_cols=0, check_X=self.check_X, **self._check_kwargs
)
X_group = self.__add_shrinkage_column(X_group)
if y is not None:
y = check_array(y, ensure_2d=False)
if self.shrinkage is not None:
self.__set_shrinkage_function()
# List of all hierarchical subsets of columns
self.group_colnames_hierarchical_ = expanding_list(X_group.columns, list)
self.fallback_ = None
if self.shrinkage is None and self.use_global_model:
self.fallback_ = clone(self.estimator).fit(X_value, y)
if self.shrinkage is not None:
self.estimators_ = self.__fit_shrinkage_groups(X_group, X_value, y)
else:
self.estimators_ = self.__fit_grouped_estimator(X_group, X_value, y)
self.groups_ = as_list(self.estimators_.keys())
if self.shrinkage is not None:
self.shrinkage_factors_ = self.__get_shrinkage_factor(X_group)
return self
def __predict_shrinkage_groups(self, X_group, X_value, method="predict"):
"""Make predictions for all shrinkage groups"""
# DataFrame with predictions for each hierarchy level, per row. Missing groups errors are thrown here.
hierarchical_predictions = pd.concat(
[
pd.Series(
self.__predict_groups(
X_group, X_value, level_columns, method=method
)
)
for level_columns in self.group_colnames_hierarchical_
],
axis=1,
)
# This is a Series with values the tuples of hierarchical grouping
prediction_groups = pd.Series(
[tuple(_) for _ in X_group.itertuples(index=False)]
)
# This is a Series of arrays
shrinkage_factors = prediction_groups.map(self.shrinkage_factors_)
# Convert the Series of arrays it to a DataFrame
shrinkage_factors = pd.DataFrame.from_dict(shrinkage_factors.to_dict()).T
return (hierarchical_predictions * shrinkage_factors).sum(axis=1)
def __predict_single_group(self, group, X, method="predict"):
"""Predict a single group by getting its estimator from the fitted dict"""
# Keep track of the original index such that we can sort in __predict_groups
index = X.index
try:
group_predictor = self.estimators_[group]
except KeyError:
if self.fallback_:
group_predictor = self.fallback_
else:
raise ValueError(
f"Found new group {group} during predict with use_global_model = False"
)
# getattr(group_predictor, method) returns the predict method of the fitted model
# if the method argument is "predict" and the predict_proba method if method argument is "predict_proba"
return pd.DataFrame(getattr(group_predictor, method)(X)).set_index(index)
def __predict_groups(
self,
X_group: pd.DataFrame,
X_value: np.array,
group_cols=None,
method="predict",
):
"""Predict for all groups"""
# Reset indices such that they are the same in X_group (reset in __check_grouping_columns),
# this way we can track the order of the result
X_value = pd.DataFrame(X_value).reset_index(drop=True)
if group_cols is None:
group_cols = X_group.columns.tolist()
# Make the groups based on the groups dataframe, use the indices on the values array
group_indices = X_group.groupby(group_cols).indices
return (
pd.concat(
[
self.__predict_single_group(
group, X_value.loc[indices, :], method=method
)
for group, indices in group_indices.items()
],
axis=0,
)
.sort_index()
.values.squeeze()
)
[docs] def predict(self, X):
"""
Predict on new data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples,) the predicted data
"""
check_is_fitted(self, ["estimators_", "groups_", "fallback_"])
X_group, X_value = _split_groups_and_values(
X, self.groups, min_value_cols=0, check_X=self.check_X, **self._check_kwargs
)
X_group = self.__add_shrinkage_column(X_group)
if self.shrinkage is None:
return self.__predict_groups(X_group, X_value, method="predict")
else:
return self.__predict_shrinkage_groups(X_group, X_value, method="predict")
# This ensures that the meta-estimator only has the predict_proba method if the estimator has it
[docs] @if_delegate_has_method("estimator")
def predict_proba(self, X):
"""
Predict probabilities on new data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples, n_classes) the predicted data
"""
check_is_fitted(self, ["estimators_", "groups_", "fallback_"])
X_group, X_value = _split_groups_and_values(
X, self.groups, min_value_cols=0, check_X=self.check_X, **self._check_kwargs
)
X_group = self.__add_shrinkage_column(X_group)
if self.shrinkage is None:
return self.__predict_groups(X_group, X_value, method="predict_proba")
else:
return self.__predict_shrinkage_groups(
X_group, X_value, method="predict_proba"
)
# This ensures that the meta-estimator only has the predict_proba method if the estimator has it
[docs] @if_delegate_has_method("estimator")
def decision_function(self, X):
"""
Evaluate the decision function for the samples in X.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: the decision function of the sample for each class in the model.
"""
check_is_fitted(self, ["estimators_", "groups_", "fallback_"])
X_group, X_value = _split_groups_and_values(
X, self.groups, min_value_cols=0, check_X=self.check_X, **self._check_kwargs
)
X_group = self.__add_shrinkage_column(X_group)
if self.shrinkage is None:
return self.__predict_groups(X_group, X_value, method="decision_function")
else:
return self.__predict_shrinkage_groups(
X_group, X_value, method="decision_function"
)