diff --git a/.gitignore b/.gitignore
index f37378e300efeb5362882eb8d6eb59f028563a0e..ff07895b0e4e52c6c6a21ae0984fd3341ab972f7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,4 +19,7 @@ dist*/
 xml*/
 
 # ONNX
-*.onnx
\ No newline at end of file
+*.onnx
+
+# log
+*.log
diff --git a/.gitlab/ci/build.gitlab-ci.yml b/.gitlab/ci/build.gitlab-ci.yml
index d8472ca464e09814273ce13cdf340a92d2d7004c..baea3ca70266f1d3fcd489824da681106c85b73c 100644
--- a/.gitlab/ci/build.gitlab-ci.yml
+++ b/.gitlab/ci/build.gitlab-ci.yml
@@ -23,9 +23,15 @@ build:ubuntu_python:
     - DEPENDENCY_NAME="aidge_backend_cpu"
     - !reference [.download_dependency, script]
 
+    # aidge_learning
+    - DEPENDENCY_NAME="aidge_learning"
+    - !reference [.download_dependency, script]
+
+
     - python3 -m pip install virtualenv
     - virtualenv venv
     - source venv/bin/activate
+    - python3 -m pip install numpy coverage onnxruntime # used for tests
     - python3 -m pip install -r requirements.txt
     - python3 -m pip install . -v
   artifacts:
diff --git a/.gitlab/ci/coverage.gitlab-ci.yml b/.gitlab/ci/coverage.gitlab-ci.yml
index e3f06d0f592543d2b0661ee829674c908e90d030..88dcd3bea4cbd10e72fd0e3c3fa47fd2b67b4236 100644
--- a/.gitlab/ci/coverage.gitlab-ci.yml
+++ b/.gitlab/ci/coverage.gitlab-ci.yml
@@ -5,10 +5,10 @@ coverage:ubuntu_python:
     - docker
   script:
     - source venv/bin/activate
-    - python3 -m pip install numpy coverage
     - cd ${CI_PROJECT_NAME}
     # Retrieve the installation path of the module, since it is installed with pip.
     - export MODULE_LOCATION=`python -c "import ${CI_PROJECT_NAME} as _; print(_.__path__[0])"`
+    - echo $MODULE_LOCATION
     - python3 -m coverage run --source=$MODULE_LOCATION -m unittest discover -s unit_tests/ -v -b
     - python3 -m coverage report
     - python3 -m coverage xml
diff --git a/aidge_interop_torch/unit_tests/test_pytorch.py b/aidge_interop_torch/unit_tests/test_pytorch.py
index 30ff4d4fe7f846062abb1e5212f70b5eb21aa4b8..463b9ef5ec0207011caad8adaeeb050e1e868296 100755
--- a/aidge_interop_torch/unit_tests/test_pytorch.py
+++ b/aidge_interop_torch/unit_tests/test_pytorch.py
@@ -1,19 +1,21 @@
 import torch
 import aidge_interop_torch
 import aidge_core
+import aidge_learning
 import unittest
 
 import numpy as np
 
 torch.backends.cudnn.enabled = True
 torch.backends.cudnn.benchmark = True
+torch.set_printoptions(precision=7)
 weight_value = 0.05
-batch_size = 10
+batch_size = 2
 learning_rate = 0.01
 comparison_precision = 0.001
 absolute_presision = 0.0001
 epochs = 10
-
+# aidge_core.Log.set_console_level(aidge_core.Level.Debug)
 # TODO : add tensor test later ...
 
 # class test_tensor_conversion(unittest.TestCase):
@@ -44,7 +46,7 @@ class Test_Networks():
 
     def __init__(self, model1, model2, name="", test_backward=True, eval_mode=False, epochs=10, relative_precision=0.001, absolute_presision=0.0001, learning_rate=0.01, cuda=False):
         self.relative_precision = relative_precision
-        self.absolute_presision = absolute_presision
+        self.absolute_precision = absolute_presision
         self.epochs = epochs
         self.test_backward = test_backward
         self.model1 = model1
@@ -60,7 +62,6 @@ class Test_Networks():
             self.model1.train()
             self.model2.train()
         self.name = name
