forked from mlpack/mlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request mlpack#1275 from Manthan-R-Sheth/batchnorm
Implement BatchNorm Layer.
- Loading branch information
Showing
9 changed files
with
466 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
/** | ||
* @file batch_norm.hpp | ||
* @author Praveen Ch | ||
* @author Manthan-R-Sheth | ||
* | ||
* Definition of the Batch Normalisation layer class | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP | ||
#define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */ { | ||
|
||
/** | ||
* Declaration of the Batch Normalisation layer class. The layer tranforms | ||
* the input data into zero mean and unit variance and then scales and shifts | ||
* the data by parameters, gamma and beta respectively. These parameters are | ||
* learnt by the network. | ||
* | ||
* If deterministic is false (training), the mean and variance over the batch is | ||
* calculated and the data is normalized. If it is set to true (testing) then | ||
* the mean and variance accrued over the training set is used. | ||
* | ||
* For more information, refer to the following paper, | ||
* | ||
* @code | ||
* @article{DBLP:journals/corr/IoffeS15, | ||
* author = {Sergey Ioffe and | ||
* Christian Szegedy}, | ||
* title = {Batch Normalization: Accelerating Deep Network Training by | ||
* Reducing Internal Covariate Shift}, | ||
* journal = {CoRR}, | ||
* volume = {abs/1502.03167} | ||
* } | ||
* | ||
* @endcode | ||
* | ||
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
*/ | ||
|
||
template < | ||
typename InputDataType = arma::mat, | ||
typename OutputDataType = arma::mat | ||
> | ||
class BatchNorm | ||
{ | ||
public: | ||
//! Create the BatchNorm object. | ||
BatchNorm(); | ||
|
||
/** | ||
* Create the BatchNorm layer object for a specified number of input units. | ||
* | ||
* @param size The number of input units. | ||
* @param eps The epsilon added to variance to ensure numerical stability. | ||
*/ | ||
BatchNorm(const size_t size, const double eps = 0.001); | ||
|
||
/** | ||
* Reset the layer parameters | ||
*/ | ||
void Reset(); | ||
|
||
/** | ||
* Forward pass of the Batch Normalization layer. Transforms the input data | ||
* into zero mean and unit variance, scales the data by a factor gamma and | ||
* shifts it by beta. | ||
* | ||
* @param input Input data for the layer | ||
* @param output Resulting output activations. | ||
*/ | ||
template<typename eT> | ||
void Forward(const arma::Mat<eT>&& input, arma::Mat<eT>&& output); | ||
|
||
/** | ||
* Backward pass through the layer. | ||
* | ||
* @param input The input activations | ||
* @param gy The backpropagated error. | ||
* @param g The calculated gradient. | ||
*/ | ||
template<typename eT> | ||
void Backward(const arma::Mat<eT>&& input, | ||
arma::Mat<eT>&& gy, | ||
arma::Mat<eT>&& g); | ||
|
||
/** | ||
* Calculate the gradient using the output delta and the input activations. | ||
* | ||
* @param input The input activations | ||
* @param error The calculated error | ||
* @param gradient The calculated gradient. | ||
*/ | ||
template<typename eT> | ||
void Gradient(const arma::Mat<eT>&& input, | ||
arma::Mat<eT>&& error, | ||
arma::Mat<eT>&& gradient); | ||
|
||
//! Get the parameters. | ||
OutputDataType const& Parameters() const { return weights; } | ||
//! Modify the parameters. | ||
OutputDataType& Parameters() { return weights; } | ||
|
||
//! Get the input parameter. | ||
InputDataType const& InputParameter() const { return inputParameter; } | ||
//! Modify the input parameter. | ||
InputDataType& InputParameter() { return inputParameter; } | ||
|
||
//! Get the output parameter. | ||
OutputDataType const& OutputParameter() const { return outputParameter; } | ||
//! Modify the output parameter. | ||
OutputDataType& OutputParameter() { return outputParameter; } | ||
|
||
//! Get the delta. | ||
OutputDataType const& Delta() const { return delta; } | ||
//! Modify the delta. | ||
OutputDataType& Delta() { return delta; } | ||
|
||
//! Get the gradient. | ||
OutputDataType const& Gradient() const { return gradient; } | ||
//! Modify the gradient. | ||
OutputDataType& Gradient() { return gradient; } | ||
|
||
//! Get the value of deterministic parameter. | ||
bool Deterministic() const { return deterministic; } | ||
//! Modify the value of deterministic parameter. | ||
bool& Deterministic() { return deterministic; } | ||
|
||
//! Get the mean over the training data. | ||
OutputDataType TrainingMean() { return stats.mean(); } | ||
|
||
//! Get the variance over the training data. | ||
OutputDataType TrainingVariance() { return stats.var(1); } | ||
|
||
/** | ||
* Serialize the layer | ||
*/ | ||
template<typename Archive> | ||
void serialize(Archive& ar, const unsigned int /* version */); | ||
|
||
private: | ||
//! Locally-stored number of input units. | ||
size_t size; | ||
|
||
//! Locally-stored epsilon value. | ||
double eps; | ||
|
||
//! Locally-stored scale parameter. | ||
OutputDataType gamma; | ||
|
||
//! Locally-stored shift parameter. | ||
OutputDataType beta; | ||
|
||
//! Locally-stored parameters. | ||
OutputDataType weights; | ||
|
||
/** | ||
* If true then mean and variance over the training set will be considered | ||
* instead of being calculated over the batch. | ||
*/ | ||
bool deterministic; | ||
|
||
//! Locally-stored mean object. | ||
OutputDataType mean; | ||
|
||
//! Locally-stored variance object. | ||
OutputDataType variance; | ||
|
||
//! Locally-stored running statistics object. | ||
arma::running_stat_vec<arma::colvec> stats; | ||
|
||
//! Locally-stored gradient object. | ||
OutputDataType gradient; | ||
|
||
//! Locally-stored delta object. | ||
OutputDataType delta; | ||
|
||
//! Locally-stored input parameter object. | ||
InputDataType inputParameter; | ||
|
||
//! Locally-stored output parameter object. | ||
OutputDataType outputParameter; | ||
}; // class BatchNorm | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
// Include the implementation. | ||
#include "batch_norm_impl.hpp" | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/** | ||
* @file batch_norm_impl.hpp | ||
* @author Praveen Ch | ||
* @author Manthan-R-Sheth | ||
* | ||
* Implementation of the Batch Normalization Layer. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_IMPL_HPP | ||
#define MLPACK_METHODS_ANN_LAYER_BATCHNORM_IMPL_HPP | ||
|
||
// In case it is not included. | ||
#include "batch_norm.hpp" | ||
|
||
namespace mlpack { | ||
namespace ann { /** Artificial Neural Network. */ | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
BatchNorm<InputDataType, OutputDataType>::BatchNorm() : | ||
size(10), | ||
eps(1e-7), | ||
deterministic(false) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template <typename InputDataType, typename OutputDataType> | ||
BatchNorm<InputDataType, OutputDataType>::BatchNorm( | ||
const size_t size, const double eps) : | ||
size(size), | ||
eps(eps), | ||
deterministic(false) | ||
{ | ||
weights.set_size(size + size, 1); | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
void BatchNorm<InputDataType, OutputDataType>::Reset() | ||
{ | ||
gamma = arma::mat(weights.memptr(), size, 1, false, false); | ||
beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1, false, false); | ||
deterministic = false; | ||
gamma.fill(1.0); | ||
beta.fill(0.0); | ||
stats.reset(); | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename eT> | ||
void BatchNorm<InputDataType, OutputDataType>::Forward( | ||
const arma::Mat<eT>&& input, arma::Mat<eT>&& output) | ||
{ | ||
output.reshape(input.n_rows, input.n_cols); | ||
|
||
// Mean and variance over the entire training set will be used to compute | ||
// the forward pass when deterministic is set to true. | ||
if (deterministic) | ||
{ | ||
mean = stats.mean(); | ||
variance = stats.var(1); | ||
} | ||
else | ||
{ | ||
mean = arma::mean(input, 1); | ||
variance = arma::var(input, 1, 1); | ||
|
||
for (size_t i = 0; i < output.n_cols; i++) | ||
{ | ||
stats(input.col(i)); | ||
} | ||
} | ||
|
||
output = input.each_col() - mean; | ||
output.each_col() %= gamma / arma::sqrt(variance + eps); | ||
output.each_col() += beta; | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename eT> | ||
void BatchNorm<InputDataType, OutputDataType>::Backward( | ||
const arma::Mat<eT>&& input, arma::Mat<eT>&& gy, arma::Mat<eT>&& g) | ||
{ | ||
mean = arma::mean(input, 1); | ||
variance = arma::var(input, 1, 1); | ||
|
||
arma::mat m = arma::sum(gy % (input.each_col() - mean), 1); | ||
g = (mean - input.each_col()); | ||
g.each_col() %= m; | ||
g.each_col() %= 1.0/(variance + eps); | ||
g += (gy.each_col() - arma::sum(gy, 1)); | ||
g += (input.n_cols - 1) * gy; | ||
g.each_col() %= ((1.0 / input.n_cols) * gamma); | ||
g.each_col() %= (1.0 / arma::sqrt(variance + eps)); | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename eT> | ||
void BatchNorm<InputDataType, OutputDataType>::Gradient( | ||
const arma::Mat<eT>&& input, | ||
arma::Mat<eT>&& error, | ||
arma::Mat<eT>&& gradient) | ||
{ | ||
gradient.set_size(size + size, 1); | ||
|
||
arma::mat normalized = input.each_col() - arma::mean(input, 1) | ||
/ arma::sqrt(arma::var(input, 1, 1) + eps); | ||
|
||
gradient.submat(0, 0, gamma.n_elem - 1, 0) = arma::sum(normalized % error, 1); | ||
gradient.submat(gamma.n_elem, 0, gradient.n_elem - 1, 0) = | ||
arma::sum(error, 1); | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename Archive> | ||
void BatchNorm<InputDataType, OutputDataType>::serialize( | ||
Archive& ar, const unsigned int /* version */) | ||
{ | ||
ar & BOOST_SERIALIZATION_NVP(gamma); | ||
ar & BOOST_SERIALIZATION_NVP(beta); | ||
} | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.