diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp
index cdac7935f6ded752201c04b2dda6cfb9e06438ec..8f33380b29a1509c75721994f16139b9f1a9e20a 100644
--- a/include/aidge/operator/BatchNorm.hpp
+++ b/include/aidge/operator/BatchNorm.hpp
@@ -24,7 +24,7 @@
 
 namespace Aidge {
 
-enum class BatchNormAttr { Epsilon, Momentum };
+enum class BatchNormAttr { Epsilon, Momentum, TrainingMode };
 
 template <DimIdx_t DIM>
 class BatchNorm_Op : public OperatorTensor,
@@ -33,7 +33,7 @@ public:
     static const std::string Type;
 
 private:
-    using Attributes_ = StaticAttributes<BatchNormAttr, float, float>;
+    using Attributes_ = StaticAttributes<BatchNormAttr, float, float, bool>;
     template <BatchNormAttr e>
     using attr = typename Attributes_::template attr<e>;
     const std::shared_ptr<Attributes_> mAttributes;
@@ -42,7 +42,7 @@ public:
 
     BatchNorm_Op() = delete;
 
-    constexpr BatchNorm_Op(float epsilon, float momentum)
+    constexpr BatchNorm_Op(float epsilon, float momentum, bool trainingMode)
         : OperatorTensor(Type,
                             {InputCategory::Data,
                                 InputCategory::Param,
@@ -52,7 +52,9 @@ public:
                             1),
           mAttributes(std::make_shared<Attributes_>(
             attr<BatchNormAttr::Epsilon>(epsilon),
-            attr<BatchNormAttr::Momentum>(momentum))) {}
+            attr<BatchNormAttr::Momentum>(momentum),
+            attr<BatchNormAttr::TrainingMode>(trainingMode)
+            )) {}
 
     /**
      * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
@@ -84,6 +86,7 @@ public:
     inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
     inline float& epsilon() const { return mAttributes->template getAttr<BatchNormAttr::Epsilon>(); }
     inline float& momentum() const { return mAttributes->template getAttr<BatchNormAttr::Momentum>(); }
+    inline bool& trainingMode() const { return mAttributes->template getAttr<BatchNormAttr::TrainingMode>(); }
 
     static const std::vector<std::string> getInputsName() {
         return {"data_input", "scale", "shift", "mean", "variance"};
@@ -101,16 +104,17 @@ template <DimSize_t DIM>
 std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
                                        const float epsilon = 1.0e-5F,
                                        const float momentum = 0.1F,
+                                       const bool trainingMode = false,
                                        const std::string& name = "");
 }  // namespace Aidge
 
-extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&);
-extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&);
-extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&);
+extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&);
+extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&);
+extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&);
 
 namespace {
 template <>
-const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "epsilon", "momentum" };
+const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "epsilon", "momentum", "training_mode" };
 }
 
 #endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_
diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp
index 43b44eb7300072e501d33829b88537850beef37a..c380f594038fe7f77d1c01a4e9c8285ccf5ad85a 100644
--- a/python_binding/operator/pybind_BatchNorm.cpp
+++ b/python_binding/operator/pybind_BatchNorm.cpp
@@ -26,16 +26,17 @@ void declare_BatchNormOp(py::module& m) {
     const std::string pyClassName("BatchNorm" + std::to_string(DIM) + "DOp");
     py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor>(
     m, pyClassName.c_str(), py::multiple_inheritance())
-        .def(py::init<float, float>(),
+        .def(py::init<float, float, bool>(),
             py::arg("epsilon"),
-            py::arg("momentum"))
+            py::arg("momentum"),
+            py::arg("training_mode"))
         .def_static("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName)
         .def_static("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName)
         .def_readonly_static("Type", &BatchNorm_Op<DIM>::Type);
 
     declare_registrable<BatchNorm_Op<DIM>>(m, pyClassName);
 
-    m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = "");
+    m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("training_mode") = false, py::arg("name") = "");
 }
 
 void init_BatchNorm(py::module &m) {
diff --git a/src/operator/BatchNorm.cpp b/src/operator/BatchNorm.cpp
index b18be528795ccf470d7503ef1a915b6b66dc255c..24a49e56c755ef46404881511cfd6af89628e251 100644
--- a/src/operator/BatchNorm.cpp
+++ b/src/operator/BatchNorm.cpp
@@ -108,9 +108,10 @@ template <Aidge::DimSize_t DIM>
 inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFeatures,
                                        const float epsilon,
                                        const float momentum,
+                                       const bool trainingMode,
                                        const std::string& name) {
     static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
-    auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
+    auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum, trainingMode), name);
     addProducer(batchNorm, 1, {nbFeatures}, "scale");
     addProducer(batchNorm, 2, {nbFeatures}, "shift");
     addProducer(batchNorm, 3, {nbFeatures}, "batch_mean");
@@ -118,6 +119,6 @@ inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFe
     return batchNorm;
 }
 
-template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&);
-template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&);
-template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&);
+template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&);
+template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&);
+template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&);
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index 8c5fa222a68a7f2eed329be7c49ca62d0d7ba52f..582c73565a4ef7bfc96e493e1e6029b1683676ab 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -352,11 +352,11 @@ TEST_CASE("[core/graph] Matching") {
     auto g2 = Sequential({
         Producer({16, 3, 512, 512}, "dataProvider"),
         Conv(3, 4, {5, 5}, "conv1"),
-        BatchNorm<2>(4, 1.0e-5, 0.1, "bn1"),
+        BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn1"),
         Conv(4, 4, {5, 5}, "conv2"),
         ReLU("relu2"),
         Conv(4, 4, {5, 5}, "conv3"),
-        BatchNorm<2>(4, 1.0e-5, 0.1, "bn3"),
+        BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"),
         FC(4, 4, false, "fc1"),
         FC(4, 4, false, "fc2"),
         FC(4, 4, false, "fc3"),