-
         if self.test_backward:
             self.optimizer1 = torch.optim.SGD(
                 self.model1.parameters(), lr=learning_rate)
@@ -74,14 +75,20 @@ class Test_Networks():
             i = i.item()
             j = j.item()
             if j != 0:
-                if abs(i-j) > self.relative_precision * abs(j) + self.absolute_presision:
+                if abs(i-j) > self.relative_precision * abs(j) + self.absolute_precision:
+                    return -1
+            elif i != 0:
+                if abs(i-j) > self.relative_precision * abs(i) + self.absolute_precision:
                     return -1
         return 0
 
     def unit_test(self, input_tensor, label):
         torch_tensor1 = input_tensor
         torch_tensor2 = input_tensor.detach().clone()
-        print(torch_tensor1)
+
+        torch_tensor1.requires_grad = True
+        torch_tensor2.requires_grad = True
+
         if self.test_backward:
             label1 = label
             label2 = label.detach().clone()
@@ -104,11 +111,23 @@ class Test_Networks():
 
             loss2 = self.criterion2(output2, label2)
             self.optimizer2.zero_grad()
+
             loss2.backward()
             self.optimizer2.step()
+            NaN_flag = False
+            if loss1.isnan():
+                print("Loss1 is NaN")
+                NaN_flag = True
+            if loss2.isnan():
+                print("Loss2 is NaN")
+                NaN_flag = True
+            if NaN_flag: return -1
             if self.compare_tensor(loss1, loss2):
                 print("Different loss : ", loss1.item(), "|", loss2.item())
                 return -1
+            if self.compare_tensor(torch_tensor1.grad, torch_tensor2.grad):
+                print(f"Different input gradient:\nGrad1\n{torch_tensor1.grad}\nGrad2\n{torch_tensor2.grad}\n")
+                return -1
         return 0
 
     def test_multiple_step(self, input_size, label_size):
@@ -135,8 +154,8 @@ class TorchLeNet(torch.nn.Module):
         c1 = torch.nn.Conv2d(1, 6, 5, bias=False)
         c2 = torch.nn.Conv2d(6, 16, 5, bias=False)
         c3 = torch.nn.Conv2d(16, 120, 5, bias=False)
-        l1 = torch.nn.Linear(120, 84)
-        l2 = torch.nn.Linear(84, 10)
+        l1 = torch.nn.Linear(120, 84, bias=True)
+        l2 = torch.nn.Linear(84, 10, bias=True)
 
         torch.nn.init.constant_(c1.weight, weight_value)
         torch.nn.init.constant_(c2.weight, weight_value)
@@ -171,6 +190,46 @@ class TorchLeNet(torch.nn.Module):
         return x
 
 
+class Easy_graph(torch.nn.Module):
+
+    def __init__(self):
+        super(Easy_graph, self).__init__()
+
+        self.layer = torch.nn.Sequential(
+            torch.nn.Flatten(),
+            torch.nn.Linear(3, 4),
+            torch.nn.ReLU(),
+            torch.nn.Linear(4, 4),
+            torch.nn.ReLU(),
+            torch.nn.Linear(4, 4)
+        )
+
+    def forward(self, x):
+        x = self.layer(x)
+        return x
+# class Easy_graph(torch.nn.Module):
+
+#     def __init__(self):
+#         super(Easy_graph, self).__init__()
+
+#         self.layer = torch.nn.Sequential(
+#             torch.nn.Flatten(),
+#             torch.nn.Linear(32*32*3, 512),
+#             torch.nn.ReLU(),
+#             torch.nn.Linear(512, 256),
+#             torch.nn.ReLU(),
+#             torch.nn.Linear(256, 128),
+#             torch.nn.ReLU(),
+#             torch.nn.Linear(128, 10)
+#         )
+
+#     def forward(self, x):
+#         x = self.layer(x)
+#         return x
+
+
+
+
 class test_interop(unittest.TestCase):
 
     def tearDown(self):
@@ -193,6 +252,24 @@ class test_interop(unittest.TestCase):
         res = tester.test_multiple_step(input_size, (batch_size, 10))
         self.assertNotEqual(res, -1, msg="LeNet CPU eval failed")
 
