class sklego.common.TrainOnlyTransformerMixin[source]

Bases: sklearn.base.TransformerMixin

Allows using a separate function for transforming train and test data

>>> from sklearn.base import BaseEstimator
>>> class TrainOnlyTransformer(TrainOnlyTransformerMixin, BaseEstimator):
...     def fit(self, X, y):
...         super().fit(X, y)
...     def transform_train(self, X, y=None):
...          return X + np.random.normal(0, 1, size=X.shape)
>>> X_train, X_test = np.random.randn(100, 4), np.random.randn(100, 4)
>>> y_train, y_test = np.random.randn(100), np.random.randn(100)
>>> trf = TrainOnlyTransformer()
>>>, y_train)
>>> assert np.all(trf.transform(X_train) != X_train)
>>> assert np.all(trf.transform(X_test) == X_test)


Transformers using this class as a mixin should at a minimum:

  • call super().fit in their fit method

  • implement transform_train()

They may also implement transform_test(). If it is not implemented, transform_test will simply return the untransformed dataframe

fit(X, y=None)[source]

Calculates the hash of X_train

transform(X, y=None)[source]

Dispatcher for transform method.

It will dispatch to self.transform_train if X is the same as X passed to fit, otherwise, it will dispatch to self.trainsform_test

transform_test(X, y=None)[source]
transform_train(X, y=None)[source]

Helper function, always returns a list of the input value.


val – the input value.


the input value as a list.


>>> as_list('test')
>>> as_list(['test1', 'test2'])
['test1', 'test2']
sklego.common.expanding_list(list_to_extent, return_type=<class 'list'>)[source]

Make a expanding list of lists by making tuples of the first element, the first 2 elements etc.

  • list_to_extent

  • return_type – type of the elements of the list (tuple or list)


>>> expanding_list('test')
>>> expanding_list(['test1', 'test2', 'test3'])
[['test1'], ['test1', 'test2'], ['test1', 'test2', 'test3']]
>>> expanding_list(['test1', 'test2', 'test3'], tuple)
[('test1',), ('test1', 'test2'), ('test1', 'test2', 'test3')]

Helper function, returns an iterator of flattened values from an arbitrarily nested iterable

>>> list(flatten([['test1', 'test2'], ['a', 'b', ['c', 'd']]]))
['test1', 'test2', 'a', 'b', 'c', 'd']
>>> list(flatten(['test1', ['test2']]))
['test1', 'test2']