diff --git a/CMakeLists.txt b/CMakeLists.txt
index b3c6d459dfaf29f5accbc0be4565a3709e9ffd3b..afb882af0d02000e5490f1d2a0c56b4487481be9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -61,7 +61,7 @@ option(PYBIND "python binding" OFF)
 option(WERROR "Warning as error" OFF)
 option(TEST "Enable tests" OFF)
 option(COVERAGE "Enable coverage" OFF)
-option(CUDA "Enable CUDA backend" OFF) # XXX OFF
+option(CUDA "Enable CUDA backend" ON) # XXX OFF
 option(ENABLE_ASAN "Enable ASan (AddressSanitizer) for runtime analysis of memory use (over/underflow, memory leak, ...)" OFF)
 
 ##############################################
@@ -182,7 +182,6 @@ endif()
 
 # Coverage flags for GCC
 if(CMAKE_COMPILER_IS_GNUCXX AND COVERAGE)
-    include(CodeCoverage)
     append_coverage_compiler_flags()
 endif()
 
diff --git a/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp b/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp
index 9d7a106cdd6b1a0970c87b2a27cc7d6637846b49..935d8f065a5e91729c5c0ff25b13f5ea1234a8b6 100644
--- a/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp
@@ -23,7 +23,7 @@ void FixedQImpl_cpu_forward_kernel(
     std::size_t nbBits,
     float span_,
     bool isOutputUnsigned,
-    std::size_t inputLenght,
+    std::size_t inputLength,
     const void* input_,
     void* output_) 
 {
@@ -40,7 +40,7 @@ void FixedQImpl_cpu_forward_kernel(
     const I* input = static_cast<const I*>(input_);
     O* output = static_cast<O*>(output_);
 
-    for (std::size_t i = 0; i < inputLenght; ++i) {
+    for (std::size_t i = 0; i < inputLength; ++i) {
         I clipped = std::max(lower, std::min(input[i], upper));
         output[i] = std::round(clipped / stepSize) * stepSize;
     }
@@ -49,14 +49,14 @@ void FixedQImpl_cpu_forward_kernel(
 
 template <class GI, class GO>
 void FixedQImpl_cpu_backward_kernel(
-    const std::size_t inputLenght,
+    const std::size_t inputLength,
     const void* grad_output_,
 	void* grad_input_) 
 {
     const GO* grad_output = static_cast<const GO*>(grad_output_);
     GI* grad_input = static_cast<GI*>(grad_input_);
 
-    for (std::size_t i = 0; i < inputLenght; ++i) {
+    for (std::size_t i = 0; i < inputLength; ++i) {
         // Straight Through Estimator
         grad_input[i] = grad_output[i];
     }
diff --git a/include/aidge/quantization/PTQ/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp
similarity index 100%
rename from include/aidge/quantization/PTQ/PTQMetaOps.hpp
rename to include/aidge/operator/PTQMetaOps.hpp
diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp
index 4fc38bc3b959ec8264ddaddbd4673fbe1f75e4ab..bfe671e3556c3af2c367ce7f86708f01c8e3d3b5 100644
--- a/include/aidge/quantization/PTQ/PTQ.hpp
+++ b/include/aidge/quantization/PTQ/PTQ.hpp
@@ -124,11 +124,11 @@ namespace Aidge {
      * @brief Quantize an already normalized (in term of parameters and activations) network.
      * @param graphView The GraphView to be quantized.
      * @param nbBits The desired number of bits of the quantization.
-     * @param applyRounding Whether to apply the rounding operations or not.
+     * @param noQuant Whether to apply the rounding operations or not.
      * @param optimizeSigns Whether to take account of the IO signs of the operators or not.
      * @param verbose Whether to print the sign map or not.
      */
-    void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool applyRounding, bool optimizeSigns, bool verbose);
+    void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant, bool optimizeSigns, bool verbose);
 
     /**
      * @brief Main quantization routine. Performs every step of the quantization pipeline.
@@ -136,12 +136,12 @@ namespace Aidge {
      * @param nbBits The desired number of bits of the quantization.
      * @param inputDataSet The input dataset on which the value ranges are computed.
      * @param clippingMode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
-     * @param applyRounding Whether to apply the rounding operations or not.
+     * @param noQuant Whether to apply the rounding operations or not.
      * @param optimizeSigns Whether to take account of the IO signs of the operators or not.
      * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
      * @param verbose Whether to print internal informations about the quantization process.
      */
-    void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool applyRounding, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose);
+    void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose);
 
     /**
      * @brief Compute the weight ranges of every affine node. Provided for debugging purposes.
diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp
index a44c71b04ca9e9c6a8fba27c615c99b4893d3d8c..922187abca915daa1c00f3949d0d791b0d3e1c39 100644
--- a/include/aidge/quantization/QAT/QAT_LSQ.hpp
+++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp
@@ -22,22 +22,13 @@ namespace Aidge {
 namespace QuantLSQ {
 
 /**
- * @brief Insert the LSQ quantizer nodes in a given GraphView
- * @param graphView The GraphView containing the graph to quantize.
+ * @brief Given a GraphView with parameters properly initialized, insert
+ * the LSQ quantizer nodes, and setup the adjustment their step-sizes.
+ * @param graphView The GraphView containing the network to quantize.
  * @param nbBits Number of quantization bits.
- * @param span Fixed output span of the quantizers.
  */
-void insertQuantizers(std::shared_ptr<GraphView> graphView, std::size_t nbBits, float step_size);
 
-/**
- * @brief Given a GraphView with parameters properly initialized and some calibration data,
- * insert the LSQ quantizer nodes, and adjust their step-sizes.
- * @param graphView The GraphView containing the graph to quantize.
- * @param nbBits Number of quantization bits.
- * @param calibrationData Calibration data used to adjust the spans.
- * @param scale Multiplicative constant applied to the spans.
- */
-void insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, std::size_t nbBits, std::shared_ptr<Tensor> calibrationData);
+void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
 
 }  // namespace QuantLSQ
 }  // namespace Aidge
diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp
index b5193bddcfe345a1702f02fcc139a4cf5b94a1ce..1de797693468273814f4c5e82a161991648d06ff 100644
--- a/python_binding/pybind_PTQ.cpp
+++ b/python_binding/pybind_PTQ.cpp
@@ -78,7 +78,7 @@ void init_PTQ(py::module &m) {
     :type value_ranges: list of float.
     )mydelimiter");
 
-    m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("no_quant")=false, py::arg("optimize_signs"), py::arg("verbose") = false,
+    m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("no_quantization")=false, py::arg("optimize_signs"), py::arg("verbose") = false,
     R"mydelimiter(
     Quantize an already normalized (in term of parameters and activations) network.
     :param network: The GraphView to be quantized.
@@ -93,7 +93,7 @@ void init_PTQ(py::module &m) {
     :type verbose: bool
     )mydelimiter");
 
-    m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = true, py::arg("optimize_signs") = false, py::arg("single_shift") = false,  py::arg("use_cuda") = false, py::arg("verbose") = false,
+    m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false,  py::arg("use_cuda") = false, py::arg("verbose") = false,
     R"mydelimiter(
     Main quantization routine. Performs every step of the quantization pipeline.
     :param network: The GraphView to be quantized.
diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp
index 206985efe4558a84ce1ed67a1264bd6902213764..4bba3b6baa5eda41a024399eb1be1402c74b2c1b 100644
--- a/python_binding/pybind_QAT_LSQ.cpp
+++ b/python_binding/pybind_QAT_LSQ.cpp
@@ -23,8 +23,6 @@ void init_QAT_LSQ(py::module &m) {
 
     auto mQuantLSQ = m.def_submodule("lsq");
 
-    mQuantLSQ.def("insert_quantizers", &QuantLSQ::insertQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("step_size"));
-
-    mQuantLSQ.def("insert_and_init_quantizers", &QuantLSQ::insertAndInitQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_data"));
+    mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits"));
 }
 } // namespace Aidge
diff --git a/setup.py b/setup.py
index 1bfc0ac515fd8cceeec4cba666addc1e7666fd25..cde7c1e513e8f3092474bddcb57842efced415e6 100644
--- a/setup.py
+++ b/setup.py
@@ -63,7 +63,7 @@ class AidgePkgBuild(build_ext):
         cxx_compiler = os.environ.get("AIDGE_CXX_COMPILER", "g++")
         build_type = os.environ.get("AIDGE_BUILD_TYPE", "Release")
         asan = os.environ.get("AIDGE_ASAN", "OFF")
-        with_cuda = os.environ.get("AIDGE_WITH_CUDA", "OFF")
+        with_cuda = os.environ.get("AIDGE_WITH_CUDA", "ON") # default could be "OFF"
         cmake_arch = os.environ.get("AIDGE_CMAKE_ARCH", "")
 
         build_gen = os.environ.get("AIDGE_BUILD_GEN", "")
diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index 5265d9c9b1326e73ee4080fe5f69fed5047a0dbb..28858d0e3c693a7620bc32806008523e0602faa9 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -24,6 +24,12 @@
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/utils/Log.hpp"
 
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/ArgMax.hpp"
+#include "aidge/operator/Abs.hpp"
+#include "aidge/operator/Reshape.hpp"
+#include "aidge/operator/Round.hpp"
+
 namespace Aidge
 {
 
@@ -39,27 +45,58 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
 
 static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
 {
-    // Get the tensor data pointer
-    double * castedTensor = static_cast<double *> (tensor->getImpl()->rawPtr());
-
-    // Rescale the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] *= scaling;
+    auto mulOp = Mul_Op();
+    mulOp.setDataType(tensor->dataType());
+    mulOp.setBackend(tensor->backend());
+
+    std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(scaling);
+    scalingTensor->setDataType(tensor->dataType());
+    scalingTensor->setBackend(tensor->backend());
+
+    mulOp.associateInput(0, tensor);
+    mulOp.associateInput(1, scalingTensor);
+
+    mulOp.forward();
+    
+    auto outTensor = mulOp.getOutput(0);
+    *tensor = *outTensor;
+    //tensor->copyCast(*outTensor);
 }
 
+// TODO : make the retreival of argmax values backend independant (refCastFrom)
 static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 {
-    // Get the tensor data pointer and edit it
-    double * castedTensor = static_cast<double*> (tensor->getImpl()->rawPtr());
-
-    // Get the tensor absolute max value
-    double maxValue = 0.0;
-    for(std::size_t i = 0; i < tensor->size(); ++i) {
-        if(std::fabs(castedTensor[i]) > maxValue) {
-            maxValue = std::fabs(castedTensor[i]);
-        }
-    }
-    return maxValue;
+    // get the abs tensor
+
+    std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
+
+    // flatten the abs tensor
+
+    std::int64_t nbElement = tensor->size();
+
+    auto reshapeOp = Reshape_Op({nbElement});
+    reshapeOp.setDataType(tensor->dataType());
+    reshapeOp.setBackend(tensor->backend());
+
+    reshapeOp.associateInput(0, absTensor);
+    reshapeOp.forward();
+    std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0);
+
+    // Get the argmax
+
+    auto argmaxOp = ArgMax_Op(0, true, false);
+    argmaxOp.setDataType(tensor->dataType());
+    argmaxOp.setBackend(tensor->backend());
+
+    argmaxOp.associateInput(0, flatTensor);
+    argmaxOp.forward();
+    std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0);
+
+    // Return the max
+
+    int maxIndex = std::round(argmaxTensor->get<double>(0));
+
+    return flatTensor->get<double>(maxIndex);
 }
 
 void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta)
@@ -83,22 +120,13 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
         if (isAffine(node))
             affineNodeVector.push_back(node);
 
-    if (affineNodeVector.empty()) {
-        Log::notice("No affine nodes found in the network. CLE cannot be applied.");
-        return;
-    }
     double maxRangeDelta;
-    int iteration = 0;
 
     do
     {
-        ++iteration;
         maxRangeDelta = 0.0;
-        //std::cout << " ----- " << std::endl;
-        //for (std::shared_ptr<Node> node : affineNodeVector)
-        //    std::cout << getTensorAbsoluteMax(getWeightTensor(node)) << std::endl;
-
-        for (std::size_t i = 0; i < (affineNodeVector.size() - 1); i++)
+        
+        for (size_t i = 0; i < (affineNodeVector.size() - 1); i++)
         {
             std::shared_ptr<Node> n1 = affineNodeVector[i];
             std::shared_ptr<Node> n2 = affineNodeVector[i+1];
@@ -120,9 +148,6 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
         }
     }
     while (maxRangeDelta > targetDelta);
-
-    Log::notice("CLE completed after {} iterations. Final max range delta: {:.6f}",
-                iteration, maxRangeDelta);
 }
 
 }
\ No newline at end of file
diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp
index 57ad7a836bbb6251a8eeb6da87e3647b4f54afe2..66b0ab36fba7634d7ee350cdccb27895ffa52da1 100644
--- a/src/PTQ/Clipping.cpp
+++ b/src/PTQ/Clipping.cpp
@@ -26,7 +26,7 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string,
 
     std::shared_ptr<Node> firstNode = retrieveNodeVector(graphView)[0];
 
-    //std::cout << " COMPUTING HISTOGRAMS ... " << std::endl;
+    // Log::debug(" COMPUTING HISTOGRAMS ... ");
 
     std::map<std::string, std::vector<int>> histograms;
 
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 0e26313475bbbda23a56dcdda52d55a0a5af8204..7c29ee0b9178fbb07f4a2d5edf9f0ad7ac8dcac4 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -12,7 +12,7 @@
 #include "aidge/quantization/PTQ/CLE.hpp"
 #include "aidge/quantization/PTQ/Clipping.hpp"
 #include "aidge/quantization/PTQ/PTQ.hpp"
-#include "aidge/quantization/PTQ/PTQMetaOps.hpp"
+#include "aidge/operator/PTQMetaOps.hpp"
 
 
 #include "aidge/data/Tensor.hpp"
@@ -28,6 +28,12 @@
 #include "aidge/operator/BatchNorm.hpp"
 #include "aidge/operator/Conv.hpp"
 
+#include "aidge/operator/ArgMax.hpp"
+#include "aidge/operator/Abs.hpp"
+#include "aidge/operator/Reshape.hpp"
+#include "aidge/operator/Round.hpp"
+
+
 #include "aidge/recipes/Recipes.hpp"
 #include "aidge/recipes/QuantRecipes.hpp"
 
@@ -66,51 +72,75 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView)
     return true;
 }
 
-static void fillTensor(std::shared_ptr<Tensor> tensor, double value)
+static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
 {
-    // Get the tensor data pointer
-    double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr());
+    auto mulOp = Mul_Op();
+    mulOp.setDataType(tensor->dataType());
+    mulOp.setBackend(tensor->backend());
 
-    // Fill the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] = value;
-}
+    std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(scaling);
+    scalingTensor->setDataType(tensor->dataType());
+    scalingTensor->setBackend(tensor->backend());
 
