-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regularisation for Decision Tree. #1296
Conversation
Early stopping in literature is the process of calculating error on the validation set before each split and only then carry about the split if there is an error drop of more than threshold. However, For implementing post-pruning, what shall be the best to implement? Minimum error tree or smallest tree or the weighted sum of the two? (All these will be performed on the CV set). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. If you can fix the style and build issues and handle the comments, I think this is close to ready. Also, if you can add a test to decision_tree_test.cpp
to ensure that the regularization is working, I think that would also be a good idea.
For implementing post-pruning, what shall be the best to implement? Minimum error tree or smallest tree or the weighted sum of the two? (All these will be performed on the CV set).
Let's not go with just one---when you implement that (and I think maybe a separate PR would be a good idea since it will be a big change), you should add a template parameter to the decision tree, so that users can decide which post-pruning algorithm they want.
@@ -533,7 +542,7 @@ void DecisionTree<FitnessFunction, | |||
|
|||
// Look through the list of dimensions and obtain the gain of the best split. | |||
// We'll cache the best numeric and categorical split auxiliary information in | |||
// numericAux and categoricalAux (and clear them later if we make not split), | |||
// numericAux and categoricalAux (and clear them later if we make no split), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, thanks. :)
@@ -39,7 +39,8 @@ PROGRAM_INFO("Decision tree", | |||
"may not be specified when the " + PRINT_PARAM_STRING("training") + " " | |||
"parameter is specified. The " + PRINT_PARAM_STRING("minimum_leaf_size") + | |||
" parameter specifies the minimum number of training points that must fall" | |||
" into each leaf for it to be split. If " + | |||
" into each leaf for it to be split. The " + PRINT_PARAM_STRING("minimum_gain_split") + | |||
" parameter specifies the minimum gain that is needed for the node to split. If " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be useful to add a little intuition here---maybe we should point out that a larger minimum_gain_split
is a form of regularization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, adding this intuition in the header file during function declaration would be better, just like how it is done for other parameters.
@@ -430,6 +436,7 @@ void DecisionTree<FitnessFunction, | |||
const size_t numClasses, | |||
WeightsType&& weights, | |||
const size_t minimumLeafSize, | |||
const double minimumGainSplit = 1e-7, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not actually legal to set a default parameter when it's already been specified in the declaration, so no need for the = 1e-7
here.
@@ -98,6 +102,7 @@ class DecisionTree : | |||
* @param numClasses Number of classes in the dataset. | |||
* @param weights The weight list of given label. | |||
* @param minimumLeafSize Minimum number of points in each leaf node. | |||
* @param minimumGainSplit Minimum Gain for the node to split. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is super picky (sorry!) but you can change Gain
to gain
here---there's no need to capitalize. :)
@@ -62,7 +62,8 @@ class AllCategoricalSplit | |||
const WeightVecType& weights, | |||
const size_t minimumLeafSize, | |||
arma::Col<typename VecType::elem_type>& classProbabilities, | |||
AuxiliarySplitInfo<typename VecType::elem_type>& aux); | |||
AuxiliarySplitInfo<typename VecType::elem_type>& aux, | |||
const double minimumGainSplit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also a bit picky, but we typically order const (input) parameters before output parameters, so I would put minimumGainSplit
before classProbabilities
(and similar changes through the rest of the code).
@rcurtin |
BOOST_REQUIRE_EQUAL(probabilitiesRegularised.n_elem, 3); | ||
} | ||
|
||
BOOST_REQUIRE_GT(count, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this test we need to make sure regularization is working (note that this is different than the main tests, where we only need to make sure that the option is being properly passed to the code), but it looks like this test just checks if the predictions produced by the regularized tree are different than the nonregularized tree.
I think a more effective test might be to make sure that a tree built with a high minimum gain split has fewer nodes in it than a tree with low minimum gain split. What do you think?
We could also test that the accuracy on a training set of a regularized tree is higher, but I am not sure that will always be true so I might be hesitant about that testing strategy.
In fact for this test here, I think it would be worthwhile to adapt and move it to main_tests/decision_tree_test.cpp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it is not a strict test for regularisation. I considered this hypothesis :- Since the unregularised tree gives 100% accuracy on training data, the fact that the count > 0
means that the tree is regularised and it contains less nodes. Had it contained more nodes, it can't have count > 0
. So what you said is somehow indirectly tested. This is the reason I adopted this strategy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, clever. I can agree with that. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, then this should be good to go?
Any views @rcurtin .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but I do still think it would be useful to add a similar test to main_tests/decision_tree_test.cpp
as paranoia just to ensure the regularization is being passed down properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will add it then.
3df9216
to
66c7ac4
Compare
@rcurtin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks good to me. The static code analysis issues don't need to be worried about---that is intentional integer division.
I'll go ahead and merge this in 4 days to leave time for any other feedback.
Thanks, ya i thought so about the static code analysis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, no comment from my side.
Thanks for the contribution! :) I think that this is your first contribution so if you want to add your name to the list of contributors in |
Early stopping is a regularisation method for decision trees i.e. it reduces overfitting. In this method, the splitting of the node is prevented if the information gain is less than the given threshold.
This PR introduces the paramter
minimumGainSplit
to the Decision Tree API for the users to stop the growth of the tree if the gain is less than the threshold.Tests for early stopping and implemetnation of post-pruning (another regularisation technique for decision tree) will be added in the following commits.