[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request mlpack#1310 from yamidark/fix_mean_shift_bug
Browse files Browse the repository at this point in the history
Fix bug in mean shift when max_iterations is too low w.r.t size of data
  • Loading branch information
rcurtin authored Mar 26, 2018
2 parents ff6f638 + a77b0b8 commit 789c194
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
7 changes: 6 additions & 1 deletion src/mlpack/methods/mean_shift/mean_shift.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ namespace meanshift /** Mean shift clustering. */ {
* extern arma::mat data; // Dataset we want to run mean shift on.
* arma::Row<size_t> assignments; // Cluster assignments.
* arma::mat centroids; // Cluster centroids.
* bool forceConvergence = true; // Flag whether to force each centroid seed
* to converge regardless of maxIterations.
*
* MeanShift<> meanShift();
* meanShift.Cluster(dataset, assignments, centroids);
* meanShift.Cluster(dataset, assignments, centroids, forceConvergence);
* @endcode
*
* @tparam UseKernel Use kernel or mean to calculate new centroid.
Expand Down Expand Up @@ -80,10 +82,13 @@ class MeanShift
* @param data Dataset to cluster.
* @param assignments Vector to store cluster assignments in.
* @param centroids Matrix in which centroids are stored.
* @param forceConvergence Flag whether to force each centroid seed to
* converge regardless of maxIterations.
*/
void Cluster(const MatType& data,
arma::Row<size_t>& assignments,
arma::mat& centroids,
bool forceConvergence = true,
bool useSeeds = true);

//! Get the maximum number of iterations.
Expand Down
41 changes: 33 additions & 8 deletions src/mlpack/methods/mean_shift/mean_shift_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
const MatType& data,
arma::Row<size_t>& assignments,
arma::mat& centroids,
bool forceConvergence,
bool useSeeds)
{
if (radius <= 0)
Expand Down Expand Up @@ -216,8 +217,8 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
{
// Initial centroid is the seed itself.
allCentroids.col(i) = pSeeds->unsafe_col(i);
for (size_t completedIterations = 0; completedIterations < maxIterations;
completedIterations++)
for (size_t completedIterations = 0; completedIterations < maxIterations
|| forceConvergence; completedIterations++)
{
// Store new centroid in this.
arma::colvec newCentroid = arma::zeros<arma::colvec>(pSeeds->n_rows);
Expand Down Expand Up @@ -260,12 +261,36 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
}
}

// Assign centroids to each point.
neighbor::KNN neighborSearcher(centroids);
arma::mat neighborDistances;
arma::Mat<size_t> resultingNeighbors;
neighborSearcher.Search(data, 1, resultingNeighbors, neighborDistances);
assignments = resultingNeighbors;
// If no centroid has converged due to too little iterations and without
// forcing convergence, take 1 random centroid calculated.
if (centroids.empty())
{
Log::Warn << "No clusters converge, setting 1 random centroid calculated. "
"Try a larger max_iterations or pass force_convergence flag." << std::endl;

if (maxIterations == 0)
{
centroids.insert_cols(centroids.n_cols, data.col(0));
}
else
{
centroids.insert_cols(centroids.n_cols, allCentroids.col(0));
}
assignments.zeros();
}
else if (centroids.n_cols == 1)
{
assignments.zeros();
}
else
{
// Assign centroids to each point.
neighbor::KNN neighborSearcher(centroids);
arma::mat neighborDistances;
arma::Mat<size_t> resultingNeighbors;
neighborSearcher.Search(data, 1, resultingNeighbors, neighborDistances);
assignments = resultingNeighbors;
}
}

} // namespace meanshift
Expand Down
6 changes: 5 additions & 1 deletion src/mlpack/methods/mean_shift/mean_shift_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ PARAM_FLAG("in_place", "If specified, a column containing the learned cluster "
"--output_file is overridden. (Do not use with Python.)", "P");
PARAM_FLAG("labels_only", "If specified, only the output labels will be "
"written to the file specified by --output_file.", "l");
PARAM_FLAG("force_convergence", "If specified, the mean shift algorithm will "
"continue running regardless of max_iterations until the clusters converge."
, "f");
PARAM_MATRIX_OUT("output", "Matrix to write output labels or labeled data to.",
"o");
PARAM_MATRIX_OUT("centroid", "If specified, the centroids of each cluster will "
Expand Down Expand Up @@ -89,7 +92,8 @@ static void mlpackMain()

Timer::Start("clustering");
Log::Info << "Performing mean shift clustering..." << endl;
meanShift.Cluster(dataset, assignments, centroids);
meanShift.Cluster(dataset, assignments, centroids,
CLI::HasParam("force_convergence"));
Timer::Stop("clustering");

Log::Info << "Found " << centroids.n_cols << " centroids." << endl;
Expand Down
35 changes: 35 additions & 0 deletions src/mlpack/tests/main_tests/mean_shift_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,41 @@ BOOST_AUTO_TEST_CASE(MeanShiftInPlaceTest)
BOOST_REQUIRE_EQUAL(CLI::GetParam<arma::mat>("output").n_cols, numCols);
}

/**
* Ensure that force_convergence is used by testing that the
* force_convergence flag makes a difference in the program.
*/
BOOST_AUTO_TEST_CASE(MeanShiftForceConvergenceTest)
{
arma::mat x;
if (!data::Load("iris_test.csv", x))
BOOST_FAIL("Cannot load test dataset iris_test.csv!");

// Input random data points.
SetInputParam("input", x);
// Set a very small max_iterations.
SetInputParam("max_iterations", (int) 1);

mlpackMain();

const int numCentroids1 = CLI::GetParam<arma::mat>("centroid").n_cols;

ResetSettings();

// Input same random data points.
SetInputParam("input", std::move(x));
// Set the same small max_iterations.
SetInputParam("max_iterations", (int) 1);
// Set the force_convergence flag on.
SetInputParam("force_convergence", true);

mlpackMain();

const int numCentroids2 = CLI::GetParam<arma::mat>("centroid").n_cols;
// Resulting number of centroids should be different.
BOOST_REQUIRE_NE(numCentroids1, numCentroids2);
}

/**
* Ensure that radius is used by testing that the radius
* makes a difference in the program.
Expand Down

0 comments on commit 789c194

Please sign in to comment.