-static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
-{
-    // Get the tensor data pointer
-    double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr());
+    mulOp.associateInput(0, tensor);
+    mulOp.associateInput(1, scalingTensor);
 
-    // Rescale the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] *= scaling;
+    mulOp.forward();
+    
+    auto outTensor = mulOp.getOutput(0);
+    *tensor = *outTensor;
 }
 
 static void roundTensor(std::shared_ptr<Tensor> tensor)
 {
-    // Get the tensor data pointer
-    double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr());
+    auto roundOp = Round_Op();
+    roundOp.setDataType(tensor->dataType());
+    roundOp.setBackend(tensor->backend());
 
-    // Rescale the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] = std::nearbyint(castedTensor[i]);//Round
+    roundOp.associateInput(0, tensor);
+    roundOp.forward();
+    
+    auto outTensor = roundOp.getOutput(0);
+    *tensor = *outTensor;
 }
 
-static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
+// TODO : make the retreival of argmax values backend independant (refCastFrom)
+static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 {
-    // Get the tensor data pointer and edit it
-    double * castedTensor = static_cast<double*>(tensor->getImpl()->rawPtr());
-
-    // Get the tensor absolute max value
-    double maxValue = 0.0f;
-    for(std::size_t i = 0; i < tensor->size(); ++i) {
-        if(std::fabs(castedTensor[i]) > maxValue) {
-            maxValue = std::fabs(castedTensor[i]);
-        }
-    }
-    return maxValue;
+    // get the abs tensor
+
+    std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
+
+    // flatten the abs tensor
+
+    std::int64_t nbElement = tensor->size();
+
+    auto reshapeOp = Reshape_Op({nbElement});
+    reshapeOp.setDataType(tensor->dataType());
+    reshapeOp.setBackend(tensor->backend());
+
+    reshapeOp.associateInput(0, absTensor);
+    reshapeOp.forward();
+    std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0);
+
+    // Get the argmax
+
+    auto argmaxOp = ArgMax_Op(0, true, false);
+    argmaxOp.setDataType(tensor->dataType());
+    argmaxOp.setBackend(tensor->backend());
+
+    argmaxOp.associateInput(0, flatTensor);
+    argmaxOp.forward();
+    std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0);
+
+    // Return the max
+
+    int maxIndex = std::round(argmaxTensor->get<double>(0));
+
+    return flatTensor->get<double>(maxIndex);
 }
 
+
 // TODO : pass nodeVector by reference ...
 static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::shared_ptr<Node>> nodeVector, std::string nodeType)
 {
@@ -185,6 +215,8 @@ void prepareNetwork(std::shared_ptr<GraphView> graphView)
 {
     removeFlatten(graphView);
 
+    sanitizeNodeNames(graphView);
+
     bool containsBatchNorm = false;
     std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
 
@@ -876,50 +908,42 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
 
     for (std::shared_ptr<Node> node : nodeVector)
     {
-        // A merging node is always followed by a scaling node at this point ...
+        // A merging node is always followed by a Quantizer node at this point
 
         if (node->type() == "Quantizer")
         {   
+            // check if the Quantizer is a residual one, and insert a compensation node if so ...
+
             bool prevNodeIsForking = ((node->getParent(0))->getChildren().size() > 1);
             bool prevNodeIsAffine = isAffine(node->getParent(0));
             bool insertNode = prevNodeIsForking || !prevNodeIsAffine;
 
             if (insertNode)
             {
-                // create and insert the multplicative node
+                // create and insert the multplicative node before the Quantizer
 
                 std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView);
                 std::shared_ptr<Node> mulNode = Mul(mulNodeName);
-
                 mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
                 mulNode->getOperator()->setBackend("cpu");
 
                 graphView->insertParent(node, mulNode, 0, 0, 0);
 
-                // create and insert the producer node
-
-                std::shared_ptr<Tensor> inputTensor = std::static_pointer_cast<Tensor> (mulNode->getOperator()->getRawInput(0));
-                std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>();
+                // Add the coeff producer to the multiplier node
 
-                coeffTensor->setDataType(DataType::Float64); // getDataType(parentNode)
-                coeffTensor->setBackend("cpu"); 
+                std::shared_ptr<Node> coeffProducer = addProducer(mulNode, 1, {1}, ""); 
+                std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(signedMax);
+                coeffProducer->getOperator()->setOutput(0, coeffTensor);
 
-                coeffTensor->resize(inputTensor->dims());
-                fillTensor(coeffTensor, 1); 
+                coeffProducer->getOperator()->setDataType(DataType::Float64);
+                coeffProducer->getOperator()->setBackend("cpu"); 
 
-                std::shared_ptr<Node> producerNode = Producer(coeffTensor, makeUniqueName("coeff", graphView));
-                producerNode->addChild(mulNode);
-                graphView->add(producerNode);
+                graphView->add(coeffProducer); // needed ?
 
-                // rescale the coeffs and edit scaling factor
+                // Adapt the scaling factor value accordingly
 
-                fillTensor(coeffTensor, signedMax);
-
-                double currScalingFactor = getScalingFactor(node); // XXX bad naming !
+                double currScalingFactor = getScalingFactor(node); 
                 updateScalingFactor(node, currScalingFactor / signedMax);
-
-                // TODO : double check this !!!
-                //std::cout << getTensorAbsoluteMax(coeffTensor) << std::endl;
             }
         }
     }
@@ -931,7 +955,8 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
 
     for (std::shared_ptr<Node> node : nodeVector)
     {
-        // Use A meatoperator of type Scaling of MulCompensation instead
+        // TODO : use Compensation nodes instead of Mul nodes
+
         if (isAffine(node) || (node->type() == "Mul"))
         {
             std::shared_ptr<Node> scalingNode = (*node->getChildren().begin());
@@ -940,7 +965,7 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
 
             double approx = std::pow(2, std::ceil(std::log2(base)));
 
-            updateScalingFactor(scalingNode,approx);
+            updateScalingFactor(scalingNode, approx);
 
             double ratio = base / approx;
 
@@ -954,7 +979,7 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
                 std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
                 rescaleTensor(biasTensor, ratio);
                 if (!noQuant)
-                roundTensor(biasTensor);
+                    roundTensor(biasTensor);
             }
         }
     }
