Source code for sklego.meta.estimator_transformer

from sklearn import clone
from sklearn.base import (
    BaseEstimator,
    TransformerMixin,
    MetaEstimatorMixin,
)
from sklearn.utils.validation import (
    check_is_fitted,
    check_X_y,
    FLOAT_DTYPES,
)


[docs]class EstimatorTransformer(TransformerMixin, MetaEstimatorMixin, BaseEstimator): """ Allows using an estimator such as a model as a transformer in an earlier step of a pipeline :param estimator: An instance of the estimator that should be used for the transformation :param predict_func: The function called on the estimator when transforming e.g. (`predict`, `predict_proba`) """ def __init__(self, estimator, predict_func="predict"): self.estimator = estimator self.predict_func = predict_func
[docs] def fit(self, X, y, **kwargs): """Fits the estimator""" X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES, multi_output=True) self.multi_output_ = len(y.shape) > 1 self.estimator_ = clone(self.estimator) self.estimator_.fit(X, y, **kwargs) return self
[docs] def transform(self, X): """ Applies the `predict_func` on the fitted estimator. Returns array of shape `(X.shape[0], )` if estimator is not multi output. For multi output estimators an array of shape `(X.shape[0], y.shape[1])` is returned. """ check_is_fitted(self, "estimator_") output = getattr(self.estimator_, self.predict_func)(X) return output if self.multi_output_ else output.reshape(-1, 1)