From 7fe7f6d914b0cf1547858cd22e9b8078b60d1a30 Mon Sep 17 00:00:00 2001 From: Benjamin Halimi <benjamin.halimi@cea.fr> Date: Wed, 27 Nov 2024 14:19:54 +0000 Subject: [PATCH] Add the BatchNorm train/test flag support --- include/aidge/operator/BatchNorm.hpp | 20 ++++++++++++-------- python_binding/operator/pybind_BatchNorm.cpp | 7 ++++--- src/operator/BatchNorm.cpp | 9 +++++---- unit_tests/graph/Test_Matching.cpp | 4 ++-- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index cdac7935f..8f33380b2 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -24,7 +24,7 @@ namespace Aidge { -enum class BatchNormAttr { Epsilon, Momentum }; +enum class BatchNormAttr { Epsilon, Momentum, TrainingMode }; template <DimIdx_t DIM> class BatchNorm_Op : public OperatorTensor, @@ -33,7 +33,7 @@ public: static const std::string Type; private: - using Attributes_ = StaticAttributes<BatchNormAttr, float, float>; + using Attributes_ = StaticAttributes<BatchNormAttr, float, float, bool>; template <BatchNormAttr e> using attr = typename Attributes_::template attr<e>; const std::shared_ptr<Attributes_> mAttributes; @@ -42,7 +42,7 @@ public: BatchNorm_Op() = delete; - constexpr BatchNorm_Op(float epsilon, float momentum) + constexpr BatchNorm_Op(float epsilon, float momentum, bool trainingMode) : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, @@ -52,7 +52,9 @@ public: 1), mAttributes(std::make_shared<Attributes_>( attr<BatchNormAttr::Epsilon>(epsilon), - attr<BatchNormAttr::Momentum>(momentum))) {} + attr<BatchNormAttr::Momentum>(momentum), + attr<BatchNormAttr::TrainingMode>(trainingMode) + )) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -84,6 +86,7 @@ public: inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } inline float& epsilon() const { return mAttributes->template getAttr<BatchNormAttr::Epsilon>(); } inline float& momentum() const { return mAttributes->template getAttr<BatchNormAttr::Momentum>(); } + inline bool& trainingMode() const { return mAttributes->template getAttr<BatchNormAttr::TrainingMode>(); } static const std::vector<std::string> getInputsName() { return {"data_input", "scale", "shift", "mean", "variance"}; @@ -101,16 +104,17 @@ template <DimSize_t DIM> std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures, const float epsilon = 1.0e-5F, const float momentum = 0.1F, + const bool trainingMode = false, const std::string& name = ""); } // namespace Aidge -extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&); -extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&); -extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&); +extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&); +extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&); +extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&); namespace { template <> -const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "epsilon", "momentum" }; +const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "epsilon", "momentum", "training_mode" }; } #endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_ diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index 43b44eb73..c380f5940 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -26,16 +26,17 @@ void declare_BatchNormOp(py::module& m) { const std::string pyClassName("BatchNorm" + std::to_string(DIM) + "DOp"); py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor>( m, pyClassName.c_str(), py::multiple_inheritance()) - .def(py::init<float, float>(), + .def(py::init<float, float, bool>(), py::arg("epsilon"), - py::arg("momentum")) + py::arg("momentum"), + py::arg("training_mode")) .def_static("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def_static("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName) .def_readonly_static("Type", &BatchNorm_Op<DIM>::Type); declare_registrable<BatchNorm_Op<DIM>>(m, pyClassName); - m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("training_mode") = false, py::arg("name") = ""); } void init_BatchNorm(py::module &m) { diff --git a/src/operator/BatchNorm.cpp b/src/operator/BatchNorm.cpp index b18be5287..24a49e56c 100644 --- a/src/operator/BatchNorm.cpp +++ b/src/operator/BatchNorm.cpp @@ -108,9 +108,10 @@ template <Aidge::DimSize_t DIM> inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFeatures, const float epsilon, const float momentum, + const bool trainingMode, const std::string& name) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); - auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); + auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum, trainingMode), name); addProducer(batchNorm, 1, {nbFeatures}, "scale"); addProducer(batchNorm, 2, {nbFeatures}, "shift"); addProducer(batchNorm, 3, {nbFeatures}, "batch_mean"); @@ -118,6 +119,6 @@ inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFe return batchNorm; } -template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&); -template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&); -template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&); +template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&); +template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&); +template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&); diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 8c5fa222a..582c73565 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -352,11 +352,11 @@ TEST_CASE("[core/graph] Matching") { auto g2 = Sequential({ Producer({16, 3, 512, 512}, "dataProvider"), Conv(3, 4, {5, 5}, "conv1"), - BatchNorm<2>(4, 1.0e-5, 0.1, "bn1"), + BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn1"), Conv(4, 4, {5, 5}, "conv2"), ReLU("relu2"), Conv(4, 4, {5, 5}, "conv3"), - BatchNorm<2>(4, 1.0e-5, 0.1, "bn3"), + BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"), FC(4, 4, false, "fc1"), FC(4, 4, false, "fc2"), FC(4, 4, false, "fc3"), -- GitLab