@@ -962,7 +987,6 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
 
 static void printScalingFactors(std::shared_ptr<GraphView> graphView)
 {
-    Log::info(" === SCALING FACTORS === ");
     for (auto node : retrieveNodeVector(graphView))
         if (node->type() == "Scaling" || node->type() == "Quantizer")
         {
@@ -995,7 +1019,7 @@ static void printRanges(std::shared_ptr<GraphView> graphView, std::map<std::stri
     auto scheduling = scheduler.getStaticScheduling();
     for (auto node : scheduling)
         if (node->type() == "Scaling")
-            fmt::println("{} range = {}", node->name(), valueRanges[node->name()]);
+            Log::info(" {} range = {} ", node->name(), valueRanges[node->name()]);
 }
 
 void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose)
@@ -1024,13 +1048,13 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
     Log::info(" Computing the value ranges ...");
     std::map<std::string, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda);
 
-    //std::cout << " === RANGES (BEFORE ADJUST) ===" << std::endl;
+    //Log::info(" === RANGES (BEFORE ADJUST) ===");
     //printRanges(graphView, valueRanges);
 
     Log::info(" Optimizing the clipping values ...");
     valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose);
 
-    //std::cout << " === RANGES (AFTER ADJUST) ===" << std::endl;
+    //Log::info(" === RANGES (AFTER ADJUST) ===");
     //printRanges(graphView, valueRanges);
 
     Log::info(" Normalizing the activations ...");
@@ -1051,14 +1075,15 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
     if (verbose)
         printScalingFactors(graphView);
 
-    //std::cout << " === SCALINGS (BEFORE CAST) ===" << std::endl;
+    //Log::info(" === SCALINGS (BEFORE CAST) ===");
     //printScalingFactors(graphView);
 
     setupDataType(graphView, inputDataSet, initialDataType);
+
     if (useCuda)
         graphView->setBackend("cuda");
 
-    //std::cout << " === SCALINGS (AFTER CAST) ===" << std::endl;
+    //Log::info(" === SCALINGS (AFTER CAST) ===");
     //printScalingFactors(graphView);
 
     Log::info(" Reseting the scheduler ...");
@@ -1098,7 +1123,7 @@ void clearBiases(std::shared_ptr<GraphView> graphView)
 void devPTQ(std::shared_ptr<GraphView> graphView) 
 {
     for (std::shared_ptr<Node> node : graphView->getNodes())
-        fmt::println(" UUU : {}", node->name());
+        Log::info(" UUU : {}", node->name());   
 }
 
 }
diff --git a/src/QAT/QAT_FixedQ.cpp b/src/QAT/QAT_FixedQ.cpp
index 9160b4ae6add5ae0347e008962956dc90c3a36fd..6ada53239f92d19f96dc87e0b91247aa093caecf 100644
--- a/src/QAT/QAT_FixedQ.cpp
+++ b/src/QAT/QAT_FixedQ.cpp
@@ -91,7 +91,7 @@ static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView>
             const auto op = std::static_pointer_cast<FixedQ_Op>(node->getOperator());
             float inputStd = getTensorStd(op->getInput(0));
             inputStats.insert(std::make_pair(node->name(), inputStd));
-            fmt::println("{} -> {}", node->name(), inputStd);
+            Log::info(" {} -> {} ", node->name(), inputStd);
         }
     }
 
@@ -108,7 +108,7 @@ static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView>
             const auto op = std::static_pointer_cast<FixedQ_Op>(node->getOperator());
             float paramStd = getTensorStd(op->getInput(1));
             paramStats.insert(std::make_pair(node->name(), paramStd));
-            fmt::println("{} -> {}", node->name(), paramStd);
+            Log::info(" {} -> {} ", node->name(), paramStd);
         }
     }
     
@@ -156,7 +156,7 @@ void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView)
     scheduler.generateScheduling();
     auto s = scheduler.getStaticScheduling();
     for (std::shared_ptr<Node> node : s)