+    def test_aidge_backward_CPU(self):
+        print('=== Testing aidge backward CPU ===')
+        input_size = (batch_size, 3, 1, 1)
+        torch_model = Easy_graph()
+
+        aidge_model = aidge_interop_torch.wrap(torch_model, input_size)
+
+        opt = aidge_learning.SGD()
+        lrs = aidge_learning.constant_lr(0.01)
+        opt.set_learning_rate_scheduler(lrs)
+        opt.set_parameters(list(aidge_core.producers(aidge_model._graph_view)))
+        aidge_model.set_optimizer(opt)
+
+        tester = Test_Networks(torch_model, aidge_model, eval_mode=False, epochs=epochs)
+        res = tester.test_multiple_step(input_size, (batch_size, 4))
+
+        self.assertNotEqual(res, -1, msg="CPU train failed")
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/aidge_interop_torch/utils.py b/aidge_interop_torch/utils.py
index b1c31749afdef9845ea9a6e44a592285844ee3b7..a685f57a594b1cca6e67f15c33dd2e73aa2fe3e5 100644
--- a/aidge_interop_torch/utils.py
+++ b/aidge_interop_torch/utils.py
@@ -6,12 +6,20 @@ import aidge_core
 import aidge_backend_cpu
 
 import aidge_onnx
+import aidge_learning
 
 from onnxsim import simplify
 from typing import Union
 
 
 def convert_tensor(tensor):
+    """Convert a torch tensor to :py:class:`aidge_core.Tensor` and vice versa.
+
+    :param tensor: Tensor to convert.
+    :type tensor: torch.Tensor or :py:class:`aidge_core.Tensor`
+    :return: Converted tensor
+    :rtype: torch.Tensor or :py:class:`aidge_core.Tensor`
+    """
     if isinstance(tensor, torch.Tensor):
         return torch_tensor_to_aidge(tensor)
     elif isinstance(tensor, aidge_core.Tensor):
@@ -21,11 +29,15 @@ def convert_tensor(tensor):
             f"Object of type {type(tensor)} is not convertible.")
 
 
-def aidge_tensor_to_torch(aidge_tensor):
-    """
-    Convert Aidge.Tensor -> torch.Tensor
-    The conversion creates a GPU memory copy if the tensor is CUDA.
-    This method also convert the shape of the tensor to follow torch convention.
+def aidge_tensor_to_torch(aidge_tensor: aidge_core.Tensor) -> torch.Tensor:
+    """Convert:py:class:`aidge_core.Tensor` -> torch.Tensor
+
+    CUDA tensor is not handled yet, but when it will, the conversion will creates a GPU memory copy if the tensor is CUDA.
+
+    :param aidge_tensor: Tensor to convert
+    :type aidge_tensor: aidge_core.Tensor
+    :return: Converted tensor
+    :rtype: torch.Tensor
     """
     torch_tensor = torch.from_numpy(np.array(aidge_tensor))
     # TODO : handle cuda case !
@@ -34,7 +46,16 @@ def aidge_tensor_to_torch(aidge_tensor):
     return torch_tensor
 
 
-def torch_tensor_to_aidge(torch_tensor):
+def torch_tensor_to_aidge(torch_tensor: torch.Tensor) -> aidge_core.Tensor:
+    """Convert torch.Tensor -> :py:class:`aidge_core.Tensor`
+
+    CUDA tensor is not handled yet, but when it will, the conversion will not creates a GPU memory copy if the tensor is CUDA.
+
+    :param torch_tensor: Tensor to convert
+    :type torch_tensor: torch.Tensor
+    :return: Converted tensor
+    :rtype: aidge_core.Tensor
+    """
     aidge_tensor = None
     numpy_tensor = torch_tensor.cpu().detach().numpy()
     # This operation creates a CPU memory copy.
@@ -53,10 +74,9 @@ class AidgeModule(torch.nn.Module):
 
     def __init__(self, graph_view, batch_size=None):
         """
