-
-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Allow for refit=callable in *SearchCV to add flexibility in identifying the best estimator #11269 #11354
Conversation
sklearn/model_selection/_search.py
Outdated
"refit should be set to False " | ||
"explicitly. %r was passed" | ||
% self.refit) | ||
refit_metric = scorer_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get why you need this. you don't use refit_metric
if refit is callable below. I also think making inferences from the name of the function is inappropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think refit_metric
is needed to compute self.best_score_ as shown here in the original code base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I would just disable best_score_
when refit is callable. Please test and document that behaviour.
sklearn/model_selection/_search.py
Outdated
Refit an estimator using the best found parameters on the whole | ||
dataset. | ||
|
||
For multiple metric evaluation, this needs to be a string denoting the | ||
scorer is used to find the best parameters for refitting the estimator | ||
at the end. | ||
|
||
Where there are considerations other than maximum model performance in | ||
choosing a best estimator, ``refit`` can be set to a function which returns | ||
thre selected ``best_index_`` given ``cv_results_``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thre -> the
sklearn/model_selection/_search.py
Outdated
Refit an estimator using the best found parameters on the whole | ||
dataset. | ||
|
||
For multiple metric evaluation, this needs to be a string denoting the | ||
scorer is used to find the best parameters for refitting the estimator | ||
at the end. | ||
|
||
Where there are considerations other than maximum model performance in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model performance -> score
sklearn/model_selection/_search.py
Outdated
Refit an estimator using the best found parameters on the whole | ||
dataset. | ||
|
||
For multiple metric evaluation, this needs to be a string denoting the | ||
scorer that would be used to find the best parameters for refitting | ||
the estimator at the end. | ||
|
||
Where there are considerations other than maximum model performance in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the same text and formatting in both places.
sklearn/model_selection/_search.py
Outdated
The refitted estimator is made available at the ``best_estimator_`` | ||
attribute and permits using ``predict`` directly on this | ||
``GridSearchCV`` instance. | ||
|
||
Also for multiple metric evaluation, the attributes ``best_index_``, | ||
``best_score_`` and ``best_parameters_`` will only be available if | ||
``refit`` is set and all of them will be determined w.r.t this specific | ||
scorer. | ||
scorer. If a callable is passed to parameter refit, the function's name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an unnecessary and unhelpful condition.
For multi-metric evaluation, the name of refit callable function must | ||
end with a scorer key(`_<scorer_name>`). | ||
""" | ||
def refit_prec(cv_results): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a realistic example in examples/model_selection/
rather than here.
As a simple example, I would consider using maximising score while minimising the number of selected features or PCA components.
Here we should merely be testing interface, and a dummy function (for instance, one that always chooses the lowest-score model) is sufficient / most appropriate, as it is then easy for us to be sure what correct behaviour is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman Would you say a dummy function like below is good enough to test our interface?
def refit_callable(cv_results):
return cv_results['mean_test_score'].argmin()
It seems that you're suggesting two things here :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that looks good. I might add to that an assertion that all the keys we expect to be in results are in there.
Yes, I am indeed suggesting a second thing here. An example in examples/model_selection
will hugely increase the visibility and practical usability of this feature. The example gallery is how we advise users how to use the features described in technical detail in the docstrings (and before StackOverflow has all the answers).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I'm adding a example from examples/model_selection
for this feature in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman is it appropriate to add one more example for refit=callable
in the docstring under GridSearchCV class after this one?
scikit-learn/sklearn/model_selection/_search.py
Lines 931 to 958 in 3b5abf7
Examples | |
-------- | |
>>> from sklearn import svm, datasets | |
>>> from sklearn.model_selection import GridSearchCV | |
>>> iris = datasets.load_iris() | |
>>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]} | |
>>> svc = svm.SVC(gamma="scale") | |
>>> clf = GridSearchCV(svc, parameters) | |
>>> clf.fit(iris.data, iris.target) | |
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS | |
GridSearchCV(cv=None, error_score=..., | |
estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=..., | |
decision_function_shape='ovr', degree=..., gamma=..., | |
kernel='rbf', max_iter=-1, probability=False, | |
random_state=None, shrinking=True, tol=..., | |
verbose=False), | |
fit_params=None, iid=..., n_jobs=1, | |
param_grid=..., pre_dispatch=..., refit=..., return_train_score=..., | |
scoring=..., verbose=...) | |
>>> sorted(clf.cv_results_.keys()) | |
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS | |
['mean_fit_time', 'mean_score_time', 'mean_test_score',... | |
'mean_train_score', 'param_C', 'param_kernel', 'params',... | |
'rank_test_score', 'split0_test_score',... | |
'split0_train_score', 'split1_test_score', 'split1_train_score',... | |
'split2_test_score', 'split2_train_score',... | |
'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a meaningful example is too large, and too much of a power-user feature, to be in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman It seems that we dont need to write test cases for our example under examples
directory, right? ;)
Feel free to use GitHub's todo list feature in the PR description. |
@jnothman Thanks for your input! I'll improve my implementation based on your feedback. |
enumerate(cv_results['mean_test_prec'])} | ||
# Select models which have test precisions within 1 standard deviation | ||
# of the best 'mean_test_prec' | ||
candidates = dict(filter(lambda i: (i[1] >= test_prec_lower |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, a dict comprehension is much easier to read than this
So is test_prec_upper > i[1] >= test_prec_lower
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
enumerate(cv_results['mean_fit_time'])} | ||
fit_time_rank = sorted(fit_time) | ||
for i in fit_time_rank: | ||
if fit_time[i] in candidates: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't working in AppVeyor. The function is returning None there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm replacing these two test cases with simpler ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Circle CI should fail if the example does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please reference the example from doc/modules/grid_search.rst. you should probably put the motivation / use case there more than in the example
Documentation is rendered at https://26300-843222-gh.circle-artifacts.com/0/doc/_changed.html |
} | ||
] | ||
|
||
grid = GridSearchCV(pipe, cv=3, n_jobs=1, param_grid=param_grid, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should be encouraging users to calculate a standard deviation over 3 samples. Make cv=10.
interface can also be used in multiple metrics evaluation. | ||
|
||
This example balances model complexity and cross-validated score by | ||
finding a decent accuracy within 1 standard deviation of the best accuracy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to say that this is a rule of thumb for insignificant difference.
We could determine insignificant difference in a more proper way, such as with a wilcoxon rank-sum test
@@ -0,0 +1,125 @@ | |||
""" | |||
======================================================================= | |||
Balance model complexity and cross-validated score using refit=callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop "using *"
upper/lower bounds within 1 standard deviation of the | ||
best `mean_test_scores`. | ||
""" | ||
std_test_score = np.std(scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be using std_test_score: you want standard deviation across cv splits, not across parameter candidates
@jnothman @adrinjalali Probably need your help to fix travis-ci issue... :-/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jiaowoshabi , LGTM!
Awesome, @jiaowoshabi! Please add an entry to the change log at |
doc/whats_new/v0.21.rst
Outdated
@@ -144,6 +144,14 @@ Support for Python 3.4 and below has been officially dropped. | |||
:func:`~model_selection.validation_curve` only the latter is required. | |||
:issue:`12613` and :issue:`12669` by :user:`Marc Torrellas <marctorrellas>`. | |||
|
|||
- |Enhancement| :class:`~model_selection.BaseSearchCV` now allows for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BaseSearchCV is not listed in doc/modules/classes.rst
so this link won't work. Ordinarily we'd reference GridSearchCV and RandomizedSearchCV. you could also consider referencing the user guide rather than the example?
sklearn/model_selection/_search.py
Outdated
|
||
See ``scoring`` parameter to know more about multiple metric | ||
evaluation. | ||
|
||
.. versionadded:: 0.20 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think versionchanged may be more appropriate, since the parameter was not added.
sklearn/model_selection/_search.py
Outdated
|
||
See ``scoring`` parameter to know more about multiple metric | ||
evaluation. | ||
|
||
.. versionadded:: 0.20 | ||
GridSearchCV supports ``refit`` = callable to add flexibility in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't mention GridSearchCV here. Simply say "Support for callable added." the rest is documented above.
Thanks @jiaowoshabi!! |
self.best_index_ = self.refit(results) | ||
if not isinstance(self.best_index_, (int, np.integer)): | ||
raise TypeError('best_index_ returned is not an integer') | ||
if self.best_index_ < 0 or self.best_index_ >= len(results): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty sure this is a bug: results is a dictionary of things, and each value is an array the size of the grid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opened #13413
…ng the best estimator (scikit-learn#11354)" This reverts commit b4f76cf.
…ng the best estimator (scikit-learn#11354)" This reverts commit b4f76cf.
Reference Issues/PRs
Fixes #11269. Fixes #12865. See also #9499
What does this implement/fix? Explain your changes.
Allow a callable to be passed to refit in *SearchCV to balance score and model complexity. This interface adds flexibility in identifying the "best" estimator. The function passed to parameter
refit
incorporate of which metric to optimise. Hence users can use multi-metric evaluation with this interface.
Any other comments?
mean_test_score
_search.py under model_selection directory
)plot_grid_search_refit_callable.py
) of demonstrating the usage of this interface underexamples/model_selection/
make
Checklist:
refit=callable
using simple dummy refit function.refit=callable
using similar example in multi-metric eval settings_search.py
to pass the above tests