[go: nahoru, domu]

Skip to content

Commit

Permalink
Correct forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
chvsp authored and Manthan-R-Sheth committed Mar 16, 2018
1 parent f8533ac commit a70abeb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
17 changes: 14 additions & 3 deletions src/mlpack/methods/ann/layer/batchnorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BatchNorm

BatchNorm();

BatchNorm(const size_t size);
BatchNorm(const size_t size, const double eps);

void Reset();

Expand Down Expand Up @@ -71,6 +71,15 @@ class BatchNorm
//! Modify the gradient.
OutputDataType& Gradient() { return gradient; }

OutputDataType& Mean() { return mean; }

OutputDataType& Variance() { return variance; }

OutputDataType& Gamma() { return gamma; }

OutputDataType& Beta() { return beta; }



template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);
Expand All @@ -91,11 +100,13 @@ class BatchNorm

OutputDataType mean;

OutputDataType trainingMean;
// OutputDataType trainingMean;

OutputDataType variance;

OutputDataType trainingVariance;
// OutputDataType trainingVariance;

arma::running_stat_vec<arma::colvec> stats;

OutputDataType gradient;

Expand Down
37 changes: 27 additions & 10 deletions src/mlpack/methods/ann/layer/batchnorm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ BatchNorm<InputDataType, OutputDataType>::BatchNorm()

template <typename InputDataType, typename OutputDataType>
BatchNorm<InputDataType, OutputDataType>::BatchNorm(
const size_t size) : size(size)
const size_t size, const double eps = 0.001) : size(size), eps(eps)
{
weights.set_size(size + size, 1);
}
Expand All @@ -36,25 +36,42 @@ BatchNorm<InputDataType, OutputDataType>::BatchNorm(
template<typename InputDataType, typename OutputDataType>
void BatchNorm<InputDataType, OutputDataType>::Reset()
{
// variance = arma::mat(variance.memptr(), size, 1, false, false);
// mean = arma::mat(mean.memptr(), size, 1, false, false);
gamma = arma::mat(weights.memptr(), size, 1, false, false);
beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1, false, false);
deterministic = false;
gamma.fill(1.0);
beta.fill(0.0);
stats.reset();
}

template<typename InputDataType, typename OutputDataType>
template<typename eT>
void BatchNorm<InputDataType, OutputDataType>::Forward(
const arma::Mat<eT>&& input, arma::Mat<eT>&& output)
{
if(!deterministic)
// if(!deterministic)
// {
// mean = arma::mean(input, 1);
// variance = arma::var(input, 1, 1);
// }

// output = beta + (gamma % (input - mean)) / arma::sqrt(variance + eps);
output.reshape(input.n_rows, input.n_cols);

for (size_t i = 0; i < output.n_rows; i++)
{
mean = arma::mean(input, 1);
variance = arma::var(input, 1, 1);
}
arma::mat inpRow = input.row(i);

output.row(i) = (inpRow - arma::as_scalar(arma::mean(inpRow,1)));

output = beta + (gamma % (input - mean)) / arma::sqrt(variance + eps);
output.row(i) /= (arma::as_scalar(arma::sqrt(arma::var(inpRow,1, 1)) +
arma::as_scalar(eps)));

output.row(i) *= arma::as_scalar(gamma.row(i));

output.row(i) += arma::as_scalar(beta.row(i));

}
}

template<typename InputDataType, typename OutputDataType>
Expand Down Expand Up @@ -86,8 +103,8 @@ void BatchNorm<InputDataType, OutputDataType>::Serialize(
{
ar & data::CreateNVP(gamma, "gamma");
ar & data::CreateNVP(beta, "beta");
ar & data::CreateNVP(trainingMean, "trainingMean");
ar & data::CreateNVP(trainingVariance, "trainingVariance");
ar & data::CreateNVP(stats.mean(), "trainingMean");
ar & data::CreateNVP(stats.var(1), "trainingVariance");
}

} // namespace ann
Expand Down

0 comments on commit a70abeb

Please sign in to comment.