diff --git a/.gitignore b/.gitignore
index d0ed31a39b0d6b71d209dae5105bba2f49b2d640..c8c3ec7486c62dcf085e91245e765a4266589596 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,7 @@ install*/
 __pycache__
 *.pyc
 *.egg-info
+wheelhouse/*
 
 # Mermaid
 *.mmd
diff --git a/include/aidge/learning/optimizer/Adam.hpp b/include/aidge/learning/optimizer/Adam.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..b5a1f01e84ed279ff7963c1179cd9d207fc4dca8
--- /dev/null
+++ b/include/aidge/learning/optimizer/Adam.hpp
@@ -0,0 +1,131 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CORE_OPTIMIZER_ADAM_H_
+#define AIDGE_CORE_OPTIMIZER_ADAM_H_
+
+#include <functional>
+#include <memory>
+#include <vector>
+#include <cmath>  // std::sqrt, std::pow
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/learning/optimizer/Optimizer.hpp"
+#include "aidge/utils/StaticAttributes.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+namespace Aidge {
+
+enum class AdamAttr {
+    Beta1,
+    Beta2,
+    Epsilon
+};
+
+class Adam: public Optimizer, public StaticAttributes<AdamAttr, float, float, float> {
+private:
+    std::vector<Tensor> mMomentum1;
+    std::vector<Tensor> mMomentum2;
+    Tensor mLR{std::vector<std::size_t>({1})};
+    Tensor mBeta1{std::vector<std::size_t>({1})};
+    Tensor mReversedBeta1{std::vector<std::size_t>({1})};
+    Tensor mBeta2{std::vector<std::size_t>({1})};
+    Tensor mReversedBeta2{std::vector<std::size_t>({1})};
+    Tensor mEpsilon{std::vector<std::size_t>({1})};
+
+public:
+    using Attributes_ = StaticAttributes<AdamAttr, float, float, float>;
+    template <AdamAttr e>
+    using attr = typename Attributes_::template attr<e>;
+
+    Adam(const float beta1 = 0.9f, const float beta2 = 0.999f, const float epsilon = 1.0e-8f)
+        : Optimizer(),
+          Attributes_(attr<AdamAttr::Beta1>(beta1),
+                      attr<AdamAttr::Beta2>(beta2),
+                      attr<AdamAttr::Epsilon>(epsilon))
+    {
+        mBeta1.setBackend("cpu");
+        mBeta1.set<float>(0, beta1);
+        mReversedBeta1.setBackend("cpu");
+        mReversedBeta1.set<float>(0, 1.0f - beta1);
+		
+        mBeta2.setBackend("cpu");
+        mBeta2.set<float>(0, beta2);
+        mReversedBeta2.setBackend("cpu");
+        mReversedBeta2.set<float>(0, 1.0f - beta2);
+		
+        mEpsilon.setBackend("cpu");
+        mEpsilon.set<float>(0, epsilon);
+    }
+
+    void update() override final {		
+        mLR.setBackend(mParameters[0]->getImpl()->backend());
+        mLR.set<float>(0, learningRate());
+        if (mParameters[0]->getImpl()->backend() != mBeta1.getImpl()->backend()) {
+            mBeta1.setBackend(mParameters[0]->getImpl()->backend());
+            mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend());
+            mBeta2.setBackend(mParameters[0]->getImpl()->backend());
+            mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
+        }
+		
+        Tensor alpha{std::vector<std::size_t>({1})};
+        alpha.setBackend(mParameters[0]->getImpl()->backend());
+        alpha.set<float>(0, learningRate() * std::sqrt(1.0f - std::pow(mBeta2.get<float>(0), mLRScheduler.step() + 1))
+                                           / (1.0f - std::pow(mBeta1.get<float>(0), mLRScheduler.step() + 1)));
+
+        Tensor epsilon{std::vector<std::size_t>({1})};
+        epsilon.setBackend(mParameters[0]->getImpl()->backend());
+        epsilon.set<float>(0, mEpsilon.get<float>(0) * std::sqrt(1.0f - std::pow(mBeta2.get<float>(0), mLRScheduler.step() + 1)));
+		
+        if (mLRScheduler.step() == 0) {
+            for (std::size_t i = 0; i < mParameters.size(); ++i) {
+                mMomentum1[i].setBackend(mParameters[i]->getImpl()->backend());
+                mMomentum1[i].setDataType(mParameters[i]->grad()->dataType());
+                mMomentum1[i].zeros();
+                mMomentum2[i].setBackend(mParameters[i]->getImpl()->backend());
+                mMomentum2[i].setDataType(mParameters[i]->grad()->dataType());
+                mMomentum2[i].zeros();
+            }
+        }
+		
+        for (std::size_t i = 0; i < mParameters.size(); ++i) {
+            mMomentum1[i] = mBeta1 * mMomentum1[i] + mReversedBeta1 * (*mParameters[i]->grad());
+            mMomentum2[i] = mBeta2 * mMomentum2[i] + mReversedBeta2 * (*mParameters[i]->grad()) * (*mParameters[i]->grad());
+            *mParameters[i] = *mParameters[i] - alpha * mMomentum1[i] / (mMomentum2[i].sqrt() +  epsilon);
+        }
+        
+        mLRScheduler.update();
+    }
+
+    void setParameters(const std::vector<std::shared_ptr<Tensor>>& parameters) override final {
+        Optimizer::setParameters(parameters);
+        mMomentum1 = std::vector<Tensor>(parameters.size());
+        mMomentum2 = std::vector<Tensor>(parameters.size());
+        for (std::size_t i = 0; i < parameters.size(); ++i) {
+            mMomentum1[i] = Tensor(parameters[i]->dims());
+            mMomentum2[i] = Tensor(parameters[i]->dims());
+        }
+    }
+};
+
+} // namespace Aidge
+
+
+namespace {
+template <>
+const char *const EnumStrings<Aidge::AdamAttr>::data[] = {
+    "Beta1",
+    "Beta2",
+    "Epsilon"
+};
+}
+#endif // AIDGE_CORE_OPTIMIZER_ADAM_H_
diff --git a/include/aidge/learning/optimizer/Optimizer.hpp b/include/aidge/learning/optimizer/Optimizer.hpp
index 195d64965d3ba4eb89c9c4d0ca2155cb719f76f3..83ba3f37f35f608c416dc8750a25c8b226fac8bf 100644
--- a/include/aidge/learning/optimizer/Optimizer.hpp
+++ b/include/aidge/learning/optimizer/Optimizer.hpp
@@ -48,9 +48,9 @@ public:
 
     virtual void setParameters(const std::vector<std::shared_ptr<Tensor>>& parameters) {
         mParameters = parameters;
-        for (const auto& param : parameters) {
-            param->initGrad(); // create gradient and set it to zeros
-        }
+        // for (const auto& param : parameters) {
+        //     param->initGrad(); // create gradient and set it to zeros
+        // }
     }
 
     constexpr float learningRate() const noexcept {
diff --git a/include/aidge/loss/LossList.hpp b/include/aidge/loss/LossList.hpp
index 5a0241d9816becbaace75185e796c5ec7c787e89..17a51a8dcb65ef8cf6132577605735ab96608478 100644
--- a/include/aidge/loss/LossList.hpp
+++ b/include/aidge/loss/LossList.hpp
@@ -31,6 +31,8 @@ namespace loss {
  */
 Tensor MSE(std::shared_ptr<Tensor>& prediction,
            const std::shared_ptr<Tensor>& target);
