Source code for sklego.meta.grouped_transformer

import numpy as np
import pandas as pd

from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.utils.validation import check_is_fitted

from ._grouped_utils import _split_groups_and_values

[docs]class GroupedTransformer(BaseEstimator, TransformerMixin): """ Construct a transformer per data group. Splits data by groups from single or multiple columns and transforms remaining columns using the transformers corresponding to the groups. :param transformer: the transformer to be applied per group :param groups: the column(s) of the matrix/dataframe to select as a grouping parameter set. If None, the transformer will be applied to the entire input without grouping :param use_global_model: Whether or not to fall back to a general transformation in case a group is not found during `.transform()` """ _check_kwargs = {"accept_large_sparse": False} def __init__(self, transformer, groups, use_global_model=True): self.transformer = transformer self.groups = groups self.use_global_model = use_global_model def __fit_single_group(self, group, X, y=None): try: return clone(self.transformer).fit(X, y) except Exception as e: raise type(e)(f"Exception for group {group}: {e}") def __fit_grouped_transformer( self, X_group: pd.DataFrame, X_value: np.array, y=None ): """Fit a transformer to each group""" # Make the groups based on the groups dataframe, use the indices on the values array try: group_indices = X_group.groupby(X_group.columns.tolist()).indices except TypeError: # This one is needed because of line #918 of sklearn/utils/estimator_checks raise TypeError("argument must be a string, date or number") if y is not None: if isinstance(y, pd.Series): y.index = X_group.index grouped_transformers = { # Fit a clone of the transformer to each group group: self.__fit_single_group(group, X_value[indices, :], y[indices]) for group, indices in group_indices.items() } else: grouped_transformers = { group: self.__fit_single_group(group, X_value[indices, :]) for group, indices in group_indices.items() } return grouped_transformers def __check_transformer(self): if not hasattr(self.transformer, "transform"): raise ValueError( "The supplied transformer should have a 'transform' method" )
[docs] def fit(self, X, y=None): """ Fit the transformers to the groups in X :param X: Array-like with at least two columns, of which at least one corresponds to groups defined in init, and the remaining columns represent the values to transform. :param y: (Optional) target variable """ self.__check_transformer() self.fallback_ = None if self.groups is None: self.transformers_ = clone(self.transformer).fit(X, y) return self X_group, X_value = _split_groups_and_values( X, self.groups, **self._check_kwargs ) self.transformers_ = self.__fit_grouped_transformer(X_group, X_value, y) if self.use_global_model: self.fallback_ = clone(self.transformer).fit(X_value) return self
def __transform_single_group(self, group, X): """Transform a single group by getting its transformer from the fitted dict""" # Keep track of the original index such that we can sort in __transform_groups index = X.index try: group_transformer = self.transformers_[group] except KeyError: if self.fallback_: group_transformer = self.fallback_ else: raise ValueError( f"Found new group {group} during transform with use_global_model = False" ) return pd.DataFrame(group_transformer.transform(X)).set_index(index) def __transform_groups(self, X_group: pd.DataFrame, X_value: np.array): """Transform all groups""" # Reset indices such that they are the same in X_group (reset in __check_grouping_columns), # this way we can track the order of the result X_value = pd.DataFrame(X_value).reset_index(drop=True) # Make the groups based on the groups dataframe, use the indices on the values array group_indices = X_group.groupby(X_group.columns.tolist()).indices return ( pd.concat( [ self.__transform_single_group(group, X_value.loc[indices, :]) for group, indices in group_indices.items() ], axis=0, ) .sort_index() .values )
[docs] def transform(self, X): """ Fit the transformers to the groups in X :param X: Array-like with columns corresponding to the ones in .fit() """ check_is_fitted(self, ["fallback_", "transformers_"]) if self.groups is None: return self.transformers_.transform(X) X_group, X_value = _split_groups_and_values( X, self.groups, **self._check_kwargs ) return self.__transform_groups(X_group, X_value)