Skip to content
Snippets Groups Projects
Commit 627c405a authored by Benjamin Halimi's avatar Benjamin Halimi Committed by Benjamin Halimi
Browse files

add the BatchNorm train/test flag support

parent cd840ec6
No related branches found
No related tags found
1 merge request!263Add the BatchNorm train/test flag support
Pipeline #60210 canceled
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace Aidge { namespace Aidge {
enum class BatchNormAttr { Epsilon, Momentum }; enum class BatchNormAttr { Epsilon, Momentum, TrainingMode };
template <DimIdx_t DIM> template <DimIdx_t DIM>
class BatchNorm_Op : public OperatorTensor, class BatchNorm_Op : public OperatorTensor,
...@@ -33,7 +33,7 @@ public: ...@@ -33,7 +33,7 @@ public:
static const std::string Type; static const std::string Type;
private: private:
using Attributes_ = StaticAttributes<BatchNormAttr, float, float>; using Attributes_ = StaticAttributes<BatchNormAttr, float, float, int>;
template <BatchNormAttr e> template <BatchNormAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes; const std::shared_ptr<Attributes_> mAttributes;
...@@ -42,7 +42,7 @@ public: ...@@ -42,7 +42,7 @@ public:
BatchNorm_Op() = delete; BatchNorm_Op() = delete;
constexpr BatchNorm_Op(float epsilon, float momentum) constexpr BatchNorm_Op(float epsilon, float momentum, int trainingMode)
: OperatorTensor(Type, : OperatorTensor(Type,
{InputCategory::Data, {InputCategory::Data,
InputCategory::Param, InputCategory::Param,
...@@ -52,7 +52,9 @@ public: ...@@ -52,7 +52,9 @@ public:
1), 1),
mAttributes(std::make_shared<Attributes_>( mAttributes(std::make_shared<Attributes_>(
attr<BatchNormAttr::Epsilon>(epsilon), attr<BatchNormAttr::Epsilon>(epsilon),
attr<BatchNormAttr::Momentum>(momentum))) {} attr<BatchNormAttr::Momentum>(momentum),
attr<BatchNormAttr::TrainingMode>(trainingMode)
)) {}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
...@@ -84,6 +86,7 @@ public: ...@@ -84,6 +86,7 @@ public:
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
inline float& epsilon() const { return mAttributes->template getAttr<BatchNormAttr::Epsilon>(); } inline float& epsilon() const { return mAttributes->template getAttr<BatchNormAttr::Epsilon>(); }
inline float& momentum() const { return mAttributes->template getAttr<BatchNormAttr::Momentum>(); } inline float& momentum() const { return mAttributes->template getAttr<BatchNormAttr::Momentum>(); }
inline int& trainingMode() const { return mAttributes->template getAttr<BatchNormAttr::TrainingMode>(); }
static const std::vector<std::string> getInputsName() { static const std::vector<std::string> getInputsName() {
return {"data_input", "scale", "shift", "mean", "variance"}; return {"data_input", "scale", "shift", "mean", "variance"};
...@@ -101,16 +104,17 @@ template <DimSize_t DIM> ...@@ -101,16 +104,17 @@ template <DimSize_t DIM>
std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures, std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
const float epsilon = 1.0e-5F, const float epsilon = 1.0e-5F,
const float momentum = 0.1F, const float momentum = 0.1F,
const int trainingMode = 0,
const std::string& name = ""); const std::string& name = "");
} // namespace Aidge } // namespace Aidge
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&); extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const int, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&); extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const int, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&); extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const int, const std::string&);
namespace { namespace {
template <> template <>
const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "epsilon", "momentum" }; const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "epsilon", "momentum", "training_mode" };
} }
#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_ #endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_
...@@ -26,16 +26,17 @@ void declare_BatchNormOp(py::module& m) { ...@@ -26,16 +26,17 @@ void declare_BatchNormOp(py::module& m) {
const std::string pyClassName("BatchNorm" + std::to_string(DIM) + "DOp"); const std::string pyClassName("BatchNorm" + std::to_string(DIM) + "DOp");
py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor>( py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor>(
m, pyClassName.c_str(), py::multiple_inheritance()) m, pyClassName.c_str(), py::multiple_inheritance())
.def(py::init<float, float>(), .def(py::init<float, float, int>(),
py::arg("epsilon"), py::arg("epsilon"),
py::arg("momentum")) py::arg("momentum"),
py::arg("training_mode"))
.def_static("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def_static("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName)
.def_static("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName) .def_static("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName)
.def_readonly_static("Type", &BatchNorm_Op<DIM>::Type); .def_readonly_static("Type", &BatchNorm_Op<DIM>::Type);
declare_registrable<BatchNorm_Op<DIM>>(m, pyClassName); declare_registrable<BatchNorm_Op<DIM>>(m, pyClassName);
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("training_mode") = 0, py::arg("name") = "");
} }
void init_BatchNorm(py::module &m) { void init_BatchNorm(py::module &m) {
......
...@@ -108,9 +108,10 @@ template <Aidge::DimSize_t DIM> ...@@ -108,9 +108,10 @@ template <Aidge::DimSize_t DIM>
inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFeatures, inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFeatures,
const float epsilon, const float epsilon,
const float momentum, const float momentum,
const int trainingMode,
const std::string& name) { const std::string& name) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum, trainingMode), name);
addProducer(batchNorm, 1, {nbFeatures}, "scale"); addProducer(batchNorm, 1, {nbFeatures}, "scale");
addProducer(batchNorm, 2, {nbFeatures}, "shift"); addProducer(batchNorm, 2, {nbFeatures}, "shift");
addProducer(batchNorm, 3, {nbFeatures}, "batch_mean"); addProducer(batchNorm, 3, {nbFeatures}, "batch_mean");
...@@ -118,6 +119,6 @@ inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFe ...@@ -118,6 +119,6 @@ inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFe
return batchNorm; return batchNorm;
} }
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&); template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const int, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&); template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const int, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&); template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const int, const std::string&);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment