[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX correct Lasso.dual_gap_ to match the objective in its docstring #19172

Merged
merged 8 commits into from
Jan 20, 2021

Conversation

mathurinm
Copy link
Contributor

Reference Issues/PRs

closes #19160

What does this implement/fix? Explain your changes.

The Lasso objective in the docstring is 1 / n_samples * the objective minimized in the cython solvers. The duality gap returned by the solvers should therefore be scaled so that after fitting, Lasso.dual_gap_ corresponds to the formulation with 1/n_samples.

Importantly, this does not affect the optimization process: the stopping criterion is untouched, and the same number of iterations are run.

@lorentzenchr
Copy link
Member
lorentzenchr commented Jan 18, 2021

I would (edit) not change the Cython code. AFAIU, that part is correct. I would change enet_path

def enet_path(X, y, *, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,

and correct the dual gap after calling the solver from cd_fast. (And add a comment why we need to rescale.)

@mathurinm
Copy link
Contributor Author

Thanks @lorentzenchr. I'm guessing you mean "I would NOT change" the cython code ? Indeed the gap there matches the cython formulation.

@lorentzenchr
Copy link
Member

I think MultiTaskElasticNet.fit also needs this rescaling as it does not use enet_path but calls enet_coordinate_descent_multi_task itself.

@mathurinm
Copy link
Contributor Author

thanks, fixed.

I'm missing something : MultitaskElasticNetCV uses path = staticmethod(enet_path), any reason why MultitaskElasticNet does not do the same instead of hardcoding the call to enet_coordinate_descent_multitask ?

@lorentzenchr
Copy link
Member
lorentzenchr commented Jan 18, 2021

@mathurinm Can you add a test for the correct dual gap?
A derivative can be checked by numerical differentiation. Is there something alike for the dual gap?

@lorentzenchr
Copy link
Member

MultiTaskElasticNet also inherits from Lasso instead of ElasticNet. Having it work with path = staticmethod(enet_path) is maybe an issue for another issue;-)

@mathurinm
Copy link
Contributor Author

I added a test where I compute manually the gap after fitting, and check that it equals Lasso.dual_gap_

Copy link
Member
@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

sklearn/linear_model/tests/test_coordinate_descent.py Outdated Show resolved Hide resolved
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Copy link
Member
@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

do we need a what's new entry?

@lorentzenchr
Copy link
Member

do we need a what's new entry?

Yes, a short bugfix note is a good idea.

@lorentzenchr lorentzenchr changed the title FIX: have Lasso.dual_gap_ match the objective in its docstring FIX correct Lasso.dual_gap_ to match the objective in its docstring Jan 20, 2021
@lorentzenchr lorentzenchr merged commit 9183486 into scikit-learn:master Jan 20, 2021
@lorentzenchr
Copy link
Member

@mathurinm Thank you very much for detecting and fixing this bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Lasso.dual_gap_ does not match the formulation in the doc
3 participants