[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for Transformers on y #4143

Closed
cancan101 opened this issue Jan 22, 2015 · 57 comments · May be fixed by #13269
Closed

Allow for Transformers on y #4143

cancan101 opened this issue Jan 22, 2015 · 57 comments · May be fixed by #13269

Comments

@cancan101
Copy link

Following up on #3113 and #3112, what about arbitrary transforms to the y values? Those issues dealt primarily with "label transforms" but I would like to use transformers to mean or range center the y values as well.

Ideally I would have some transform that can be applied to the y values before fitting and then applied in the inverse to the predicted y values coming out of predict.

Ideally this Transformer could be added to a pipeline.

Currently the signature for transform for StandardScaler allows for transforming y but as pointed out in the linked issues, not all transform methods have a signature allowing for a y to be passed in.

Further even for StandardScaler there is an inconsistency with the inverse_transform NOT taking y.

@amueller amueller added this to the 1.0 milestone Jan 22, 2015
@amueller
Copy link
Member

This is definitely on the todo. I can't find an issue for that, but it is planned, and one of the 1.0 goals. I am surprised there is no open issue for it.

@GaelVaroquaux
Copy link
Member

This is someting everybody has been wanting, and initially an API design mistake from my side.

However, the problem is that currently we cannot do that without changing the API and thus breaking everything.

A possible way forward is discussed in #3855

@amueller
Copy link
Member

I don't think we need to break everything for this to be implemented, but I haven't had time to look into it yet.

@GaelVaroquaux
Copy link
Member

I don't think we need to break everything for this to be implemented, but I
haven't had time to look into it yet.

If you can figure a path forward it would be great. It might require as a
temporary solution somewhat wacky code, but we can have it temporarily
in. What matters to me is that the API and the code in 1.0 are clean.

@amueller
Copy link
Member

(not really related to this issue but my plan for the next couple of month is:

  1. bugfixes, then help @ogrisel release
  2. code review new feature PRs
  3. finally merge neural nets
  4. start on 1.0 API issues

@GaelVaroquaux
Copy link
Member
  1. start on 1.0 API issues

I think that we should do a sprint on the road to 1.0. July?

@amueller
Copy link
Member

I'm up for it :)

@jnothman
Copy link
Member

A possible way forward is discussed in #3855

#3855 could potentially enable transforming y when training, but doesn't consider the inverse transformation upon prediction. When we consider such inverse transformations, are we concerned about regression only? If we're concerned about classification problem transformations, I think there's the additional concern that while the transformation for fit and predict might be straightforward, it may be necessary to implement a non-trivial decision_function, in which case a meta-estiamtor like OvR is necessary.

So, considering the advantages of a pipeline over a metaestimator: could you provide examples of reusable target transformers units that are (a) dependent on training data statistics; and (b) likely to be applied in sequence? For a stronger argument, might they be applied at different points in the Pipeline sequence?

In short we need to consider the cases of resampling and target transformation separately, and perhaps will find they can share a design, and perhaps not.

@neuvirth
Copy link
neuvirth commented Jan 27, 2015

I would like to suggest below the following modification to the Pipeline implementation (which I called here PieplineXY), which supports transformers that return both X and y in fit_transform.

"""
The :mod:`sklearn.PipelineXY` module implements utilities to build a composite
estimator, as a chain of transforms and estimators.
"""
# Author: Edouard Duchesnay
#         Gael Varoquaux
#         Virgile Fritsch
#         Alexandre Gramfort
#         Lars Buitinck
# Licence: BSD

from collections import defaultdict

import numpy as np
from scipy import sparse

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.externals.joblib import Parallel, delayed
from sklearn.externals import six
from sklearn.utils import tosequence
from sklearn.externals.six import iteritems

__all__ = ['PipelineXY', 'FeatureUnion']


# One round of beers on me if someone finds out why the backslash
# is needed in the Attributes section so as not to upset sphinx.

class PipelineXY(BaseEstimator):
    """PipelineXY of transforms with a final estimator.

    Sequentially apply a list of transforms and a final estimator.
    Intermediate steps of the PipelineXY must be 'transforms', that is, they
    must implements fit and transform methods.
    The final estimator needs only implements fit.

    The purpose of the PipelineXY is to assemble several steps that can be
    cross-validated together while setting different parameters.
    For this, it enables setting parameters of the various steps using their
    names and the parameter name separated by a '__', as in the example below.

    Parameters
    ----------
    steps: list
        List of (name, transform) tuples (implementing fit/transform) that are
        chained, in the order in which they are chained, with the last object
        an estimator.

    Examples
    --------
    >>> from sklearn import svm
    >>> from sklearn.datasets import samples_generator
    >>> from sklearn.feature_selection import SelectKBest
    >>> from sklearn.feature_selection import f_regression
    >>> from sklearn.PipelineXY import PipelineXY
    >>> # generate some data to play with
    >>> X, y = samples_generator.make_classification(
    ...     n_informative=5, n_redundant=0, random_state=42)
    >>> # ANOVA SVM-C
    >>> anova_filter = SelectKBest(f_regression, k=5)
    >>> clf = svm.SVC(kernel='linear')
    >>> anova_svm = PipelineXY([('anova', anova_filter), ('svc', clf)])
    >>> # You can set the parameters using the names issued
    >>> # For instance, fit using a k of 10 in the SelectKBest
    >>> # and a parameter 'C' of the svm
    >>> anova_svm.set_params(anova__k=10, svc__C=.1).fit(X, y)
    ...                                              # doctest: +ELLIPSIS
    PipelineXY(steps=[...])
    >>> prediction = anova_svm.predict(X)
    >>> anova_svm.score(X, y)                        # doctest: +ELLIPSIS
    0.77...
    """

    # BaseEstimator interface

    def __init__(self, steps):
        self.named_steps = dict(steps)
        names, estimators = zip(*steps)
        if len(self.named_steps) != len(steps):
            raise ValueError("Names provided are not unique: %s" % (names,))

        # shallow copy of steps
        self.steps = tosequence(zip(names, estimators))
        transforms = estimators[:-1]
        estimator = estimators[-1]

        for t in transforms:
            if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not
                    hasattr(t, "transform")):
                raise TypeError("All intermediate steps a the chain should "
                                "be transforms and implement fit and transform"
                                " '%s' (type %s) doesn't)" % (t, type(t)))

        if not hasattr(estimator, "fit"):
            raise TypeError("Last step of chain should implement fit "
                            "'%s' (type %s) doesn't)"
                            % (estimator, type(estimator)))

    def get_params(self, deep=True):
        if not deep:
            return super(PipelineXY, self).get_params(deep=False)
        else:
            out = self.named_steps.copy()
            for name, step in six.iteritems(self.named_steps):
                for key, value in six.iteritems(step.get_params(deep=True)):
                    out['%s__%s' % (name, key)] = value
            return out

    # Estimator interface

    def _pre_transform(self, X, y=None, **fit_params):
        fit_params_steps = dict((step, {}) for step, _ in self.steps)
        for pname, pval in six.iteritems(fit_params):
            step, param = pname.split('__', 1)
            fit_params_steps[step][param] = pval
        Xt = X
        yt = y
        for name, transform in self.steps[:-1]:
            if hasattr(transform, "fit_transform"):
                Xt = transform.fit_transform(Xt, yt, **fit_params_steps[name])
            else:
                Xt = transform.fit(Xt, yt, **fit_params_steps[name]) \
                              .transform(Xt)
            if (type(Xt) is tuple):
                Xt, yt = Xt
        return Xt, yt, fit_params_steps[self.steps[-1][0]]

    def fit(self, X, y=None, **fit_params):
        """Fit all the transforms one after the other and transform the
        data, then fit the transformed data using the final estimator.
        """
        Xt, yt, fit_params = self._pre_transform(X, y, **fit_params)
        self.steps[-1][-1].fit(Xt, yt, **fit_params)
        return self

    def fit_transform(self, X, y=None, **fit_params):
        """Fit all the transforms one after the other and transform the
        data, then use fit_transform on transformed data using the final
        estimator."""
        Xt, yt, fit_params = self._pre_transform(X, y, **fit_params)
        if hasattr(self.steps[-1][-1], 'fit_transform'):
            return self.steps[-1][-1].fit_transform(Xt, yt, **fit_params)
        else:
            return self.steps[-1][-1].fit(Xt, yt, **fit_params).transform(Xt)

    def predict(self, X):
        """Applies transforms to the data, and the predict method of the
        final estimator. Valid only if the final estimator implements
        predict."""
        Xt = X
        for name, transform in self.steps[:-1]:
            Xt = transform.transform(Xt)
        return self.steps[-1][-1].predict(Xt)

    def predict_proba(self, X):
        """Applies transforms to the data, and the predict_proba method of the
        final estimator. Valid only if the final estimator implements
        predict_proba."""
        Xt = X
        for name, transform in self.steps[:-1]:
            Xt = transform.transform(Xt)
        return self.steps[-1][-1].predict_proba(Xt)

    def decision_function(self, X):
        """Applies transforms to the data, and the decision_function method of
        the final estimator. Valid only if the final estimator implements
        decision_function."""
        Xt = X
        for name, transform in self.steps[:-1]:
            Xt = transform.transform(Xt)
        return self.steps[-1][-1].decision_function(Xt)

    def predict_log_proba(self, X):
        Xt = X
        for name, transform in self.steps[:-1]:
            Xt = transform.transform(Xt)
        return self.steps[-1][-1].predict_log_proba(Xt)

    def transform(self, X):
        """Applies transforms to the data, and the transform method of the
        final estimator. Valid only if the final estimator implements
        transform."""
        Xt = X
        for name, transform in self.steps:
            Xt = transform.transform(Xt)
        return Xt

    def inverse_transform(self, X):
        if X.ndim == 1:
            X = X[None, :]
        Xt = X
        for name, step in self.steps[::-1]:
            Xt = step.inverse_transform(Xt)
        return Xt

    def score(self, X, y=None):
        """Applies transforms to the data, and the score method of the
        final estimator. Valid only if the final estimator implements
        score."""
        Xt = X
        for name, transform in self.steps[:-1]:
            Xt = transform.transform(Xt)
        return self.steps[-1][-1].score(Xt, y)

    @property
    def _pairwise(self):
        # check if first estimator expects pairwise input
        return getattr(self.steps[0][1], '_pairwise', False)


def _name_estimators(estimators):
    """Generate names for estimators."""

    names = [type(estimator).__name__.lower() for estimator in estimators]
    namecount = defaultdict(int)
    for est, name in zip(estimators, names):
        namecount[name] += 1

    for k, v in list(six.iteritems(namecount)):
        if v == 1:
            del namecount[k]

    for i in reversed(range(len(estimators))):
        name = names[i]
        if name in namecount:
            names[i] += "-%d" % namecount[name]
            namecount[name] -= 1

    return list(zip(names, estimators))


def make_PipelineXY(*steps):
    """Construct a PipelineXY from the given estimators.

    This is a shorthand for the PipelineXY constructor; it does not require, and
    does not permit, naming the estimators. Instead, they will be given names
    automatically based on their types.

    Examples
    --------
    >>> from sklearn.naive_bayes import GaussianNB
    >>> from sklearn.preprocessing import StandardScaler
    >>> make_PipelineXY(StandardScaler(), GaussianNB())    # doctest: +NORMALIZE_WHITESPACE
    PipelineXY(steps=[('standardscaler',
                     StandardScaler(copy=True, with_mean=True, with_std=True)),
                    ('gaussiannb', GaussianNB())])

    Returns
    -------
    p : PipelineXY
    """
    return PipelineXY(_name_estimators(steps))


def _fit_one_transformer(transformer, X, y):
    return transformer.fit(X, y)


def _transform_one(transformer, name, X, transformer_weights):
    if transformer_weights is not None and name in transformer_weights:
        # if we have a weight for this transformer, muliply output
        return transformer.transform(X) * transformer_weights[name]
    return transformer.transform(X)


def _fit_transform_one(transformer, name, X, y, transformer_weights,
                       **fit_params):
    if transformer_weights is not None and name in transformer_weights:
        # if we have a weight for this transformer, muliply output
        if hasattr(transformer, 'fit_transform'):
            X_transformed, yt = transformer.fit_transform(X, y, **fit_params)
            return X_transformed * transformer_weights[name], yt, transformer
        else:
            X_transformed = transformer.fit(X, y, **fit_params).transform(X)
            return X_transformed * transformer_weights[name], transformer
    if hasattr(transformer, 'fit_transform'):
        X_transformed, yt = transformer.fit_transform(X, y, **fit_params)
        return X_transformed, yt, transformer
    else:
        X_transformed = transformer.fit(X, y, **fit_params).transform(X)
        return X_transformed, transformer


class FeatureUnion(BaseEstimator, TransformerMixin):
    """Concatenates results of multiple transformer objects.

    This estimator applies a list of transformer objects in parallel to the
    input data, then concatenates the results. This is useful to combine
    several feature extraction mechanisms into a single transformer.

    Parameters
    ----------
    transformer_list: list of (string, transformer) tuples
        List of transformer objects to be applied to the data. The first
        half of each tuple is the name of the transformer.

    n_jobs: int, optional
        Number of jobs to run in parallel (default 1).

    transformer_weights: dict, optional
        Multiplicative weights for features per transformer.
        Keys are transformer names, values the weights.

    """
    def __init__(self, transformer_list, n_jobs=1, transformer_weights=None):
        self.transformer_list = transformer_list
        self.n_jobs = n_jobs
        self.transformer_weights = transformer_weights

    def get_feature_names(self):
        """Get feature names from all transformers.

        Returns
        -------
        feature_names : list of strings
            Names of the features produced by transform.
        """
        feature_names = []
        for name, trans in self.transformer_list:
            if not hasattr(trans, 'get_feature_names'):
                raise AttributeError("Transformer %s does not provide"
                                     " get_feature_names." % str(name))
            feature_names.extend([name + "__" + f for f in
                                  trans.get_feature_names()])
        return feature_names

    def fit(self, X, y=None):
        """Fit all transformers using X.

        Parameters
        ----------
        X : array-like or sparse matrix, shape (n_samples, n_features)
            Input data, used to fit transformers.
        """
        transformers = Parallel(n_jobs=self.n_jobs)(
            delayed(_fit_one_transformer)(trans, X, y)
            for name, trans in self.transformer_list)
        self._update_transformer_list(transformers)
        return self

    def fit_transform(self, X, y=None, **fit_params):
        """Fit all transformers using X, transform the data and concatenate
        results.

        Parameters
        ----------
        X : array-like or sparse matrix, shape (n_samples, n_features)
            Input data to be transformed.

        Returns
        -------
        X_t : array-like or sparse matrix, shape (n_samples, sum_n_components)
            hstack of results of transformers. sum_n_components is the
            sum of n_components (output dimension) over transformers.
        """
        result = Parallel(n_jobs=self.n_jobs)(
            delayed(_fit_transform_one)(trans, name, X, y,
                                        self.transformer_weights, **fit_params)
            for name, trans in self.transformer_list)

        Xs, transformers = zip(*result)
        self._update_transformer_list(transformers)
        if any(sparse.issparse(f) for f in Xs):
            Xs = sparse.hstack(Xs).tocsr()
        else:
            Xs = np.hstack(Xs)
        return Xs

    def transform(self, X):
        """Transform X separately by each transformer, concatenate results.

        Parameters
        ----------
        X : array-like or sparse matrix, shape (n_samples, n_features)
            Input data to be transformed.

        Returns
        -------
        X_t : array-like or sparse matrix, shape (n_samples, sum_n_components)
            hstack of results of transformers. sum_n_components is the
            sum of n_components (output dimension) over transformers.
        """
        Xs = Parallel(n_jobs=self.n_jobs)(
            delayed(_transform_one)(trans, name, X, self.transformer_weights)
            for name, trans in self.transformer_list)
        if any(sparse.issparse(f) for f in Xs):
            Xs = sparse.hstack(Xs).tocsr()
        else:
            Xs = np.hstack(Xs)
        return Xs

    def get_params(self, deep=True):
        if not deep:
            return super(FeatureUnion, self).get_params(deep=False)
        else:
            out = dict(self.transformer_list)
            for name, trans in self.transformer_list:
                for key, value in iteritems(trans.get_params(deep=True)):
                    out['%s__%s' % (name, key)] = value
            return out

    def _update_transformer_list(self, transformers):
        self.transformer_list[:] = [
            (name, new)
            for ((name, old), new) in zip(self.transformer_list, transformers)
        ]


# XXX it would be nice to have a keyword-only n_jobs argument to this function,
# but that's not allowed in Python 2.x.
def make_union(*transformers):
    """Construct a FeatureUnion from the given transformers.

    This is a shorthand for the FeatureUnion constructor; it does not require,
    and does not permit, naming the transformers. Instead, they will be given
    names automatically based on their types. It also does not allow weighting.

    Examples
    --------
    >>> from sklearn.decomposition import PCA, TruncatedSVD
    >>> make_union(PCA(), TruncatedSVD())    # doctest: +NORMALIZE_WHITESPACE
    FeatureUnion(n_jobs=1,
                 transformer_list=[('pca', PCA(copy=True, n_components=None,
                                               whiten=False)),
                                   ('truncatedsvd',
                                    TruncatedSVD(algorithm='randomized',
                                                 n_components=2, n_iter=5,
                                                 random_state=None, tol=0.0))],
                 transformer_weights=None)

    Returns
    -------
    f : FeatureUnion
    """
    return FeatureUnion(_name_estimators(transformers))

@jnothman
Copy link
Member

It's very difficult to see what you mean when you post the code this way. But it's probably not the way to go... I don't think having transform sometimes return a tuple is going to be a pleasant way to deal with this API problem.

@GaelVaroquaux
Copy link
Member

I don't think having transform sometimes return a tuple is going to be
a pleasant way to deal with this API problem.

Indeed. A signature of a method (input type and return type) should be
independent of the context in which it is called, or the object.

@jnothman
Copy link
Member

To that end, we're explicitly going in the opposite direction to force
clusterers to allow the user to pass an ignored y parameter when training
(see #4064)

On 27 January 2015 at 21:28, Gael Varoquaux notifications@github.com
wrote:

I don't think having transform sometimes return a tuple is going to be
a pleasant way to deal with this API problem.

Indeed. A signature of a method (input type and return type) should be
independent of the context in which it is called, or the object.


Reply to this email directly or view it on GitHub
#4143 (comment)
.

@GaelVaroquaux
Copy link
Member

To that end, we're explicitly going in the opposite direction to force
clusterers to allow the user to pass an ignored y parameter when training
(see #4064)

And that is a Good thing :).

With insight, I believe that it was a mistake to have fit signatures
accept either "X" or "X, y".

@neuvirth
Copy link

The only reason it's doesn't always return both X and Y is to support the current implementation. As can be clearly seen from the various questions here, very often both of them need to be processed together.

@GaelVaroquaux
Copy link
Member

The only reason it's doesn't always return both X and Y is to support the
current implementation. As can be clearly seen from the various questions here,
very often both of them need to be processed together.

Agreed. But we need a smooth way forward to avoid breaking everybodys
code. It has been discussed to add another method, (in the issue linked
before) named different, and deprecate this one.

@jnothman
Copy link
Member

Except that sometimes the transformer will change not X nor y but
sample_weight, and then what?

On 27 January 2015 at 21:35, Gael Varoquaux notifications@github.com
wrote:

The only reason it's doesn't always return both X and Y is to support
the
current implementation. As can be clearly seen from the various
questions here,
very often both of them need to be processed together.

Agreed. But we need a smooth way forward to avoid breaking everybodys
code. It has been discussed to add another method, (in the issue linked
before) named different, and deprecate this one.


Reply to this email directly or view it on GitHub
#4143 (comment)
.

@GaelVaroquaux
Copy link
Member

Except that sometimes the transformer will change not X nor y but
sample_weight, and then what?

Indeed, we are back to the fact that we need to be able to take dict of
arrays as y.

That somewhat summarize why I think that these problems are hard, and
that we need a week-long sprint to addresse them in July.

Joel, any chance you make it to Europe this summer?

@neuvirth
Copy link

good point, although in my opinion less urgent than the not-processing-Y issue. Not sure I follow: you suggest to pass the weights as part of Y? why not pass a third argument with this dict? I believe Y has it's unique role in ML..

@jnothman
Copy link
Member

All I mean is that a tuple is not sufficient.

On 27 January 2015 at 21:54, neuvirth notifications@github.com wrote:

good point, although in my opinion less urgent than the not-processing-Y
issue. Not sure I follow: you suggest to pass the weights as part of Y? why
not pass a third argument with this dict? I believe Y has it's unique role
in ML..


Reply to this email directly or view it on GitHub
#4143 (comment)
.

@amueller
Copy link
Member

After my jab at this in #4552, I don't think this is a good idea any more.
The subsampling case can much easier be dealt with using a meta-estimator, and I haven't really found any other good application.
Most features, like the one OP asks for are very specific, and I feel these are better suited using meta-estimators. For the "preprocessing and postprocessing" of y, you would need to attach something to the end of the pipeline. That seems very different from what pipelines currently do, but could easily be handled using a meta-estimator.

@jnothman
Copy link
Member
  1. class_weight='balanced' should be more-or-less equivalent (or better) relative to downsampling.
  2. Resampler estimators that change the sample size in fitting #3855, specifically Resampler estimators that change the sample size in fitting #3855 (comment), is the proposed solution. Help welcome.
  3. A meta-estimator can be constructed to do what you want, even without Resampler estimators that change the sample size in fitting #3855

@amueller
Copy link
Member

I would argue class_weight='balanced' is equivalent to upsampling ;)
You could also use imbalanced-learn btw, which does implement this as a pipeline:
https://github.com/scikit-learn-contrib/imbalanced-learn

@eliasmistler
Copy link
eliasmistler commented Nov 22, 2018 via email

@theoptips
Copy link
Contributor
theoptips commented Nov 2, 2019

Related #12587 Added transform y section to faq.rst FAQ documentation via #15484

@dabana
Copy link
dabana commented Nov 4, 2019

Quoting @jnothman from #15484:

we hope to solve for use cases where y should be transformed at training time and not at test time, for re-sampling and similar uses, like at imbalanced learn. In general these use cases should be solved with a custom meta estimator rather than a Pipeline.

My team and I have build such a meta-estimator for our application. But one problem remained: how can we use Scikit-learn's cross-validation tooling (GridCVSearch, etc...) on this estimator? Indeed, Scikit-learn's conventional evaluation metrics does not do the trick. Here are the two main concerns we had:

  1. Evaluating estimators that transform y WITHOUT applying an inverse-transform post-prediction (unlike with TransformedTargetRegressor). Our use-case is regression by classification: we would like to perform parameter search on target discretization using cross-validation.

  2. Evaluating estimators that change the number of samples, like re-sampling. For our use-case, we are simply aggregating grouped data before training: we want to do parameter search on aggregation functions (comparing mean, median, min, max, etc.) using cross-validation.

For now, the way we go around this problem involves two steps:

  1. Endow our custom meta-estimator with a getter method to retrieve transformed targets: the so-called get_transformed_targets method.
class CustomMetaEstimator(_BaseComposition):
    [...]
    def get_transformed_targets(self, X, y_true):
        '''Returns the transformed targets
        '''
        X_transformed, y_true_transformed = X, y_true
        for transformer in self.preprocessing_transformers:
            output = self.named_steps[step].fit_transform(X_transformed, y_true_transformed)
            if len(output) == 2:
                X_transformed, y_true_transformed = output
            else:
                X_transformed = output
        return y_true_transformed
  1. Perform a slight hack to sklearn's _PredictScorer class:
class _PredictScorer(_BaseScorer):
    def _score(self, method_caller, estimator, X, y_true, sample_weight=None):
        """[... docstring ...]
        """
        #Here starts the hack
        if hasattr(estimator, 'get_transformed_targets'):
            y_true = estimator.get_transformed_targets(X, y_true)
        #Here ends the hack

        y_pred = method_caller(estimator, "predict", X)
        if sample_weight is not None:
            return self._sign * self._score_func(y_true, y_pred,
                                                 sample_weight=sample_weight,
                                                 **self._kwargs)
        else:
            return self._sign * self._score_func(y_true, y_pred,
                                                 **self._kwargs)

@jnothman
Copy link
Member
jnothman commented Dec 6, 2019

@dabana could you open a new issue re regression by classification, please? although I think this could be handled with our current proposal for resampling.

@adrinjalali
Copy link
Member

Moving the milestone to 2.0

@adrinjalali adrinjalali modified the milestones: 1.0, 2.0 Aug 22, 2021
@Permafacture
Copy link

We've been using our own PipelineXY at work for a couple of years and wouldn't be able to do without it. In addition to data loading and data augmentation (both very useful applications), another use case for this is getting StratifiedKFold to to stratify on more than True/False.

For instance, we have several types of instruments and we need positive and negative cases from each instrument type to be evenly distributed when grid searching. So, we make the y values strings like 'Type1_True' and then our pipeline has an XyTransformer that converts these to True/False. Then we can use all the sklearn cross validation tools like GridSearchCV with StratifiedKFold and not have to modify any of the other transformers or estimators.

@pedroilidio
Copy link

What about having a Transformer attribute like returns_y with values False (default), True or "both" that indicates whether transform returns transformed X (as usual), transformed y or both Xand y in a tuple?

Pros I see are it seems easy to

  1. convert a usual transformer to a target transformer;
  2. assure compatibility with old custom transformers by assuming in Pipeline that not having the attribute means setting it to False.

Cons I see are
3. the new target transformers would behave unpredictably for old code;
4. we still have the issue of variable format and meaning of transform output.

Another (maybe simpler) alternative would be to have a third optional element in each step tuple of a Pipeline construction, with a string indicating "transform_target" or something similar to direct a standard transformer towards y instead of X. This syntax could be extended to accept these discussed new transformers that take both X and y and returns new y, or new y and new X, maybe passing "transform_target" in the first case and "transform_both" for the second.

On the issue of ensuring consistent y dimensionality for scoring, we could have another keyword processed similarly to the current "passthrough" in pipelines, maybe "collect_true_labels" or "store_true_y", that indicates in which step the output labels are to be considered true target values for scoring. This seems hard to implement in cross-validation though, and would be more prone to data leakage.

I could work on a PR if any idea sounds interesting.

@dmbee
Copy link
Contributor
dmbee commented Mar 19, 2022

Alternatively you could have a base class e.g. XYTransformerMixin, and use isistance()

@Permafacture
Copy link
Permafacture commented Mar 19, 2022

I have the transformer return a named tuple called XyTuple with attributes X and y and do isinstance on the value returned by the transformer. I think this or the XyTransfomer make sense. I did both but what really matters is what the transformer returns so I opted for the named tuple rather than allowing for a possible mismatch between what the Transformer technically inherits from and what it actually does. With the named tuple, Transformers could be implemented to have a fit_transform that transforms y and transform could either handle y or not depending on what makes sense for that particular transformer.

As far as the PipelineXY goes, fit_transform should always return XyTuple. I think transform should also return XyTuple so that y can be used in estimator scoring, but since this is a change from expected behaviour the pipeline.transform could take a boolean kwarg transform_y that determines the output type.

For grid searching and other things like that, it would be best for y to come from the step of the pipeline right before the estimator.

@elliot-hallmark
Copy link

Commenting from a different account. The solution I've been using (last comment by @Permafacture) is broken by sklearn 1.2 and it's set_output. My XyTuple gets converted to a regular tuple by a wrapper somewhere in the machinery. If anyone knows how get my XyTransformerMixin to opt out of the set_output behavior, let me know. Until then I'm sticking to 1.1.3

@adrinjalali
Copy link
Member

cc @thomasjpfan

@thomasjpfan
Copy link
Member

As noted in set_output developer docs, one can add a auto_wrap_output_key=None to disable all of set_output:

class XyTransformerMixin(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
    ...

Overall, I think this is a bug and I opened #26121 to fix it.

@worthy7
Copy link
worthy7 commented Apr 17, 2024

Bumb for demand on this. I want to use XGBBoost, but it wants the y targets in to be ordered ints.
Performing label encoding before running a cross_val may remove some of them during splitting and thus breaks things.
So for this reason I need the LE in the pipeline to act on the Y in that split instead, but it seems impossible.

@GaelVaroquaux
Copy link
Member

2 things:

  1. the problem that you are mentioning should be fixed in XGB
  2. These days, I really think that transformers on y are not something that fits well in the logic of the scikit-learn pipeline / model selection tools (too many difficulties to make the model validation solid).

The right solution would be to have a TransformerTargetClassifier, as discussed in #20952
A Pull Request for this would be great.

@worthy7
Copy link
worthy7 commented Apr 17, 2024

Yes I agree with 1, there is an issue for it dmlc/xgboost#10078

@adrinjalali
Copy link
Member

So can we actually close this one?

@GaelVaroquaux
Copy link
Member
GaelVaroquaux commented Apr 17, 2024 via email

@adrinjalali adrinjalali closed this as not planned Won't fix, can't repro, duplicate, stale Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
Andy's pets
wishful thinking
Development

Successfully merging a pull request may close this issue.