-        fmt::println(" name : {}", node->name());
+        Log::info(" name : {} ", node->name());
 }
 
 }
\ No newline at end of file
diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp
index 9b51e846df498a9303b7373ae1c86d4b007a96f0..6eae077b060027eb4029f6b59f55376a1674df70 100644
--- a/src/QAT/QAT_LSQ.cpp
+++ b/src/QAT/QAT_LSQ.cpp
@@ -21,193 +21,152 @@
 #include "aidge/graph/Matching.hpp"
 #include "aidge/recipes/QuantRecipes.hpp"
 
-namespace Aidge {
 
-void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float stepSize)
+namespace Aidge 
 {
-    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
 
-    for (const auto& match : matches) 
-    {
-        auto linearNode = match.graph->rootNode(); 
+static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
+{
+    auto valueTensor = (*tensor).abs().mean();
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu");
+    return localTensor.get<float>(0);
+}
 
-        std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
-        std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1};
+static float getTensorStd(std::shared_ptr<Tensor> tensor)
+{
+    auto valueTensor = (*tensor);
+    
+    auto skewedTensor = valueTensor - valueTensor.mean();
+    auto squaredTensor = skewedTensor * skewedTensor;
+    auto varianceTensor = squaredTensor.mean();
 
-        // INPUT QUANTIZERS INSERTION
+    std::shared_ptr<Tensor> fallback;
+    auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu");
+    
+    float variance = localTensor.get<float>(0);
+    return std::sqrt(variance);
+}
 
-        // TODO : double check this, and use createUniqueName()
-        auto inputQuantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
-        auto inputQuantizerNode = LSQ(signedRange, inputQuantizerName);
 
-        // Set the step size
+// INIT THE STEP SIZE OF A QUANTIZER NODE
 
-        auto inputStepSizeOp = inputQuantizerNode->getParent(1)->getOperator();
-        auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        inputStepSizeOp->setOutput(0, inputStepSizeTensor);
+static bool initStepSize(std::shared_ptr<Node> quantizer)
+{
+    const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator());
 
-        // Absorb the ReLU when possible ...
+    // This formula is the one proposed in the paper ...
 
-        // XXX is this safe ???
-        bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); 
-        // bool nodeHasParent = (linearNode->getParents().size() != 0);
+    // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0));
+    // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second));
 
-        if (nodeHasParent) {
-            auto parentNode = linearNode->getParents()[0];
-            if (parentNode->type() == "ReLU") {
-                auto inputQuantizerOp = std::static_pointer_cast<LSQ_Op> (inputQuantizerNode->getOperator());
-                inputQuantizerOp->range() = unsignedRange;
-                graphView->replace({parentNode}, {}); 
-            }
-        }
+    // .. but this formula seems to work better !!!
 
-        // We need to handle the case where the linear node is the first one ...
+    float inputStd = getTensorStd(quantizerOp->getInput(0));
+    float stepSize = 8.0f * (inputStd / (quantizerOp->range().second));
 
-        if (nodeHasParent) {
-            graphView->insertParent(linearNode, inputQuantizerNode, 0, 0, 0);
-        } else {
-            inputQuantizerNode->addChild(graphView);
-            graphView->add(inputQuantizerNode);
-        }
+    // TODO : use the scalar constructor
+    auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); 
 
-        // PARAM QUANTIZERS INSERTION
+    // XXX Manage backend here ?
+    stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend());
+    stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType());
 
-        // TODO : double check this, and use createUniqueName()
-        auto paramQuantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
-        auto paramQuantizerNode = LSQ(signedRange, paramQuantizerName); 
-        graphView->insertParent(linearNode, paramQuantizerNode, 1, 0, 0);
+    auto stepSizeProducer = quantizer->getParent(1);
 
-        // Set the step size
+    stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor);
 
-        auto paramStepSizeOp = paramQuantizerNode->getParent(1)->getOperator();
-        auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        paramStepSizeOp->setOutput(0, paramStepSizeTensor);
-    }
+    Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize);
 
+    return false;
 }
 
-static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
+static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
-    auto backend = tensor->backend();
-    if (backend == "cuda")
-        tensor->setBackend("cpu");
+    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
 
-    float acc = 0;
-    float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr());
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        acc += std::abs(castedTensor[i]);
-    acc /= static_cast<float> (tensor->size());
+    for (const auto& match : matches) 
+    {
+        auto linearNode = match.graph->rootNode(); 
 
-    if (backend == "cuda")
-        tensor->setBackend("cuda");
+        // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type());
 
-    return acc;
-}
+        std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
+        std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1};
 
-static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> calibrationData, bool useCuda)
-{
-    // Propagate the calibration tensor
+        // Create the input quantizer node
 
-    SequentialScheduler scheduler(graphView);
-    scheduler.resetScheduling();
-    scheduler.forward(true, {calibrationData});
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName);
 
-    // Store the input tensor statistics
+        // Init the step-size using the node call stack
 
-    if (useCuda)
-        graphView->setBackend("cpu"); 
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
 
