[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEP auto, binary_crossentropy, categorical_crossentropy in HGBT #23040

Merged
merged 15 commits into from
Apr 7, 2022

Conversation

lorentzenchr
Copy link
Member
@lorentzenchr lorentzenchr commented Apr 3, 2022

Reference Issues/PRs

Partially addresses #18248

What does this implement/fix? Explain your changes.

This PR introduces loss="log_loss" for HistGradientBoostingClassifier and deprecates other options.

Any other comments?

Currently, loss can be "auto", "binary_crossentropy" and "categorical_crossentropy". Can we remove the two options "binary_crossentropy" and "categorical_crossentropy"? I don't see a meaningful use case. For instance "categorical_crossentropy" raises ValueError on binary problems.

@lorentzenchr
Copy link
Member Author

@NicolasHug might be interested.

@lorentzenchr lorentzenchr added this to the 1.1 milestone Apr 4, 2022
@ogrisel
Copy link
Member
ogrisel commented Apr 6, 2022

I just checked if the choice of the loss function could not be used to over-parameterize the binary classification case as we do for the multiclass case, with one tree per class and per boosting iteration and a softmax inverse link function instead of the logistic sigmoid. At the moment it is not the case:

>>> from sklearn.ensemble import HistGradientBoostingClassifier
>>> from sklearn.datasets import make_classification
>>> X, y = make_classification(n_classes=2)
>>> HistGradientBoostingClassifier(loss="categorical_crossentropy").fit(X, y)
Traceback (most recent call last):
  Input In [19] in <cell line: 1>
    HistGradientBoostingClassifier(loss="categorical_crossentropy").fit(X, y)
  File ~/code/scikit-learn/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py:326 in fit
    self._loss = self._get_loss(sample_weight=sample_weight)
  File ~/code/scikit-learn/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py:1812 in _get_loss
    raise ValueError(
ValueError: loss='categorical_crossentropy' is not suitable for a binary classification problem. Please use loss='auto' or loss='binary_crossentropy' instead.

if we ever want to do this we can probably introduce a dedicated parameter instead, but this is probably a YAGNI.

@jeremiedbb
Copy link
Member
jeremiedbb commented Apr 6, 2022

What's new entry after #23036 is merged.

#23036 is merged :)

@ogrisel
Copy link
Member
ogrisel commented Apr 6, 2022

We need also need to document the deprecation in what's new once we agree on the new loss.

@francoisgoupil
Copy link
Member
francoisgoupil commented Apr 6, 2022

There are some tests that still pass loss="categorical_crossentropy" instead of "multiclass_log_loss" to the HistGradientBoostingClassifier and fail when trying to match the expected error and obtaining our new deprecation warning. A possibly non-exhaustive list of failing tests:

test_zero_sample_weights_classification
test_same_predictions_multiclass_classification[10000-8-1-3]
test_same_predictions_multiclass_classification[10000-8-20-3]
test_same_predictions_multiclass_classification[10000-8-1-2]
test_same_predictions_multiclass_classification[10000-8-20-0]
test_same_predictions_multiclass_classification[10000-8-1-1]
test_same_predictions_multiclass_classification[255-4096-1-0]
test_same_predictions_multiclass_classification[255-4096-1-3]
test_same_predictions_multiclass_classification[10000-8-1-4]
test_same_predictions_multiclass_classification[10000-8-20-2]
test_same_predictions_multiclass_classification[10000-8-20-4]
test_same_predictions_multiclass_classification[255-4096-1-4]
test_same_predictions_multiclass_classification[10000-8-1-0]
test_same_predictions_multiclass_classification[10000-8-20-1]
test_same_predictions_multiclass_classification[255-4096-1-1]
test_same_predictions_multiclass_classification[255-4096-1-2]
test_same_predictions_multiclass_classification[255-4096-20-4]
test_same_predictions_multiclass_classification[255-4096-20-0]
test_same_predictions_multiclass_classification[255-4096-20-1]
test_same_predictions_multiclass_classification[255-4096-20-2]
test_same_predictions_multiclass_classification[255-4096-20-3]

Same is happening for loss="binary_crossentropy" instead of "binary_log_loss". A possibly non-exhaustive list of failing tests:

test_same_predictions_classification[255-4096-1-2]
test_same_predictions_classification[255-4096-1-1]
test_same_predictions_classification[255-4096-1-3]
test_same_predictions_classification[255-4096-1-4]
test_same_predictions_classification[255-4096-20-0]
test_same_predictions_classification[255-4096-20-1]
test_same_predictions_classification[255-4096-20-4]
test_same_predictions_classification[255-4096-1-0]
test_same_predictions_classification[255-4096-20-2]
test_same_predictions_classification[255-4096-20-3]
test_same_predictions_classification[1000-8-1-2]
test_same_predictions_classification[1000-8-1-3]
test_same_predictions_classification[1000-8-1-4]
test_same_predictions_classification[1000-8-20-0]
test_same_predictions_classification[1000-8-20-1]
test_same_predictions_classification[1000-8-20-2]
test_same_predictions_classification[1000-8-20-3]
test_same_predictions_classification[1000-8-20-4]
test_same_predictions_classification[1000-8-1-0]
test_same_predictions_classification[1000-8-1-1]

@lorentzenchr
Copy link
Member Author

@francoisgoupil Very good point. It took me a little longer to fix all occurrences and make the changes. I hope I've got them all.

@ArturoAmorQ
Copy link
Member

We still have some failling tests. Maybe you can try a
git grep 'loss="binary_crossentropy"'
and
git grep 'loss="categorical_crossentropy"'

Copy link
Member
@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once @jeremiedbb's comment above have been dealt with.

Copy link
Member
@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeremiedbb jeremiedbb merged commit 5b69652 into scikit-learn:main Apr 7, 2022
@lorentzenchr lorentzenchr deleted the call_it_log_loss_in_hgbt branch April 8, 2022 05:56
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Apr 29, 2022
scikit-learn#23040)

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
Archived in project
Development

Successfully merging this pull request may close these issues.

None yet

5 participants