[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request mlpack#1313 from rcurtin/misc-test-fixes
Browse files Browse the repository at this point in the history
Miscellaneous test fixes
  • Loading branch information
rcurtin authored Mar 23, 2018
2 parents 793fe4a + c2b7d9a commit 446a821
Show file tree
Hide file tree
Showing 17 changed files with 280 additions and 174 deletions.
47 changes: 38 additions & 9 deletions src/mlpack/core/optimizers/sdp/primal_dual_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,16 @@ PrimalDualSolver<SDPType>::PrimalDualSolver(const SDPType& sdp,
*
* See (2.18) of [AHO98] for more details.
*/
static inline double
Alpha(const arma::mat& A, const arma::mat& dA, double tau)
static inline bool
Alpha(const arma::mat& A, const arma::mat& dA, double tau, double& alpha)
{
const arma::mat L = arma::chol(A, "lower");
const arma::mat Linv = arma::inv(arma::trimatl(L));
arma::mat L;
if (!arma::chol(L, A, "lower"))
return false;

arma::mat Linv;
if (!arma::inv(Linv, arma::trimatl(L)))
return false;
// TODO(stephentu): We only want the top eigenvalue, we should
// be able to do better than full eigen-decomposition.
const arma::vec evals = arma::eig_sym(-Linv * dA * Linv.t());
Expand All @@ -121,7 +126,8 @@ Alpha(const arma::mat& A, const arma::mat& dA, double tau)
if (alphahat < 0.)
// dA is PSD already
alphahat = 1.;
return std::min(1., tau * alphahat);
alpha = std::min(1., tau * alphahat);
return true;
}

/**
Expand Down Expand Up @@ -369,8 +375,21 @@ PrimalDualSolver<SDPType>::Optimize(arma::mat& X,
math::Smat(dsz, dZ);

// Step (2), determine step size lengths (alpha, beta)
alpha = Alpha(X, dX, tau);
beta = Alpha(Z, dZ, tau);
bool success = Alpha(X, dX, tau, alpha);
if (!success)
{
Log::Warn << "PrimalDualSolver::Optimize(): cholesky decomposition of X "
<< "failed! Terminating optimization.";
return primalObj;
}

success = Alpha(Z, dZ, tau, beta);
if (!success)
{
Log::Warn << "PrimalDualSolver::Optimize(): cholesky decomposition of Z "
<< "failed! Terminating optimization.";
return primalObj;
}

// See (7.1)
const double sigma =
Expand All @@ -384,8 +403,18 @@ PrimalDualSolver<SDPType>::Optimize(arma::mat& X,
dsz);
math::Smat(dsx, dX);
math::Smat(dsz, dZ);
alpha = Alpha(X, dX, tau);
beta = Alpha(Z, dZ, tau);
if (!Alpha(X, dX, tau, alpha))
{
Log::Warn << "PrimalDualSolver::Optimize(): cholesky decomposition of Z "
<< "failed! Terminating optimization.";
return primalObj;
}
if (!Alpha(Z, dZ, tau, beta))
{
Log::Warn << "PrimalDualSolver::Optimize(): cholesky decomposition of Z "
<< "failed! Terminating optimization.";
return primalObj;
}

// Iterate update
X += alpha * dX;
Expand Down
25 changes: 22 additions & 3 deletions src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,20 +953,39 @@ void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
delete right;
if (!parent)
delete dataset;

parent = NULL;
left = NULL;
right = NULL;
}

ar & BOOST_SERIALIZATION_NVP(begin);
ar & BOOST_SERIALIZATION_NVP(count);
ar & BOOST_SERIALIZATION_NVP(bound);
ar & BOOST_SERIALIZATION_NVP(stat);
ar & BOOST_SERIALIZATION_NVP(parent);

ar & BOOST_SERIALIZATION_NVP(parentDistance);
ar & BOOST_SERIALIZATION_NVP(furthestDescendantDistance);
ar & BOOST_SERIALIZATION_NVP(dataset);

// Save children last; otherwise boost::serialization gets confused.
ar & BOOST_SERIALIZATION_NVP(left);
ar & BOOST_SERIALIZATION_NVP(right);
bool hasLeft = (left != NULL);
bool hasRight = (right != NULL);

ar & BOOST_SERIALIZATION_NVP(hasLeft);
ar & BOOST_SERIALIZATION_NVP(hasRight);
if (hasLeft)
ar & BOOST_SERIALIZATION_NVP(left);
if (hasRight)
ar & BOOST_SERIALIZATION_NVP(right);

if (Archive::is_loading::value)
{
if (left)
left->parent = this;
if (right)
right->parent = this;
}
}

} // namespace tree
Expand Down
10 changes: 7 additions & 3 deletions src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,8 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree() :
furthestDescendantDistance(0.0),
localMetric(false),
localDataset(false),
metric(NULL)
metric(NULL),
distanceComps(0)
{
// Nothing to do.
}
Expand Down Expand Up @@ -1590,6 +1591,8 @@ void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::serialize(
delete metric;
if (localDataset && dataset)
delete dataset;

parent = NULL;
}

ar & BOOST_SERIALIZATION_NVP(dataset);
Expand All @@ -1599,12 +1602,13 @@ void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::serialize(
ar & BOOST_SERIALIZATION_NVP(stat);
ar & BOOST_SERIALIZATION_NVP(numDescendants);

ar & BOOST_SERIALIZATION_NVP(parent);
bool hasParent = (parent != NULL);
ar & BOOST_SERIALIZATION_NVP(hasParent);
ar & BOOST_SERIALIZATION_NVP(parentDistance);
ar & BOOST_SERIALIZATION_NVP(furthestDescendantDistance);
ar & BOOST_SERIALIZATION_NVP(metric);

if (Archive::is_loading::value && parent == NULL)
if (Archive::is_loading::value && !hasParent)
{
localMetric = true;
localDataset = true;
Expand Down
10 changes: 8 additions & 2 deletions src/mlpack/core/tree/octree/octree_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ void Octree<MetricType, StatisticType, MatType>::serialize(

if (!parent)
delete dataset;

parent = NULL;
}

ar & BOOST_SERIALIZATION_NVP(begin);
Expand All @@ -651,11 +653,15 @@ void Octree<MetricType, StatisticType, MatType>::serialize(
ar & BOOST_SERIALIZATION_NVP(parentDistance);
ar & BOOST_SERIALIZATION_NVP(furthestDescendantDistance);
ar & BOOST_SERIALIZATION_NVP(metric);

ar & BOOST_SERIALIZATION_NVP(parent);
ar & BOOST_SERIALIZATION_NVP(dataset);

ar & BOOST_SERIALIZATION_NVP(children);

if (Archive::is_loading::value)
{
for (size_t i = 0; i < children.size(); ++i)
children[i]->parent = this;
}
}

//! Split the node.
Expand Down
24 changes: 18 additions & 6 deletions src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ RectangleTree() :
parent(NULL),
begin(0),
count(0),
numDescendants(0),
maxLeafSize(0),
minLeafSize(0),
parentDistance(0.0),
Expand Down Expand Up @@ -959,7 +960,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
root = root->Parent();
}
if (stillShrinking)
stillShrinking = root->ShrinkBoundForBound(bound);
root->ShrinkBoundForBound(bound);

root = parent;
while (root != NULL)
Expand All @@ -977,7 +978,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
root = root->Parent();
}
if (stillShrinking)
stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);
root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);

// Reinsert the points at the root node.
for (size_t j = 0; j < count; j++)
Expand Down Expand Up @@ -1020,7 +1021,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
root = root->Parent();
}
if (stillShrinking)
stillShrinking = root->ShrinkBoundForBound(bound);
root->ShrinkBoundForBound(bound);

root = parent;
while (root != NULL)
Expand All @@ -1038,7 +1039,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
root = root->Parent();
}
if (stillShrinking)
stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);
root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);

// Reinsert the nodes at the root node.
for (size_t i = 0; i < numChildren; i++)
Expand Down Expand Up @@ -1261,6 +1262,8 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,

if (ownsDataset && dataset)
delete dataset;

parent = NULL;
}

ar & BOOST_SERIALIZATION_NVP(maxNumChildren);
Expand All @@ -1279,16 +1282,25 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
ar & BOOST_SERIALIZATION_NVP(parentDistance);
ar & BOOST_SERIALIZATION_NVP(dataset);
ar & BOOST_SERIALIZATION_NVP(ownsDataset);
ar & BOOST_SERIALIZATION_NVP(parent);

ar & BOOST_SERIALIZATION_NVP(points);
ar & BOOST_SERIALIZATION_NVP(auxiliaryInfo);

// Since we may or may not be holding children, we need to serialize _only_
// numChildren children.
for (size_t i = 0; i < numChildren; ++i)
{
std::ostringstream oss;
oss << "children" << i;
ar & boost::serialization::make_nvp(oss.str().c_str(), children[i]);

if (Archive::is_loading::value)
children[i]->parent = this;
}
for (size_t i = numChildren; i < maxNumChildren + 1; ++i)
{
children[i] = NULL;
ar & BOOST_SERIALIZATION_NVP(children);
}
}

} // namespace tree
Expand Down
6 changes: 4 additions & 2 deletions src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,10 @@ bool XTreeSplit::SplitNonLeafNode(TreeType *tree, std::vector<bool>& relevels)
bool useMinOverlapSplit = false;
if (tiedOnOverlap)
{
if (overlapBestAreaAxis/areaBestAreaAxis < MAX_OVERLAP)
if (overlapBestAreaAxis / areaBestAreaAxis < MAX_OVERLAP)
{
tree->numDescendants = 0;
tree->bound.Clear();
for (size_t i = 0; i < numChildren; i++)
{
if (i < bestAreaIndexOnBestAxis + tree->MinNumChildren())
Expand All @@ -450,7 +452,7 @@ bool XTreeSplit::SplitNonLeafNode(TreeType *tree, std::vector<bool>& relevels)
}
else
{
if (overlapBestOverlapAxis/areaBestOverlapAxis < MAX_OVERLAP)
if (overlapBestOverlapAxis / areaBestOverlapAxis < MAX_OVERLAP)
{
tree->numDescendants = 0;
tree->bound.Clear();
Expand Down
33 changes: 29 additions & 4 deletions src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,12 @@ void SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
delete right;
if (!parent && localDataset)
delete dataset;

parent = NULL;
left = NULL;
right = NULL;
}

ar & BOOST_SERIALIZATION_NVP(parent);
ar & BOOST_SERIALIZATION_NVP(count);
ar & BOOST_SERIALIZATION_NVP(pointsIndex);
ar & BOOST_SERIALIZATION_NVP(overlappingNode);
Expand All @@ -770,12 +773,34 @@ void SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
ar & BOOST_SERIALIZATION_NVP(furthestDescendantDistance);
ar & BOOST_SERIALIZATION_NVP(dataset);

if (Archive::is_loading::value && parent == NULL)
if (Archive::is_loading::value)
localDataset = true;

// Save children last; otherwise boost::serialization gets confused.
ar & BOOST_SERIALIZATION_NVP(left);
ar & BOOST_SERIALIZATION_NVP(right);
bool hasLeft = (left != NULL);
bool hasRight = (right != NULL);

ar & BOOST_SERIALIZATION_NVP(hasLeft);
ar & BOOST_SERIALIZATION_NVP(hasRight);

if (hasLeft)
ar & BOOST_SERIALIZATION_NVP(left);
if (hasRight)
ar & BOOST_SERIALIZATION_NVP(right);

if (Archive::is_loading::value)
{
if (left)
{
left->parent = this;
left->localDataset = false;
}
if (right)
{
right->parent = this;
right->localDataset = false;
}
}
}

} // namespace tree
Expand Down
15 changes: 13 additions & 2 deletions src/mlpack/methods/det/dtree_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,10 +1016,21 @@ void DTree<MatType, TagType>::serialize(Archive& ar,
delete left;
if (right)
delete right;

left = NULL;
right = NULL;
}

ar & BOOST_SERIALIZATION_NVP(left);
ar & BOOST_SERIALIZATION_NVP(right);
bool hasLeft = (left != NULL);
bool hasRight = (right != NULL);

ar & BOOST_SERIALIZATION_NVP(hasLeft);
ar & BOOST_SERIALIZATION_NVP(hasRight);

if (hasLeft)
ar & BOOST_SERIALIZATION_NVP(left);
if (hasRight)
ar & BOOST_SERIALIZATION_NVP(right);

if (root)
{
Expand Down
2 changes: 1 addition & 1 deletion src/mlpack/methods/hmm/hmm_train_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ static void mlpackMain()
if (CLI::HasParam("input_model") && CLI::HasParam("tolerance"))
{
Log::Info << "Tolerance of existing model in '"
<< CLI::GetPrintableParam<std::string>("input_model") << "' will be "
<< CLI::GetPrintableParam<HMMModel*>("input_model") << "' will be "
<< "replaced with specified tolerance of " << tolerance << "." << endl;
}

Expand Down
Loading

0 comments on commit 446a821

Please sign in to comment.