+Tensor BCE(std::shared_ptr<Tensor>& prediction,
+           const std::shared_ptr<Tensor>& target);
 
 }  // namespace loss
 }  // namespace Aidge
diff --git a/python_binding/learning/loss/pybind_Loss.cpp b/python_binding/learning/loss/pybind_Loss.cpp
index 5e3c3af23cb81effc87888f91ac108f8b1cfd61a..3b975747412aee077bece7d3d565524a496340a6 100644
--- a/python_binding/learning/loss/pybind_Loss.cpp
+++ b/python_binding/learning/loss/pybind_Loss.cpp
@@ -23,5 +23,6 @@ void init_Loss(py::module &m) {
     auto m_loss =
         m.def_submodule("loss", "Submodule dedicated to loss functions");
     m_loss.def("MSE", &loss::MSE, py::arg("graph"), py::arg("target"));
+    m_loss.def("BCE", &loss::BCE, py::arg("graph"), py::arg("target"));
 }
 }  // namespace Aidge
diff --git a/python_binding/learning/optimizer/pybind_Adam.cpp b/python_binding/learning/optimizer/pybind_Adam.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..03e8de97b06f40a7430294e6e55509178697450a
--- /dev/null
+++ b/python_binding/learning/optimizer/pybind_Adam.cpp
@@ -0,0 +1,27 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <pybind11/pybind11.h>
+
+#include "aidge/learning/optimizer/Optimizer.hpp"
+#include "aidge/learning/optimizer/Adam.hpp"
+
+namespace py = pybind11;
+namespace Aidge {
+// namespace learning {
+
+void init_Adam(py::module& m) {
+    py::class_<Adam, std::shared_ptr<Adam>, Attributes, Optimizer>(m, "Adam", py::multiple_inheritance())
+    .def(py::init<float, float, float>(), py::arg("beta1") = 0.9f, py::arg("beta2") = 0.999f, py::arg("epsilon") = 1.0e-8f)
+    .def("update", &Adam::update);
+}
+// }  // namespace learning
+}  // namespace Aidge
diff --git a/python_binding/pybind_learning.cpp b/python_binding/pybind_learning.cpp
index 3b4a16ceffb0db7bd7e1d407bcef5d5df830cb2f..c0566dd1bd4bcfc32977ad3372018d00a9c54259 100644
--- a/python_binding/pybind_learning.cpp
+++ b/python_binding/pybind_learning.cpp
@@ -19,12 +19,14 @@ namespace Aidge {
 void init_Loss(py::module&);
 void init_Optimizer(py::module&);
 void init_SGD(py::module&);
+void init_Adam(py::module&);
 void init_LRScheduler(py::module&);
 
 void init_Aidge(py::module& m) {
     init_Loss(m);
     init_Optimizer(m);
     init_SGD(m);
+    init_Adam(m);
 
     init_LRScheduler(m);
 }
diff --git a/src/loss/classification/BCE.cpp b/src/loss/classification/BCE.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d5156072e9aeff84470fc60a4efb7571de81483b
--- /dev/null
+++ b/src/loss/classification/BCE.cpp
@@ -0,0 +1,157 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <memory>
+#include <numeric>  // std::iota
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/graph/OpArgs.hpp"
+#include "aidge/loss/LossList.hpp"
+#include "aidge/recipes/GraphViewHelper.hpp"
+#include "aidge/scheduler/Scheduler.hpp"
+#include "aidge/scheduler/SequentialScheduler.hpp"
+
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/Add.hpp"
+#include "aidge/operator/Sub.hpp"
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/Div.hpp"
+#include "aidge/operator/Ln.hpp"
+#include "aidge/operator/ReduceMean.hpp"
+
+#include "aidge/backend/cpu/operator/AddImpl.hpp"
+#include "aidge/backend/cpu/operator/SubImpl.hpp"
+#include "aidge/backend/cpu/operator/MulImpl.hpp"
+#include "aidge/backend/cpu/operator/DivImpl.hpp"
+#include "aidge/backend/cpu/operator/LnImpl.hpp"
+#include "aidge/backend/cpu/operator/ReduceMeanImpl.hpp"
+
+
+Aidge::Tensor Aidge::loss::BCE(std::shared_ptr<Tensor>& prediction,
+                               const std::shared_ptr<Tensor>& target) {
+    /*
+	Binay Cross Entropy (BCE) loss function
+
+    Implementation note:
+    loss function is computed using a graph in order to not be backend dependant.
+    */
+
+    AIDGE_ASSERT(target->dims().size() == 2,
+                 "Label must have two dims: [BatchSize, NbChannel]");
+    AIDGE_ASSERT(prediction->backend() == target->backend(),
+                 "'prediction' and 'target' Tensors must be on the "
+                 "same backend. Found {} and {}.\n",
+                 prediction->backend(), target->backend());
+    AIDGE_ASSERT(prediction->dims() == target->dims(),
+                 "'prediction' (shape {}) and 'target' (shape {}) Tensors must "
+                 "have the same dimensions.\n",
+                 prediction->dims(), target->dims());
+    AIDGE_ASSERT(prediction->dataType() == target->dataType(),
+                 "'prediction' (data type {}) and 'target' (data type {}) "
+                 "Tensors must have the same data type.\n",
+                 prediction->dataType(), target->dataType());
+
+    const float eps1 = 1.e-10f;
+    const float eps2 = 1.e-10f;
+
+    // Define nodes: inputs
+    const std::shared_ptr<Node> prediction_node = Producer(prediction, "pred");
+    const std::shared_ptr<Node> target_node = Producer(target, "label");
+
+    // Define nodes: add1 = prediction + eps1, add2 = target + eps1
+    const std::shared_ptr<Node> add1_node = Add(2, "add1");
+    const std::shared_ptr<Node> add2_node = Add(2, "add2");
+    prediction_node->addChild(add1_node, 0, 0);
+    Producer(std::make_shared<Tensor>(Array1D<float, 1>{{eps1}}))
+        ->addChild(add1_node, 0, 1);
+    target_node->addChild(add2_node, 0, 0);
+    Producer(std::make_shared<Tensor>(Array1D<float, 1>{{eps1}}))
+        ->addChild(add2_node, 0, 1);
+
+    // Define nodes: sub1 = 1 - prediction + eps2 and sub2 = - (1 - target + eps2)
+    const std::shared_ptr<Node> sub1_node = Sub("sub1");
+    const std::shared_ptr<Node> sub2_node = Sub("sub2");
+    Producer(std::make_shared<Tensor>(Array1D<float, 1>{{1.0f + eps2}}))
+        ->addChild(sub1_node, 0, 0);
+    prediction_node->addChild(sub1_node, 0, 1);
+    target_node->addChild(sub2_node, 0, 0);
+    Producer(std::make_shared<Tensor>(Array1D<float, 1>{{1.0f + eps2}}))
+        ->addChild(sub2_node, 0, 1);
+
+    // Define nodes: ln1 = ln(prediction + eps1) and ln2 = ln(1 - prediction + eps2)
+    const std::shared_ptr<Node> ln1_node = Ln("ln1");
+    const std::shared_ptr<Node> ln2_node = Ln("ln2");
+    add1_node-> addChild(ln1_node, 0, 0);
+    sub1_node-> addChild(ln2_node, 0, 0);
+
+    // Define nodes: mul1 = (target + eps1) * ln(prediction + eps1) and mul2 = - (1 - target + eps2) * ln(1 - prediction + eps2)
+    const std::shared_ptr<Node> mul1_node = Mul("mul1");
+    const std::shared_ptr<Node> mul2_node = Mul("mul2");
+    add2_node->addChild(mul1_node, 0, 0);
+    ln1_node->addChild(mul1_node, 0, 1);
+    sub2_node->addChild(mul2_node, 0, 0);
+    ln2_node->addChild(mul2_node, 0, 1);
+
+    // Define node: sub3 = - [(target + eps1) * ln(prediction + eps1) + (1 - target + eps2) * ln(1 - prediction + eps2)]
+    const std::shared_ptr<Node> sub3_node = Sub("sub3");
+    mul2_node->addChild(sub3_node, 0, 0);
+    mul1_node->addChild(sub3_node, 0, 1);
+
+    // Define nodes: div1 = (target + eps1) / (prediction + eps1) and div2 = - (1 - target + eps2)/(1 - prediction + eps2)
+    const std::shared_ptr<Node> div1_node = Div("div1");
+    const std::shared_ptr<Node> div2_node = Div("div2");
+    add2_node->addChild(div1_node, 0, 0);
+    add1_node->addChild(div1_node, 0, 1);
+    sub2_node->addChild(div2_node, 0, 0);
+    sub1_node->addChild(div2_node, 0, 1);
+
+    // Define node: add3 = (target + eps1) / (prediction + eps1) - (1 - target + eps2)/(1 - prediction + eps2)
+    const std::shared_ptr<Node> add3_node = Add(2, "add3");
+    div1_node->addChild(add3_node, 0, 0);
+    div2_node->addChild(add3_node, 0, 1);
+
+    // Define node: loss
+    std::vector<int> axes_dims(prediction->nbDims());
+    std::iota(std::begin(axes_dims), std::end(axes_dims), 0);
+    auto loss_node = ReduceMean(axes_dims, 1, "loss");
+    sub3_node->addChild(loss_node, 0, 0);
+
+    // Define node: gradient
+    const std::shared_ptr<Node> gradient_node = Mul("gradient");
+    add3_node->addChild(gradient_node, 0, 0);
+    Producer(std::make_shared<Tensor>(Array1D<float, 1>{{-1.0f/float(target->dims()[0])}}))
+        ->addChild(gradient_node, 0, 1);
+
+    // Create GraphView
+    std::shared_ptr<GraphView> gv_loss = std::make_shared<GraphView>("BCE");
+    gv_loss->add({prediction_node, target_node,
+                  add1_node->getParent(1), add1_node,
+                  add2_node->getParent(1), add2_node,
+                  sub1_node->getParent(0), sub1_node,
+                  sub2_node->getParent(1), sub2_node,
+                  ln1_node, ln2_node, mul1_node, mul2_node, div1_node, div2_node,
+                  sub3_node, loss_node,
+                  add3_node, gradient_node->getParent(1), gradient_node});
+    gv_loss->compile(prediction->getImpl()->backend(), prediction->dataType());
+
+    // Compute loss and gradient
+    SequentialScheduler ss_loss{gv_loss};
+    ss_loss.forward(false);
+
+    // prediction->initGrad(); // Enable gradient for output
+    std::shared_ptr<Tensor> outputGrad = prediction->grad();
+    const std::shared_ptr<OperatorTensor> gradient_op = std::dynamic_pointer_cast<OperatorTensor>(gradient_node->getOperator());
+    outputGrad->copyFrom(gradient_op->getOutput(0)->clone()); // Update gradient
+
+    const std::shared_ptr<OperatorTensor> loss_op = std::dynamic_pointer_cast<OperatorTensor>(loss_node->getOperator());
+    return loss_op->getOutput(0)->clone(); // Return loss
+}
diff --git a/src/loss/regression/MSE.cpp b/src/loss/regression/MSE.cpp
index 87f685a0f550a1cb60563503447407f70868ce9a..3d7ffe923bfa957c43fa93ef7c234ef1bdf63f06 100644
--- a/src/loss/regression/MSE.cpp
+++ b/src/loss/regression/MSE.cpp
@@ -15,6 +15,7 @@
 #include "aidge/backend/cpu/operator/PowImpl.hpp"
 #include "aidge/backend/cpu/operator/ReduceMeanImpl.hpp"
 #include "aidge/backend/cpu/operator/SubImpl.hpp"
+#include "aidge/backend/cpu/operator/MulImpl.hpp"
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/graph/OpArgs.hpp"
@@ -23,6 +24,7 @@
 #include "aidge/operator/Pow.hpp"
 #include "aidge/operator/ReduceMean.hpp"
 #include "aidge/operator/Sub.hpp"
+#include "aidge/operator/Mul.hpp"
 #include "aidge/recipes/GraphViewHelper.hpp"
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/scheduler/SequentialScheduler.hpp"
@@ -43,7 +45,7 @@ Aidge::Tensor Aidge::loss::MSE(std::shared_ptr<Tensor>& prediction,
     (2/NbBatch)->Mul->Gradient
     */
 
-    prediction->initGrad(); // Enable gradient for output
+    // prediction->initGrad(); // Enable gradient for output
 
     // compile_gradient(graph);  // Warning compile gradient here, without
     //                           // it, grad is nullptr. Maybe we can find a better
diff --git a/unit_tests/loss/classification/Test_BCE.cpp b/unit_tests/loss/classification/Test_BCE.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..01f0cb52441e91f46d80947c937e1baab807052c
--- /dev/null
+++ b/unit_tests/loss/classification/Test_BCE.cpp
@@ -0,0 +1,90 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <catch2/catch_test_macros.hpp>
+#include <cstddef>     // std::size_t
+#include <cmath>       //
+#include <functional>  // std::multiplies, std::plus
+#include <memory>      // std::make_unique
+#include <numeric>     // std::accumulate
+#include <random>      // std::random_device, std::mt19937,
+                       // std::uniform_int_distribution
+#include <vector>
+
+#include "aidge/loss/LossList.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+namespace Aidge {
+TEST_CASE("[loss/classification] BCE", "[loss][classification][BCE]") {
+    constexpr std::uint16_t NBTRIALS = 10;
+
+    // set random variables
+    std::random_device rd;
+    std::mt19937 gen(rd());
+    std::uniform_int_distribution<std::size_t> dimsDist(1, 5);
+    std::uniform_int_distribution<std::size_t> nbDimsDist(1, 2);
+    std::uniform_real_distribution<float> valueDist(0.0f, 1.0f);
+
+    for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
+        const std::size_t nb_dims = 2; // For BCE test, nb_dims is fixed as 2: NbBatch, NbChan
+        std::vector<std::size_t> dims(2);
+
+        for (std::size_t i = 0; i < nb_dims; ++i) { dims[i] = dimsDist(gen); }
+        const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+
+        // create random predictions
+        std::unique_ptr<float[]> pred = std::make_unique<float[]>(nb_elements);
+        for (std::size_t i = 0; i < nb_elements; ++i) {
+            pred[i] = valueDist(gen);
+        }
+
+        // create random targets
+        std::unique_ptr<float[]> targ = std::make_unique<float[]>(nb_elements);
+        for (std::size_t i = 0; i < nb_elements; ++i) {
+            targ[i] = valueDist(gen);
+        }
+
+        // compute the BCE manually
+        const float eps1 = 1.0e-10f;
+        const float eps2 = 1.0e-10f;
+        std::unique_ptr<float[]> tmp_res_manual = std::make_unique<float[]>(nb_elements);
+        for (std::size_t i = 0; i < nb_elements; ++i) {
+            tmp_res_manual[i] = - ((targ[i] + eps1) * std::log(pred[i] + eps1) + (1.0f - targ[i] + eps2) * std::log(1.0f - pred[i] + eps2));
+        }
+        std::cout << "Output manual:" << std::endl;
+        std::shared_ptr<Tensor> tmp_tensor = std::make_shared<Tensor>(dims);
+        tmp_tensor->setBackend("cpu");
+        tmp_tensor->getImpl()->setRawPtr(tmp_res_manual.get(), nb_elements);
+        tmp_tensor->print();
+        const float res_manual = std::accumulate(&tmp_res_manual[0], &tmp_res_manual[nb_elements], 0.0f, std::plus<float>()) / static_cast<float>(nb_elements);
+
+        // compute the BCE using Aidge::loss::BCE function
+        std::cout << "Input 0 manual:" << std::endl;
+        std::shared_ptr<Tensor> pred_tensor = std::make_shared<Tensor>(dims);
+        pred_tensor->setBackend("cpu");
+        pred_tensor->getImpl()->setRawPtr(pred.get(), nb_elements);
+        pred_tensor->print();
+
+        std::cout << "Input 1 manual:" << std::endl;
+        std::shared_ptr<Tensor> targ_tensor = std::make_shared<Tensor>(dims);
+        targ_tensor->setBackend("cpu");
+        targ_tensor->getImpl()->setRawPtr(targ.get(), nb_elements);
+        targ_tensor->print();
+        
+        const Tensor res_function = loss::BCE(pred_tensor, targ_tensor);
+
+        // compare results
+        Tensor res_manual_tensor = Tensor(res_manual);
+        REQUIRE(approxEq<float>(res_manual, res_function));
+    }
+}
+}  // namespace Aidge
diff --git a/unit_tests/optimizer/Test_Adam.cpp b/unit_tests/optimizer/Test_Adam.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..bd297903d47b90b755ff59ace0e052aa62c309d7
--- /dev/null
+++ b/unit_tests/optimizer/Test_Adam.cpp
@@ -0,0 +1,145 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <catch2/catch_test_macros.hpp>
+#include <cstddef>  // std::size_t
+#include <cmath>  // std::sqrt, std::pow
+#include <memory>
+#include <random>   // std::random_device, std::mt19937, std::uniform_int_distribution
+#include <set>
+#include <vector>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/backend/cpu/data/TensorImpl.hpp"
+#include "aidge/learning/learningRate/LRScheduler.hpp"
+#include "aidge/learning/learningRate/LRSchedulerList.hpp"
+#include "aidge/learning/optimizer/Optimizer.hpp"
+#include "aidge/learning/optimizer/Adam.hpp"
+#include "aidge/backend/cpu/operator/AddImpl.hpp"
+#include "aidge/backend/cpu/operator/MulImpl.hpp"
+#include "aidge/backend/cpu/operator/SubImpl.hpp"
+#include "aidge/backend/cpu/operator/DivImpl.hpp"
+#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+namespace Aidge {
+TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") {
+    constexpr std::uint16_t NBTRIALS = 10;
+    // Create a random number generator
+    std::random_device rd;
+    std::mt19937 gen(rd());
+    std::uniform_real_distribution<float> valueDist(0.1f, 1.0f); // Random float distribution between 0 and 1
+    std::uniform_real_distribution<float> paramDist(0.001f, 1.0f); // Random float distribution between 0 and 1
+    std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2), std::size_t(5));
+    std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(5));
+
+
+    for (std::size_t trial = 0; trial < NBTRIALS; ++trial) {
+        // create a random number of Tensor with random dims and random values
+        // Create random Tensor, Random Gradient and random
+        const std::size_t nb_tensors = dimSizeDist(gen);
+        std::vector<std::size_t> size_tensors(nb_tensors, 1);
+
+        std::vector<std::shared_ptr<Tensor>> tensors(nb_tensors);
+        std::vector<std::unique_ptr<float[]>> val_tensors(nb_tensors);
+
+        std::vector<std::shared_ptr<Tensor>> optim_tensors(nb_tensors);
+
+        std::vector<std::shared_ptr<Tensor>> grad_tensors(nb_tensors);
+        std::vector<std::unique_ptr<float[]>> val_grad_tensors(nb_tensors);
+
+        std::vector<std::shared_ptr<Tensor>> momentum_tensors(nb_tensors);
+        std::vector<std::unique_ptr<float[]>> val_momentum1_tensors(nb_tensors);
+		std::vector<std::unique_ptr<float[]>> val_momentum2_tensors(nb_tensors);
+
+        for (std::size_t i = 0; i < nb_tensors; ++i) {
+            std::vector<std::size_t> dims(nbDimsDist(gen));
+            for (std::size_t d = 0; d < dims.size(); ++d) {
+                dims[d] = dimSizeDist(gen);
+                size_tensors[i] *= dims[d];
+            }
+
+            val_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
+            val_grad_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
+            val_momentum1_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
+            val_momentum2_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
+            for (std::size_t j = 0; j < size_tensors[i]; ++j) {
+                val_tensors[i][j] = valueDist(gen);
+                val_grad_tensors[i][j] = valueDist(gen);
+                val_momentum1_tensors[i][j] = 0.0f;
+				val_momentum2_tensors[i][j] = 0.0f;
+            }
+            tensors[i] = std::make_shared<Tensor>(dims);
+            tensors[i]->setBackend("cpu");
+            tensors[i]->getImpl()->setRawPtr(val_tensors[i].get(), size_tensors[i]);
+            optim_tensors[i] = std::make_shared<Tensor>(dims);
+            optim_tensors[i]->setBackend("cpu");
+            optim_tensors[i]->getImpl()->copy(val_tensors[i].get(), size_tensors[i]);
+            // optim_tensors[i]->initGrad();
+
+            grad_tensors[i] = std::make_shared<Tensor>(dims);
+            grad_tensors[i]->setBackend("cpu");
+            grad_tensors[i]->getImpl()->setRawPtr(val_grad_tensors[i].get(), size_tensors[i]);
+
+            momentum_tensors[i] = std::make_shared<Tensor>(dims);
+            momentum_tensors[i]->setBackend("cpu");
+            momentum_tensors[i]->getImpl()->setRawPtr(val_momentum1_tensors[i].get(), size_tensors[i]);
+            momentum_tensors[i]->getImpl()->setRawPtr(val_momentum2_tensors[i].get(), size_tensors[i]);
+
+            REQUIRE((tensors[i]->hasImpl() &&
+                    optim_tensors[i]->hasImpl() &&
+                    grad_tensors[i]->hasImpl()));
+        }
+
+        // generate parameters
+        float lr = paramDist(gen);
+        float beta1 = paramDist(gen);
+        float beta2 = paramDist(gen);
+        float epsilon = paramDist(gen);
+
+        // set Optimizer
+        Adam opt = Adam(beta1, beta2, epsilon);
+        opt.setParameters(optim_tensors);
+        for (std::size_t t = 0; t < nb_tensors; ++t) {
+            optim_tensors[t]->grad()->getImpl()->setRawPtr(val_grad_tensors[t].get(), size_tensors[t]);
+        }
+        opt.setLearningRateScheduler(learning::ConstantLR(lr));
+
+        for (std::size_t t = 0; t < nb_tensors; ++t) {
+            const Tensor tmpt1= *(opt.parameters().at(t));
+            const Tensor tmpt2= *tensors[t];
+            REQUIRE(approxEq<float,float>(tmpt2, tmpt1, 1e-5f, 1e-8f));
+        }
+
+        for (std::size_t step = 0; step < 10; ++step) {
+            // truth
+            float lr2 = lr * std::sqrt(1.0f - std::pow(beta2, step + 1)) / (1.0f - std::pow(beta1, step + 1));
+            float epsilon2 = epsilon * std::sqrt(1.0f - std::pow(beta2, step + 1));
+            for (std::size_t t = 0; t < nb_tensors; ++t) {
+                for (std::size_t i = 0; i < size_tensors[t]; ++i) {
+                    val_momentum1_tensors[t][i] = beta1 * val_momentum1_tensors[t][i] + (1.0f - beta1) * val_grad_tensors[t][i];
+                    val_momentum2_tensors[t][i] = beta2 * val_momentum2_tensors[t][i] + (1.0f - beta2) * val_grad_tensors[t][i] * val_grad_tensors[t][i];
+                    val_tensors[t][i] = val_tensors[t][i]
+                                      - lr2 * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) +  epsilon2);
+                }
+            }
+            // optimizer
+            opt.update();
+            // tests
+            for (std::size_t t = 0; t < nb_tensors; ++t) {
+                const Tensor tmpt1= *(opt.parameters().at(t));
+                const Tensor tmpt2= *tensors[t];
+                REQUIRE(approxEq<float,float>(tmpt2, tmpt1, 1e-5f, 1e-8f));
+            }
+        }
+    }
+}
+} // namespace Aidge
diff --git a/unit_tests/optimizer/Test_SGD.cpp b/unit_tests/optimizer/Test_SGD.cpp
index df9924d557d89d0483d018ce08951cf573e233d7..6b8edc60a6f1583d1241552442558bff5f2ce52e 100644
--- a/unit_tests/optimizer/Test_SGD.cpp
+++ b/unit_tests/optimizer/Test_SGD.cpp
@@ -77,7 +77,7 @@ TEST_CASE("[learning/SGD] update", "[Optimizer][SGD]") {
             optim_tensors[i] = std::make_shared<Tensor>(dims);
             optim_tensors[i]->setBackend("cpu");
             optim_tensors[i]->getImpl()->copy(val_tensors[i].get(), size_tensors[i]);
-            optim_tensors[i]->initGrad();
+            // optim_tensors[i]->initGrad();
 
             grad_tensors[i] = std::make_shared<Tensor>(dims);
             grad_tensors[i]->setBackend("cpu");
diff --git a/version.txt b/version.txt
index 17e51c385ea382d4f2ef124b7032c1604845622d..8294c184368c0ec9f84fbcc80c6b36326940c770 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.1.1
+0.1.2
\ No newline at end of file