-    std::map<std::string, float> inputStats;
-    for (auto node : graphView->getNodes())
-    {
-        if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!!
-        {
-            const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator());
-            float inputAbsMean = getTensorAbsMean(op->getInput(0));
-            inputStats.insert(std::make_pair(node->name(), inputAbsMean));
-            fmt::println("{} -> {}", node->name(), inputAbsMean);
-        }
-    }
+        // Absorb the ReLU when possible ...
 
-    if (useCuda)
-        graphView->setBackend("cuda");
+        bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]);  // XXX is this safe ?
 
-    return inputStats;
-}
+        if (nodeHasParent) 
+        {
+            bool allParentsAreReLU = true;
+            for (auto parentNode : linearNode->getParents())
+                if (parentNode->type() != "ReLU")
+                    allParentsAreReLU = false;
+
+            if (allParentsAreReLU) {
+                auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator());
+                quantizerOp->range() = unsignedRange;
+            }
 
-static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> graphView, bool useCuda)
-{
-    if (useCuda)
-        graphView->setBackend("cpu");
+            // TODO : remove the ReLUs when possible
+        }
 
-    std::map<std::string, float> paramStats;
-    for (auto node : graphView->getNodes())
-    {
-        if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!!
-        {
-            const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator());
-            float paramAbsMean = getTensorAbsMean(op->getInput(1));
-            paramStats.insert(std::make_pair(node->name(), paramAbsMean));
-            fmt::println("{} -> {}", node->name(), paramAbsMean);
+        // Insert the quantizer in the graphView ...
+        // (We need to handle the case where the linear node is the first one)
+
+        if (nodeHasParent) {
+            graphView->insertParent(linearNode, quantizerNode, 0, 0, 0);
+        } else {
+            quantizerNode->addChild(graphView);
+            graphView->add(quantizerNode);
         }
     }
-    
-    if (useCuda)
-        graphView->setBackend("cuda");
-
-    return paramStats;
 }
 
-static void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView, std::map<std::string, float> inputStats, std::map<std::string, float> paramStats)
-{
-    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
+// PARAM QUANTIZERS INSERTION
 
-    for (const auto& match : matches) 
-    {
-        auto linearNode = match.graph->rootNode();
+static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
+{
+    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
 
-        // INPUT QUANTIZERS STEP-SIZES
+    std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
 
-        auto inputQuantNode = linearNode->getParent(0);
-        auto inputQuantOp = std::static_pointer_cast<LSQ_Op>(inputQuantNode->getOperator());
+    for (const auto& match : matches) 
+    {       
+        auto linearNode = match.graph->rootNode(); 
 
-        float absMean = inputStats[linearNode->name()];
-        float stepSize = 2.0f * (absMean / std::sqrt(inputQuantOp->range().second));
+        // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type());
 
-        auto inputStepSizeOp = inputQuantNode->getParent(1)->getOperator();
-        // XXX inputStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})));
-        auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        inputStepSizeOp->setOutput(0, inputStepSizeTensor);
+        // TODO : double check this, and use createUniqueName()
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName); 
 
-        // PARAM QUANTIZERS STEP-SIZES
+        // Init the step-size using the node call stack
 
-        auto paramQuantNode = linearNode->getParent(1);
-        auto paramQuantOp = std::static_pointer_cast<LSQ_Op>(paramQuantNode->getOperator());
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
 
-        absMean = paramStats[linearNode->name()];
-        stepSize = 2.0f * (absMean / std::sqrt(paramQuantOp->range().second));
+        // Insert the quantizer in the graphView
 
-        auto paramStepSizeOp = paramQuantNode->getParent(1)->getOperator();
-        // XXX paramStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})));
-        auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        paramStepSizeOp->setOutput(0, paramStepSizeTensor);
+        graphView->insertParent(linearNode, quantizerNode, 1, 0, 0);
     }
 }
 
-void QuantLSQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData)
+void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
-    bool useCuda = (calibrationData->backend() == "cuda");
-
-    // Collect the tensor statisics
-    auto inputStats = collectInputStats(graphView, calibrationData, useCuda);
-
-    auto paramStats = collectParamStats(graphView, useCuda);
-
-    // Insert the quantizers
-    insertQuantizers(graphView, nbBits, 1.0);
-
-    // Adjust the quantizers step-sizes
-    adjustQuantizersStepSizes(graphView, inputStats, paramStats);
+    sanitizeNodeNames(graphView);
+    setupInputQuantizers(graphView, nbBits);
+    setupParamQuantizers(graphView, nbBits);
 }
 
 }
