Source code for sklego.preprocessing.outlier_remover

from sklearn import clone
from sklearn.base import BaseEstimator
from sklearn.utils.validation import (

from sklego.common import TrainOnlyTransformerMixin

[docs]class OutlierRemover(TrainOnlyTransformerMixin, BaseEstimator): """ Removes outliers (train-time only) using the supplied removal model. :param outlier_detector: must implement `fit` and `predict` methods :param refit: If True, fits the estimator during """ def __init__(self, outlier_detector, refit=True): self.outlier_detector = outlier_detector self.refit = refit self.estimator_ = None
[docs] def fit(self, X, y=None): self.estimator_ = clone(self.outlier_detector) if self.refit: super().fit(X, y), y) return self
[docs] def transform_train(self, X): check_is_fitted(self, "estimator_") predictions = self.estimator_.predict(X) check_array(predictions, estimator=self.outlier_detector, ensure_2d=False) return X[predictions != -1]