From c617a2125f4fc814e7c96ab208d943d3778153ac Mon Sep 17 00:00:00 2001
From: bhalimi <benjamin.halimi@cea.fr>
Date: Wed, 27 Nov 2024 13:45:55 +0000
Subject: [PATCH] change the training flag type from int to bool

---
 include/aidge/operator/BatchNorm.hpp         | 14 +++++++-------
 python_binding/operator/pybind_BatchNorm.cpp |  4 ++--
 src/operator/BatchNorm.cpp                   |  8 ++++----
 unit_tests/graph/Test_Matching.cpp           |  4 ++--
 4 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp
index 34366b9b6..8f33380b2 100644
--- a/include/aidge/operator/BatchNorm.hpp
+++ b/include/aidge/operator/BatchNorm.hpp
@@ -33,7 +33,7 @@ public:
     static const std::string Type;
 
 private:
-    using Attributes_ = StaticAttributes<BatchNormAttr, float, float, int>;
+    using Attributes_ = StaticAttributes<BatchNormAttr, float, float, bool>;
     template <BatchNormAttr e>
     using attr = typename Attributes_::template attr<e>;
     const std::shared_ptr<Attributes_> mAttributes;
@@ -42,7 +42,7 @@ public:
 
     BatchNorm_Op() = delete;
 
-    constexpr BatchNorm_Op(float epsilon, float momentum, int trainingMode)
+    constexpr BatchNorm_Op(float epsilon, float momentum, bool trainingMode)
         : OperatorTensor(Type,
                             {InputCategory::Data,
                                 InputCategory::Param,
@@ -86,7 +86,7 @@ public:
     inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
     inline float& epsilon() const { return mAttributes->template getAttr<BatchNormAttr::Epsilon>(); }
     inline float& momentum() const { return mAttributes->template getAttr<BatchNormAttr::Momentum>(); }
-    inline int& trainingMode() const { return mAttributes->template getAttr<BatchNormAttr::TrainingMode>(); }
+    inline bool& trainingMode() const { return mAttributes->template getAttr<BatchNormAttr::TrainingMode>(); }
 
     static const std::vector<std::string> getInputsName() {
         return {"data_input", "scale", "shift", "mean", "variance"};
@@ -104,13 +104,13 @@ template <DimSize_t DIM>
 std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
                                        const float epsilon = 1.0e-5F,
                                        const float momentum = 0.1F,
-                                       const int trainingMode = 0,
+                                       const bool trainingMode = false,
                                        const std::string& name = "");
 }  // namespace Aidge
 
-extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const int, const std::string&);
-extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const int, const std::string&);
-extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const int, const std::string&);
+extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&);
+extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&);
+extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&);
 
 namespace {
 template <>
diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp
index 039147018..c380f5940 100644
--- a/python_binding/operator/pybind_BatchNorm.cpp
+++ b/python_binding/operator/pybind_BatchNorm.cpp
@@ -26,7 +26,7 @@ void declare_BatchNormOp(py::module& m) {
     const std::string pyClassName("BatchNorm" + std::to_string(DIM) + "DOp");
     py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor>(
     m, pyClassName.c_str(), py::multiple_inheritance())
-        .def(py::init<float, float, int>(),
+        .def(py::init<float, float, bool>(),
             py::arg("epsilon"),
             py::arg("momentum"),
             py::arg("training_mode"))
@@ -36,7 +36,7 @@ void declare_BatchNormOp(py::module& m) {
 
     declare_registrable<BatchNorm_Op<DIM>>(m, pyClassName);
 
-    m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("training_mode") = 0, py::arg("name") = "");
+    m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("training_mode") = false, py::arg("name") = "");
 }
 
 void init_BatchNorm(py::module &m) {
diff --git a/src/operator/BatchNorm.cpp b/src/operator/BatchNorm.cpp
index 6a5d8819e..24a49e56c 100644
--- a/src/operator/BatchNorm.cpp
+++ b/src/operator/BatchNorm.cpp
@@ -108,7 +108,7 @@ template <Aidge::DimSize_t DIM>
 inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFeatures,
                                        const float epsilon,
                                        const float momentum,
-                                       const int trainingMode,
+                                       const bool trainingMode,
                                        const std::string& name) {
     static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
     auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum, trainingMode), name);
@@ -119,6 +119,6 @@ inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFe
     return batchNorm;
 }
 
-template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const int, const std::string&);
-template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const int, const std::string&);
-template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const int, const std::string&);
+template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&);
+template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&);
+template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&);
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index ce454c409..582c73565 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -352,11 +352,11 @@ TEST_CASE("[core/graph] Matching") {
     auto g2 = Sequential({
         Producer({16, 3, 512, 512}, "dataProvider"),
         Conv(3, 4, {5, 5}, "conv1"),
-        BatchNorm<2>(4, 1.0e-5, 0.1, 0, "bn1"),
+        BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn1"),
         Conv(4, 4, {5, 5}, "conv2"),
         ReLU("relu2"),
         Conv(4, 4, {5, 5}, "conv3"),
-        BatchNorm<2>(4, 1.0e-5, 0.1, 0, "bn3"),
+        BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"),
         FC(4, 4, false, "fc1"),
         FC(4, 4, false, "fc2"),
         FC(4, 4, false, "fc3"),
-- 
GitLab