\ No newline at end of file
diff --git a/src/backend/cuda/operator/LSQImpl.cpp b/src/backend/cuda/operator/LSQImpl.cpp
index c66bd8a5aa78513b4bcceec83f9c9d87ffed2b11..fa45f211e72f6742b72584aadf2a109c3bdca594 100644
--- a/src/backend/cuda/operator/LSQImpl.cpp
+++ b/src/backend/cuda/operator/LSQImpl.cpp
@@ -52,19 +52,6 @@ void Aidge::LSQImpl_cuda::backward() {
     std::shared_ptr<Tensor> gra_int1 = op_.getInput(1)->grad();
     std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();    
 
-    // XXX
-/*
-    size_t tmp;
-
-    cudaDeviceSetLimit(cudaLimitStackSize, 2048);
-    cudaDeviceGetLimit(&tmp, cudaLimitStackSize );
-    printf(" stack limit = %ld \n", tmp);
-
-    cudaDeviceSetLimit(cudaLimitMallocHeapSize, 100000000);
-    cudaDeviceGetLimit(&tmp, cudaLimitMallocHeapSize);
-    printf(" heap limit = %ld \n", tmp);
-*/
-
     if (gra_int0->size() > mWorkspaceSize) {
         // std::cout << " reallocation " << sizeof(gra_int0) << " " << gra_int0->size() << std::endl;
         if (mWorkspace != nullptr) {
@@ -87,12 +74,7 @@ void Aidge::LSQImpl_cuda::backward() {
         gra_int0->getImpl()->rawPtr(),
         gra_int1->getImpl()->rawPtr(),
         mWorkspace);
-/*
-    gra_int1->setBackend("cpu");
-    float *castedTensor = static_cast<float *> (gra_int1->getImpl()->rawPtr());
-    std::cout << castedTensor[0] << std::endl;
-    gra_int1->setBackend("cuda");
-*/
+
 }
 
 Aidge::LSQImpl_cuda::~LSQImpl_cuda() {
diff --git a/src/PTQ/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp
similarity index 98%
rename from src/PTQ/PTQMetaOps.cpp
rename to src/operator/PTQMetaOps.cpp
index 77018c23aee2f1ef6f430389393fd35e97baa0f6..56245da47076d8930ce29ab75e549d97d0d7493d 100644
--- a/src/PTQ/PTQMetaOps.cpp
+++ b/src/operator/PTQMetaOps.cpp
@@ -9,13 +9,12 @@
  *
  ********************************************************************************/
 
-#include "aidge/quantization/PTQ/PTQMetaOps.hpp"
+#include "aidge/operator/PTQMetaOps.hpp"
 
 #include <memory>
 #include <string>
 #include <utility>
 
-//Operator
 #include "aidge/operator/Clip.hpp"
 #include "aidge/operator/Mul.hpp"
 #include "aidge/operator/Round.hpp"
diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp
index 6e1dcdb1b64c0a1e94c74ce66cb71f1a458bca35..f03eb462088b16645fe600769e2a5e2c990f21b6 100644
--- a/src/recipes/QuantRecipes.cpp
+++ b/src/recipes/QuantRecipes.cpp
@@ -9,24 +9,13 @@
  *
  ********************************************************************************/
 
-/*
-#include "aidge/data/Tensor.hpp"
-#include "aidge/graph/GraphView.hpp"
-#include "aidge/graph/Node.hpp"
-#include "aidge/scheduler/SequentialScheduler.hpp"
-#include "aidge/scheduler/Scheduler.hpp"
-#include "aidge/utils/Log.hpp"
-
-#include "aidge/operator/Producer.hpp"
-#include "aidge/operator/Mul.hpp"
-#include "aidge/operator/ReLU.hpp"
-#include "aidge/operator/Scaling.hpp"
-*/
 
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/BatchNorm.hpp"
 //#include "aidge/quantization/PTQ/PTQ.hpp"
 #include "aidge/recipes/QuantRecipes.hpp"
+#include "aidge/graph/Node.hpp"
+
 
 namespace Aidge 
 {
@@ -55,14 +44,16 @@ void insertBatchNormNodes(std::shared_ptr<GraphView> graphView)
 {
     for (std::shared_ptr<Node> parentNode : graphView->getNodes())
     {
-        if (parentNode->type() == "Conv2D")
+        // TODO : use graph matching
+
+        if (parentNode->type() == "Conv2D" || parentNode->type() == "PaddedConv2D")
         {
-            std::shared_ptr<Conv_Op<2>> convOperator = std::static_pointer_cast<Conv_Op<2>> (parentNode->getOperator());
-            int nb_channels = convOperator->getInput(1)->dims()[0];
-            fmt::println(" NB CHANNELS = {}", nb_channels); // TODO : remove this ...
+            std::shared_ptr<OperatorTensor> convOperator = std::static_pointer_cast<OperatorTensor> (parentNode->getOperator());
+            int nbChannels = convOperator->getInput(1)->dims()[0];
+            Log::notice(" NB CHANNELS = {} ", nbChannels);
 
             std::string batchnormNodeName = makeUniqueName(parentNode->name() + "_BN", graphView);
-            std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nb_channels, 1e-5, 0.1, false, batchnormNodeName);
+            std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nbChannels, 1e-5, 0.1, false, batchnormNodeName);
             batchnormNode->getOperator()->setDataType(DataType::Float32);
             batchnormNode->getOperator()->setBackend("cpu");
 
@@ -118,6 +109,7 @@ std::string makeUniqueName(std::string baseName, std::shared_ptr<GraphView> grap
     return newName;
 }
 
+
 void sanitizeNodeNames(std::shared_ptr<GraphView> graphView)
 {
     for (std::shared_ptr<Node> node : graphView->getNodes())