-        :param block: Aidge block object to interface with PyTorch
-        :type block: :py:class:`aidge_core.GraphView`
+        :param graph_view: Aidge block object to interface with PyTorch
+        :type graph_view: :py:class:`aidge_core.GraphView`
         """
-        print("TEST")
         super().__init__()
         if not isinstance(graph_view, aidge_core.GraphView):
             raise TypeError(
@@ -65,7 +85,6 @@ class AidgeModule(torch.nn.Module):
 
         # TODO : better handling of backend ?
         # maybe a function set_backend similar to PyTorch ?
-        print("Set cpu !")
         graph_view.set_backend("cpu")
 
         # We need to add a random parameter to the module else pytorch refuse to compute gradient
@@ -78,6 +97,11 @@ class AidgeModule(torch.nn.Module):
         # TODO support of multi input graph later ...
         self.input_nodes = [None]
         self.scheduler = None
+        self.optimizer = None
+        self.grad_compiled = False
+
+    def set_optimizer(self, opt):
+        self.optimizer = opt
 
     def forward(self, inputs: torch.Tensor):
         """
@@ -108,14 +132,13 @@ class AidgeModule(torch.nn.Module):
                 else:
                     self.input_nodes[0].get_operator(
                     ).set_output(0, aidge_tensor)
-                # TODO: create one scheduler ! maybe in init ?
 
                 if not self.scheduler:
                     self.scheduler = aidge_core.SequentialScheduler(
                         self._graph_view)
 
                 # Run inference !
-                self.scheduler.forward(verbose=True)
+                self.scheduler.forward(forward_dims=True)
 
                 # TODO: support for multi output later ?
                 if len(self._graph_view.get_output_nodes()) > 1:
@@ -128,7 +151,34 @@ class AidgeModule(torch.nn.Module):
 
             @staticmethod
             def backward(ctx, grad_output):
-                raise RuntimeError("Backward is not yet supported with Aidge.")
+                if not self.grad_compiled: aidge_core.compile_gradient(self._graph_view)
+
+                if self.multi_outputs_flag:
+                    raise RuntimeError(
+                        "Backward is not possible if the model has multi-outputs")
+                if len(self.input_nodes) != 1:
+                    raise RuntimeError(
+                        "Multi-input is not handled for now in pytorch backpropagation")
+                # convert the output gradient to an AIDGE Tensor
+                aidge_grad_output = torch_tensor_to_aidge(grad_output)
+
+                if len(self._graph_view.get_output_nodes()) != 1:
+                    RuntimeError(
+                        f"We only support one output got {len(self._graph_view.get_output_nodes())}")
+                output_node = list(self._graph_view.get_output_nodes())[0]
+                output_tensor = output_node.get_operator().get_output(0)
+                output_tensor.set_grad(aidge_grad_output)
+
+                # run the backpropagation
+                # TODO: remove update from the backprop
+                self.optimizer.reset_grad()
+                self.scheduler.backward()
+                self.optimizer.update()
+                # get grad of first layer no handling of multi input
+                aidge_out_grad = self.input_nodes[0].get_operator().get_output(0).grad()
+                # convert grad to torch
+                torch_out_grad = aidge_tensor_to_torch(aidge_out_grad)
+                return torch_out_grad
 
         # If the layer is at the beginning of the network requires grad is False.
         inputs.requires_grad = True
@@ -174,7 +224,6 @@ class ContextNoBatchNormFuse:
                     if current_bn.affine:  # Bias and Weights only created if affine=True
                         saved_bias = current_bn.bias.detach().clone()
                         saved_weight = current_bn.weight.detach().clone()
-                    print(current_bn.running_mean.shape)
                     # Real batchnorm forward
                     output_tensor = torch.nn.functional.batch_norm(
                         inputs,
@@ -244,7 +293,8 @@ def wrap(torch_model: torch.nn.Module,
     try:
         torch_device = next(torch_model.parameters()).device
     except StopIteration:
-        torch_device = torch.device('cpu') # Model has no parameter, defaulting to cpu
+        # Model has no parameter, defaulting to cpu
+        torch_device = torch.device('cpu')
 
     dummy_in = torch.zeros(input_size).to(torch_device)