diff --git a/.gitignore b/.gitignore
index ba5c59398b68083c6c1c5fe820fb9070d999c18e..c64cbb5b6997c5c332326460eb36296247a88979 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,8 +5,10 @@
 build*/
 install*/
 include/aidge/backend/quantization_version.h
+include/aidge/quantization_version.h
 
-# VSCode
+
+# VSCodes
 .vscode
 
 # Python
diff --git a/aidge_quantization/_version.py b/aidge_quantization/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d34d3557071ed5c22aea83c63bfb7684b180cf9
--- /dev/null
+++ b/aidge_quantization/_version.py
@@ -0,0 +1,4 @@
+# file generated by setuptools_scm
+# don't change, don't track in version control
+__version__ = version = '0.2.1.dev60+g8044e79.d20250106'
+__version_tuple__ = version_tuple = (0, 2, 1, 'dev60', 'g8044e79.d20250106')
\ No newline at end of file
diff --git a/include/aidge/quantization/PTQ/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp
similarity index 64%
rename from include/aidge/quantization/PTQ/PTQMetaOps.hpp
rename to include/aidge/operator/PTQMetaOps.hpp
index 62fac873235f2b89a242042de9260fc350ad6aa8..a65e4d52a11eb83463208088707da57cbc78eae2 100644
--- a/include/aidge/quantization/PTQ/PTQMetaOps.hpp
+++ b/include/aidge/operator/PTQMetaOps.hpp
@@ -37,13 +37,33 @@ namespace Aidge {
 /// @return A shared pointer to an instance of the meta-operator node.
 std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name);
 
-/// @brief The purpose of Scaling is to encapsulate the Mul operator and tag it as a PTQ node rather than a regular Mul operator.
-/// Therefore, this meta-operator consists solely of a [Mul] operation.
+/// @brief IntQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration 
+///        into computation graphs with a data type other than Float while preserving floating-point precision.
+/// 
+/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after 
+/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), 
+/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs 
+/// while maintaining the precision of floating-point operations.
 ///
-/// @param scalingFactor The scaling factor to apply to the input (a scalar to multiply the input with).
+/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted.
+/// @param targetType The target data type to which the final output should be cast after the quantization process.
 /// @param name The name of the meta-operator node created.
-/// @return A shared pointer to an instance of the scaling node.
-std::shared_ptr<Aidge::Node> Scaling(double scalingFactor, const std::string& name = "");
+/// @return A shared pointer to a new instance of the modified meta-operator node.
+std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name);
+
+/// @brief BitShiftQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration 
+///        into computation graphs with a data type other than Float while preserving floating-point precision.
+/// 
+/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after 
+/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), 
+/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs 
+/// while maintaining the precision of floating-point operations.
+///
+/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted.
+/// @param targetType The target data type to which the final output should be cast after the quantization process.
+/// @param name The name of the meta-operator node created.
+/// @return A shared pointer to a new instance of the modified meta-operator node.
+std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name);
 
 /// @brief Updates the scaling factor of a PTQ meta-operator node, allowing for dynamic adjustment of the scaling parameter.
 /// This function sets a new scaling factor for a specified meta-operator node, modifying the scalar applied in the [Mul] operation.
diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp
index d2b8b7f78fccc15cf4afd598b02f0f7b391375e9..3a35017404337c60845578aee8d0f0bb249bb0b7 100644
--- a/include/aidge/quantization/PTQ/PTQ.hpp
+++ b/include/aidge/quantization/PTQ/PTQ.hpp
@@ -66,6 +66,26 @@ namespace Aidge {
      * @return The scheduled vector of nodes
      */
     std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule = true, bool verbose = false);
+    
+    /**
+     * @brief Inserts a scaling node below the given producer node in the graph view. 
+     *        If the node is already a producer scaling node, it accumulates the scaling factor by multiplyins its value directly.
+     *
+     * @param node A shared pointer to the producer node where the scaling node will be inserted (below).
+     * @param scalingFactor The scaling factor to apply.
+     * @param graphView A shared pointer to the graph view in which the nodes are located.
+     * @return True if the scaling node was successfully inserted or the scaling factor was accumulated; False otherwise.
+     */
+    bool insertScalingBelowProducer(std::shared_ptr<Node> node, double scalingFactor, std::shared_ptr<GraphView> graphView);
+
+    /**
+     * @brief Inserts a rounding node below the given producer (also below its ows producerScaling) node in the graph view. 
+     *
+     * @param node A shared pointer to the producer node where the rounding node will be inserted.
+     * @param graphView A shared pointer to the graph view in which the nodes are located.
+     * @return True if the rounding node was successfully inserted; False otherwise.
+     */
+    bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView);
 
     /**
      * @brief Determine whether an input GraphView can be quantized or not.
@@ -74,6 +94,14 @@ namespace Aidge {
      */
     bool checkArchitecture(std::shared_ptr<GraphView> graphView);
 
+    /**
+     * @brief This function multiplies the existing scaling factor by a given coefficient. It verifies that the node is of the correct type ("Mul") 
+     * and has the `isScaling` attribute. If these conditions are not met, a warning is logged.
+     * @param node A shared pointer to an `Aidge::Node` object representing the node to modify.
+     * @param coeff  A double representing the multiplication coefficient to apply to the scaling factor.
+     */
+    void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff);
+
 
     void prepareNetwork(std::shared_ptr<GraphView> graphView);
 
@@ -138,7 +166,8 @@ namespace Aidge {
      * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights. 
      * @param verbose Whether to print internal informations about the quantization process.
      */
-    void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool applyRounding, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose);
+    void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet,
+     Clipping clippingMode, DataType targetType, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda,bool foldGraph ,bool verbose);
 
     /**
      * @brief Compute the weight ranges of every affine node. Provided for debugging purposes.
diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp
index 4970be07fae8737a1c2863600757bb81ff3a65f9..d7d03ca78ff63b328ba068dd4ff82c61270218e3 100644
--- a/include/aidge/quantization/QAT/QAT_LSQ.hpp
+++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp
@@ -20,22 +20,14 @@ namespace Aidge {
 namespace QuantLSQ {
 
 /**
- * @brief Insert the LSQ quantizer nodes in a given GraphView
- * @param graphView The GraphView containing the graph to quantize.
+ * @brief Given a GraphView with parameters properly initialized, insert
+ * the LSQ quantizer nodes, and setup the adjustment their step-sizes.
+ * @param graphView The GraphView containing the network to quantize.
  * @param nbBits Number of quantization bits.
- * @param span Fixed output span of the quantizers.
  */
-void insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float step_size);
+void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
 
-/**
- * @brief Given a GraphView with parameters properly initialized and some calibration data,
- * insert the LSQ quantizer nodes, and adjust their step-sizes.
- * @param graphView The GraphView containing the graph to quantize.
- * @param nbBits Number of quantization bits.
- * @param calibrationData Calibration data used to adjust the spans.
- * @param scale Multiplicative constant applied to the spans.
- */
-void insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData);
+void devLSQ(std::shared_ptr<Tensor> tensor);
 
 }
 }
diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h
index 546263af3a7e8b7a73991173f48d0b095c7d9501..909ab28d77313e34ed93f46af9ef1dc1d086a036 100644
--- a/include/aidge/quantization_version.h
+++ b/include/aidge/quantization_version.h
@@ -3,9 +3,9 @@
 
 namespace Aidge {
 static constexpr const int PROJECT_VERSION_MAJOR = 0;
-static constexpr const int PROJECT_VERSION_MINOR = 2;
+static constexpr const int PROJECT_VERSION_MINOR = 3;
 static constexpr const int PROJECT_VERSION_PATCH = 0;
-static constexpr const char * PROJECT_VERSION = "0.2.0";
-static constexpr const char * PROJECT_GIT_HASH = "f50c860";
+static constexpr const char * PROJECT_VERSION = "0.3.0";
+static constexpr const char * PROJECT_GIT_HASH = "f0f9e60";
 }
 #endif // VERSION_H
diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp
index b5193bddcfe345a1702f02fcc139a4cf5b94a1ce..290d59d822cd34f861533f3adc0019ab7fa538e9 100644
--- a/python_binding/pybind_PTQ.cpp
+++ b/python_binding/pybind_PTQ.cpp
@@ -13,11 +13,10 @@
 #include <pybind11/stl.h>
 
 #include <string>
-
+#include "aidge/operator/PTQMetaOps.hpp"
 #include "aidge/quantization/PTQ/Clipping.hpp"
 #include "aidge/quantization/PTQ/CLE.hpp"
 #include "aidge/quantization/PTQ/PTQ.hpp"
-
 #include "aidge/graph/GraphView.hpp"
 
 namespace py = pybind11;
@@ -40,6 +39,8 @@ void init_PTQ(py::module &m) {
     :rtype: bool
     )mydelimiter");
 
+    m.def("quantizer",&Quantizer,py::arg("sf"),py::arg("min"),py::arg("max"),py::arg("name"));
+
     m.def("insert_scaling_nodes", &insertScalingNodes, py::arg("network"),
     R"mydelimiter(
     Insert a scaling node after each affine node of the GraphView.
@@ -48,6 +49,14 @@ void init_PTQ(py::module &m) {
     :type network: :py:class:`aidge_core.GraphView`
     )mydelimiter");
 
+    m.def( "multiply_scaling_factor",&multiplyScalingFactor,py::arg("node"), py::arg("coeff"),
+     R"mydelimiter(
+    Updates the scaling factor of a "Mul" node in a graph if the node is marked as a scaling node. This function multiplies the existing scaling factor by a given coefficient.
+    :param node: A node representing the node to modify.
+    :param coeff: A floating value representing the multiplication coefficient to apply to the scaling factor.
+    )mydelimiter"
+    );
+    
     m.def("normalize_parameters", &normalizeParameters, py::arg("network"),
     R"mydelimiter(
     Normalize the parameters of each parametrized node, so that they fit in the [-1:1] range.
@@ -93,7 +102,9 @@ void init_PTQ(py::module &m) {
     :type verbose: bool
     )mydelimiter");
 
-    m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = true, py::arg("optimize_signs") = false, py::arg("single_shift") = false,  py::arg("use_cuda") = false, py::arg("verbose") = false,
+    m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), 
+    py::arg("clipping_mode") = Clipping::MAX ,py::arg("target_type") = DataType::Float64 ,py::arg("no_quantization") = true, py::arg("optimize_signs") = false,
+    py::arg("single_shift") = false,  py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false,
     R"mydelimiter(
     Main quantization routine. Performs every step of the quantization pipeline.
     :param network: The GraphView to be quantized.
diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp
index 206985efe4558a84ce1ed67a1264bd6902213764..0b9fcc29d1a144708537084d4538eaa47873cd05 100644
--- a/python_binding/pybind_QAT_LSQ.cpp
+++ b/python_binding/pybind_QAT_LSQ.cpp
@@ -23,8 +23,9 @@ void init_QAT_LSQ(py::module &m) {
 
     auto mQuantLSQ = m.def_submodule("lsq");
 
-    mQuantLSQ.def("insert_quantizers", &QuantLSQ::insertQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("step_size"));
+    mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits"));
+
+    mQuantLSQ.def("dev_lsq", &QuantLSQ::devLSQ, py::arg("tensor"));
 
-    mQuantLSQ.def("insert_and_init_quantizers", &QuantLSQ::insertAndInitQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_data"));
 }
 } // namespace Aidge
diff --git a/scripts/PTQ/ptq_ts.py b/scripts/PTQ/ptq_ts.py
new file mode 100644
index 0000000000000000000000000000000000000000..b836a7b41981735299ffedb610dc42acda37d903
--- /dev/null
+++ b/scripts/PTQ/ptq_ts.py
@@ -0,0 +1,135 @@
+import unittest
+import re
+import numpy as np
+import gzip
+import aidge_core
+import aidge_onnx
+import os
+import copy
+import aidge_backend_cpu
+import aidge_quantization
+import sys
+import concurrent.futures
+
+aidge_core.Log.set_console_level(aidge_core.Level.Error)
+
+SIGMA = 0.05  # Tolérance
+
+def print_in_color(text, color_code):
+    print(f"\033[{color_code}m{text}\033[0m")
+
+def run_model_test(model_name, expected_values, use_multithreading, asset_path, model_path):
+    NB_SAMPLES = 1000
+    NB_BITS = 4
+    CLIPPING = aidge_quantization.Clipping.MSE
+    VERBOSE = False
+
+    results = []
+
+    samples = np.load(gzip.GzipFile(asset_path + '/mnist_samples.npy.gz', "r"))
+    labels = np.load(gzip.GzipFile(asset_path + '/mnist_labels.npy.gz', "r"))
+
+    def load_model():
+        model = aidge_onnx.load_onnx(model_path + '/' + model_name + ".onnx", verbose=False)
+        aidge_core.remove_flatten(model)
+        model.set_datatype(aidge_core.dtype.float32)
+        model.set_backend("cpu")
+        return model
+
+    aidge_model = load_model()
+    scheduler = aidge_core.SequentialScheduler(aidge_model)
+
+    def propagate(model, scheduler, sample):
+        sample = np.reshape(sample, (1, 1, 28, 28))
+        input_tensor = aidge_core.Tensor(sample)
+        scheduler.forward(True, [input_tensor])
+        output_node = model.get_output_nodes().pop()
+        output_tensor = output_node.get_operator().get_output(0)
+        return np.array(output_tensor)
+
+    def compute_accuracy(model, samples, labels):
+        acc = sum(labels[i] == np.argmax(propagate(model, scheduler, x)) for i, x in enumerate(samples))
+        return acc / len(samples)
+
+    base_accuracy = compute_accuracy(aidge_model, samples[:NB_SAMPLES], labels)
+    if abs(base_accuracy * 100 - expected_values[0]) >= SIGMA:
+        results.append(f"❌ [ERROR] Baseline accuracy mismatch for {model_name}: Expected {expected_values[0]}, got {base_accuracy * 100:.2f}")
+    else:
+        results.append(f"✅ Baseline accuracy for {model_name}: Expected {expected_values[0]}, got {base_accuracy * 100:.2f}")
+
+    quant_model = load_model()
+    tensors = [aidge_core.Tensor(np.reshape(sample, (1, 1, 28, 28))) for sample in samples[:NB_SAMPLES]]
+    aidge_quantization.quantize_network(quant_model, NB_BITS, tensors, CLIPPING, aidge_core.dtype.float64, False, True, False, VERBOSE)
+    scheduler = aidge_core.SequentialScheduler(quant_model)
+
+    scaling = 2**(NB_BITS - 1) - 1
+    samples = samples * scaling
+
+    quant_accuracy = compute_accuracy(quant_model, samples[:NB_SAMPLES], labels)
+    if abs(quant_accuracy * 100 - expected_values[1]) >= SIGMA:
+        results.append(f"❌ [ERROR] Quantized accuracy mismatch for {model_name}: Expected {expected_values[1]}, got {quant_accuracy * 100:.2f}")
+    else:
+        results.append(f"✅ Quantized accuracy for {model_name}: Expected {expected_values[1]}, got {quant_accuracy * 100:.2f}")
+
+    # Quantification Single Shift
+    quant_model_ss = load_model()
+    aidge_quantization.quantize_network(quant_model_ss, NB_BITS, tensors, CLIPPING, aidge_core.dtype.float64, False, True, True, VERBOSE)
+    scheduler = aidge_core.SequentialScheduler(quant_model_ss)
+    quant_accuracy_ss = compute_accuracy(quant_model_ss, samples[:NB_SAMPLES], labels)
+
+    if abs(quant_accuracy_ss * 100 - expected_values[2]) >= SIGMA:
+        results.append(f"❌ [ERROR] Quantized Single Shift Approximation accuracy mismatch for {model_name}: Expected {expected_values[2]}, got {quant_accuracy_ss * 100:.2f}")
+    else:
+        results.append(f"✅ Quantized Single Shift Approximation accuracy for {model_name}: Expected {expected_values[2]}, got {quant_accuracy_ss * 100:.2f}")
+
+    return model_name, results
+
+def run_quantization_test(use_multithreading,model_path,asset_path):
+    EXPECTED_RESULTS = {
+        "MiniResNet": (95.4, 94.5, 94.7),
+        "ConvNet": (97.9, 97.7, 97.4),
+        "BranchNetV4": (93.8, 93.2, 93.7),
+        "TestNet": (95.5, 94.2, 94.2),
+        "MLP": (94.7, 94.2, 93.3)
+    }
+
+    all_results = []
+
+    if use_multithreading:
+        with concurrent.futures.ProcessPoolExecutor() as executor:
+            futures = {executor.submit(run_model_test, model, values, use_multithreading,asset_path,model_path): model for model, values in EXPECTED_RESULTS.items()}
+
+            for future in concurrent.futures.as_completed(futures):
+                model_name = futures[future]
+                try:
+                    model_name, results = future.result()
+                    all_results.append((model_name, results))
+                except Exception as exc:
+                    all_results.append((model_name, [f"❌ [ERROR] {model_name} test failed with exception: {exc}"]))
+    else:
+        for model, values in EXPECTED_RESULTS.items():
+            try:
+                model_name, results = run_model_test(model, values, use_multithreading,asset_path,model_path)
+                all_results.append((model_name, results))
+            except Exception as exc:
+                all_results.append((model, [f"❌ [ERROR] {model} test failed with exception: {exc}"]))
+
+    os.system("clear")
+    for model_name, results in all_results:
+        print(f"Results for {model_name}:")
+        for result in results:
+            if "❌ [ERROR]" in result:
+                print_in_color(result, 31)
+            else:
+                print_in_color(result, 32)
+        print()
+
+if __name__ == "__main__":
+    import argparse
+    parser = argparse.ArgumentParser(description="Run quantization tests.")
+    parser.add_argument("-j", action="store_true", help="Enable multithreading")
+    parser.add_argument("--models_path", type=str, default="/data1/is156025/nz280189/sbx/Models", help="Path to models directory (default: /data)")
+    parser.add_argument("--asset_path", type=str, default="/data1/is156025/nz280189/sbx/assets", help="Path to assets directory (default: /data)")
+    args = parser.parse_args()
+
+    run_quantization_test(use_multithreading=args.j,model_path = args.models_path, asset_path = args.asset_path)
diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index 2c818155877349ad5e5a141469de9f6657873be7..eb5ca7a04ae28326094523d4f6e6974b99aec283 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -14,11 +14,18 @@
 #include "aidge/quantization/PTQ/PTQ.hpp"
 
 #include "aidge/graph/GraphView.hpp"
+
 #include "aidge/scheduler/SequentialScheduler.hpp"
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/utils/Log.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
 
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/ArgMax.hpp"
+#include "aidge/operator/Abs.hpp"
+#include "aidge/operator/Reshape.hpp"
+#include "aidge/operator/Round.hpp"
+
 namespace Aidge
 {
 
@@ -34,27 +41,68 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
 
 static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
 {
-    // Get the tensor data pointer
-    double * castedTensor = static_cast<double *> (tensor->getImpl()->rawPtr());
-
-    // Rescale the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] *= scaling;
+    auto mulOp = Mul_Op();
+    mulOp.setDataType(tensor->dataType());
+    mulOp.setBackend(tensor->backend());
+
+    std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(Aidge::Array1D<double, 1> {scaling});
+    scalingTensor->setDataType(tensor->dataType());
+    scalingTensor->setBackend(tensor->backend());
+
+    mulOp.associateInput(0, tensor);
+    mulOp.associateInput(1, scalingTensor);
+
+    mulOp.forward();
+    
+    auto outTensor = mulOp.getOutput(0);
+    *tensor = *outTensor;
+    //tensor->copyCast(*outTensor);
 }
 
-static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
+// TODO : make the retreival of argmax values backend independant (refCastFrom)
+static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 {
-    // Get the tensor data pointer and edit it
-    double * castedTensor = static_cast<double*> (tensor->getImpl()->rawPtr());
-
-    // Get the tensor absolute max value
-    double maxValue = 0.0f;
-    for(std::size_t i = 0; i < tensor->size(); ++i) {
-        if(std::fabs(castedTensor[i]) > maxValue) {
-            maxValue = std::fabs(castedTensor[i]);
-        }
+    // get the abs tensor
+
+    std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
+
+    // flatten the abs tensor
+
+    std::int64_t nbElement = tensor->size();
+
+    auto reshapeOp = Reshape_Op({nbElement});
+    reshapeOp.setDataType(tensor->dataType());
+    reshapeOp.setBackend(tensor->backend());
+
+    reshapeOp.associateInput(0, absTensor);
+    reshapeOp.forward();
+    std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0);
+
+    // Get the argmax
+
+    auto argmaxOp = ArgMax_Op(0, true, false);
+    argmaxOp.setDataType(tensor->dataType());
+    argmaxOp.setBackend(tensor->backend());
+
+    argmaxOp.associateInput(0, flatTensor);
+    argmaxOp.forward();
+    std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0);
+
+    // Return the max
+
+    int maxIndex = std::round(argmaxTensor->get<double>(0));
+
+    return flatTensor->get<double>(maxIndex);
+}
+//Function used to extraxt the local tensor (from a ProducerScalingNode)
+std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) {
+    if (node->getParent(1)->attributes()->hasAttr("isProducerScaling")) {
+        std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator());
+        operatorTensor->forward();// We need the forward pass to compute the scaled value of the Tensor
+        return operatorTensor->getOutput(0);
+    } else {
+        return getWeightTensor(node);
     }
-    return maxValue;
 }
 
 void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta)
@@ -94,16 +142,18 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
             std::shared_ptr<Node> n1 = affineNodeVector[i];
             std::shared_ptr<Node> n2 = affineNodeVector[i+1];
 
-            double r1 = getTensorAbsoluteMax(getWeightTensor(n1));
-            double r2 = getTensorAbsoluteMax(getWeightTensor(n2));
+            std::shared_ptr<Aidge::Tensor> n1localTensor = getLocalTensor(n1);
+            std::shared_ptr<Aidge::Tensor> n2localTensor = getLocalTensor(n2);
+            
+            double r1 = getTensorAbsoluteMax(n1localTensor);
+            double r2 = getTensorAbsoluteMax(n2localTensor);
 
             double s1 = std::sqrt(r1 * r2) / r1;
             double s2 = std::sqrt(r1 * r2) / r2;
 
-            rescaleTensor(getWeightTensor(n1), s1);
-            rescaleTensor(getWeightTensor(n2), s2);
-
-            rescaleTensor(getBiasTensor(n1), s1);
+            insertScalingBelowProducer(n1->getParent(1),s1,graphView);
+            insertScalingBelowProducer(n2->getParent(1),s2,graphView);
+            insertScalingBelowProducer(n1->getParent(2),s1,graphView);
 
             double rangeDelta = std::abs(r1 - r2);
             if (rangeDelta > maxRangeDelta)
diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp
index 57ad7a836bbb6251a8eeb6da87e3647b4f54afe2..1901e3864066d3e9bc00f3093fe099c5bcfdec94 100644
--- a/src/PTQ/Clipping.cpp
+++ b/src/PTQ/Clipping.cpp
@@ -222,7 +222,7 @@ std::map<std::string, double> adjustRanges(Clipping clippingMode, std::map<std::
 
         for (std::shared_ptr<Node> node : graphView->getNodes())
         {
-            if (node->type() == "Scaling")
+            if (node->attributes()->hasAttr("isScaling"))
             {
                 std::vector<int> histogram = histograms[node->name()];
 
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index 0e26313475bbbda23a56dcdda52d55a0a5af8204..c2bc0e20dee70ba88c05f49d1b7acacb66da047b 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -12,8 +12,7 @@
 #include "aidge/quantization/PTQ/CLE.hpp"
 #include "aidge/quantization/PTQ/Clipping.hpp"
 #include "aidge/quantization/PTQ/PTQ.hpp"
-#include "aidge/quantization/PTQ/PTQMetaOps.hpp"
-
+#include "aidge/operator/PTQMetaOps.hpp"
 
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/GraphView.hpp"
@@ -22,11 +21,16 @@
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/utils/Log.hpp"
 
+#include "aidge/operator/BitShift.hpp"
 #include "aidge/operator/Producer.hpp"
 #include "aidge/operator/Mul.hpp"
 #include "aidge/operator/ReLU.hpp"
 #include "aidge/operator/BatchNorm.hpp"
 #include "aidge/operator/Conv.hpp"
+#include "aidge/operator/ArgMax.hpp"
+#include "aidge/operator/Reshape.hpp"
+#include "aidge/operator/Cast.hpp"
+
 
 #include "aidge/recipes/Recipes.hpp"
 #include "aidge/recipes/QuantRecipes.hpp"
@@ -49,6 +53,155 @@ bool isMerging(std::shared_ptr<Node> node)
 {
     return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end());
 }
+static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
+{
+    int index = 0;
+    while (node->getParent(index) != parentNode) 
+        index++;
+    return index;
+}
+
+void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node,double coeff)
+{
+    AIDGE_ASSERT(node->type() == "Mul" && (node->attributes()->hasAttr("isProducerScaling") || node->attributes()->hasAttr("isScaling")),
+    "Cannot update the scaling factor on Node of type {} with no scaling tag",node->type());
+    auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
+    
+    double previousScalingFactor = localTensor.get<double>(0);
+    std::shared_ptr<Tensor> finalTensor = std::make_shared<Tensor>(Array1D<double, 1> {previousScalingFactor * coeff});
+    node->input(1).first->getOperator()->setOutput(0, finalTensor);
+}
+/* Util function to insert a node below another one already connected */
+void insertNodeBetween(std::shared_ptr<Node> parent, 
+                       std::shared_ptr<Node> newNode, 
+                       std::shared_ptr<GraphView> graphView) 
+{
+    // Checking the parents always have at least 1 children
+    if(parent->getChildren().size() == 0)
+    {
+        parent->addChild(newNode, 0, 0);
+        graphView->add(newNode);
+        return;
+    }
+    std::vector<std::shared_ptr<Node>> nextNodes = parent->getChildren(0);
+    std::vector<int> inputIndices(nextNodes.size());
+    for (std::size_t i = 0; i < nextNodes.size(); i++) {
+        inputIndices[i] = getInputIndex(nextNodes[i], parent);
+    }
+
+    // Disconnect childs from parent
+    for (std::shared_ptr<Node> nextNode : nextNodes) {
+        parent->removeChild(nextNode, 0);
+    }
+
+    // Insert the new node between the child and the parent
+    parent->addChild(newNode, 0, 0);
+    for (std::size_t i = 0; i < nextNodes.size(); i++) {
+        newNode->addChild(nextNodes[i], 0, inputIndices[i]);
+    }
+
+    graphView->add(newNode);
+}
+
+void applyConstFold(std::shared_ptr<GraphView> &graphView)
+{
+    for (const std::shared_ptr<Node> node : graphView->getNodes())
+    {
+        if (node->type() == "Producer" )
+        {
+            const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
+            producer->constant() = true;
+        }
+    }
+    constantFolding(graphView);
+}
+//Add a condition to insert Cast Node to cast User Input Data into the desired type
+bool castQuantizedGraph(std::shared_ptr<GraphView> &graphView, Aidge::DataType targetType, bool singleShift)
+{
+    //We need a deepcopy of the graphs nodes since we will replace some nodes
+    std::vector<std::shared_ptr<Node>> nodeVector(graphView->getNodes().begin(), graphView->getNodes().end());
+
+    for (std::shared_ptr<Node> node : nodeVector)
+    {
+        if (node->type() == "Round" && node->attributes()->hasAttr("isProducerRounding"))
+        {
+            std::shared_ptr<Aidge::Node> castNode =  Cast(targetType,node->name() + "_Cast");
+            castNode->getOperator()->setDataType(targetType);
+            castNode->getOperator()->setBackend(node->getOperator()->backend());
+            insertNodeBetween(node,castNode,graphView);
+            castNode->attributes()->addAttr("isProducerCasting",0.0);
+            node->getOperator()->setDataType(DataType::Float64);
+        }
+        else if(node->type() == "Quantizer")
+        {
+            if(singleShift)
+            {
+                std::shared_ptr<Node> newBitShiftQuantizer = BitShiftQuantizer(node,targetType,node->name()+"_BitShift_Quantizer");
+                newBitShiftQuantizer->getOperator()->setBackend(node->getOperator()->backend());
+                graphView->replace({node},{newBitShiftQuantizer});
+
+            }
+            else //If single shift is not enabled we keep using the alternative Int Quantizer (which cast the data before and after the regular Quantizer Operations) 
+            {
+                std::shared_ptr<Node> newIntQuantizer = IntQuantizer(node,targetType,node->name());
+                newIntQuantizer->getOperator()->setBackend(node->getOperator()->backend());
+                graphView->replace({node},{newIntQuantizer});
+            }
+        }
+        else if (node->type() != "Producer" &&
+        !node->attributes()->hasAttr("isProducerScaling")) 
+        {              
+            node->getOperator()->setDataType(targetType);
+        }   
+    }
+    return true;
+}
+bool insertRoundBelowProducer(std::shared_ptr<Node> node,std::shared_ptr<GraphView> graphView)
+{
+    std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
+    roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
+    roundNode->getOperator()->setBackend("cpu");
+
+    insertNodeBetween(node,roundNode,graphView);
+
+    roundNode->attributes()->addAttr("isProducerRounding",0.0);
+    return true;
+}
+bool insertScalingBelowProducer(std::shared_ptr<Node> node,double scalingFactor, std::shared_ptr<GraphView> graphView)
+{
+    if(node->attributes()->hasAttr("isProducerRounding"))
+    {
+        //In this case we 'bump' the node to the one above him (an actual ProducerScaling)
+        // because the round node is not usable (only used when SSA is enabled)
+        node = node->getParent(0);
+    }
+    if(node->attributes()->hasAttr("isProducerScaling"))
+    {
+        // We accumulate the multiples scaling factors by multiplying the SF of the ProducerScaling node 
+        // (adding new nodes each time would make the graph unusable)
+        multiplyScalingFactor(node,scalingFactor);
+        return true;
+    }
+    AIDGE_ASSERT(node->type() == "Producer","Cannot apply a scaling factor on node of type: {} which is not a producer", node->type());
+    std::string scalingNodeName = makeUniqueName(node->name() + "_Producer_Scaling", graphView);
+    
+    std::shared_ptr<Aidge::Node> scalingNode = Mul(scalingNodeName);
+    scalingNode->attributes()->addAttr("isProducerScaling",0.0);
+    
+    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
+    std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "Factor"); 
+    scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
+    graphView->add(scalingFactorProducer);
+    
+    scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
+    scalingNode->getOperator()->setBackend("cpu");
+
+    insertNodeBetween(node, scalingNode, graphView);
+
+    return true;
+}
 
 bool checkArchitecture(std::shared_ptr<GraphView> graphView)
 {
@@ -66,51 +219,43 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView)
     return true;
 }
 
-static void fillTensor(std::shared_ptr<Tensor> tensor, double value)
+// TODO : make the retreival of argmax values backend independant (refCastFrom)
+static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 {
-    // Get the tensor data pointer
-    double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr());
+    // get the abs tensor
 
-    // Fill the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] = value;
-}
+    std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
 
-static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
-{
-    // Get the tensor data pointer
-    double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr());
+    // flatten the abs tensor
 
-    // Rescale the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] *= scaling;
-}
+    std::int64_t nbElement = tensor->size();
 
-static void roundTensor(std::shared_ptr<Tensor> tensor)
-{
-    // Get the tensor data pointer
-    double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr());
+    auto reshapeOp = Reshape_Op({nbElement});
+    reshapeOp.setDataType(tensor->dataType());
+    reshapeOp.setBackend(tensor->backend());
 
-    // Rescale the tensor
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        castedTensor[i] = std::nearbyint(castedTensor[i]);//Round
-}
+    reshapeOp.associateInput(0, absTensor);
+    reshapeOp.forward();
+    std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0);
 
-static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
-{
-    // Get the tensor data pointer and edit it
-    double * castedTensor = static_cast<double*>(tensor->getImpl()->rawPtr());
-
-    // Get the tensor absolute max value
-    double maxValue = 0.0f;
-    for(std::size_t i = 0; i < tensor->size(); ++i) {
-        if(std::fabs(castedTensor[i]) > maxValue) {
-            maxValue = std::fabs(castedTensor[i]);
-        }
-    }
-    return maxValue;
+    // Get the argmax
+
+    auto argmaxOp = ArgMax_Op(0, true, false);
+    argmaxOp.setDataType(tensor->dataType());
+    argmaxOp.setBackend(tensor->backend());
+
+    argmaxOp.associateInput(0, flatTensor);
+    argmaxOp.forward();
+    std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0);
+
+    // Return the max
+
+    int maxIndex = std::round(argmaxTensor->get<double>(0));
+
+    return flatTensor->get<double>(maxIndex);
 }
 
+
 // TODO : pass nodeVector by reference ...
 static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::shared_ptr<Node>> nodeVector, std::string nodeType)
 {
@@ -121,6 +266,15 @@ static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::s
 
     return remainingNodes;
 }
+static std::vector<std::shared_ptr<Node>> removeProdScalingNodes(std::vector<std::shared_ptr<Node>> nodeVector)
+{
+    std::vector<std::shared_ptr<Node>> remainingNodes;
+    for (std::shared_ptr<Node> node : nodeVector)
+        if (!node->attributes()->hasAttr("isProducerScaling"))
+            remainingNodes.push_back(node);
+
+    return remainingNodes;
+}
 
 static void fixScheduling(std::vector<std::shared_ptr<Node>>& nodeVector) {
 
@@ -165,12 +319,13 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView>
 
     fixScheduling(nodeVector);
     nodeVector = removeMatchingNodes(nodeVector, "Producer");
+    nodeVector = removeProdScalingNodes(nodeVector);
 
     if (verbose) 
     {
-        Log::info("NB OF NODES = {}", nodeVector.size());
+        Log::notice("NB OF NODES = {}", nodeVector.size());
         for (std::shared_ptr<Node> node : nodeVector)
-            Log::info("{} {}", node->type(), node->name());
+            Log::notice("{} {}", node->type(), node->name());
     }
 
     return nodeVector;    
@@ -184,6 +339,7 @@ static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView)
 void prepareNetwork(std::shared_ptr<GraphView> graphView)
 {
     removeFlatten(graphView);
+    sanitizeNodeNames(graphView);
 
     bool containsBatchNorm = false;
     std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
@@ -228,29 +384,30 @@ void insertResidualNodes(std::shared_ptr<GraphView> graphView)
                 if (parentIsForking)
                 {
                     // temporary verbose ...
-                    Log::info(" ### found residual branch at index {}", i);
-                    Log::info(" ### inserting multiplicative node ...");
+                    Log::notice(" ### found residual branch at index {}", i);
+                    Log::notice(" ### inserting multiplicative node ...");
 
                     std::string residualNodeName = makeUniqueName(parentNode->name() + "_Res", graphView);
-                    std::shared_ptr<Node> residualNode = Scaling(1.0, residualNodeName);
+                    std::shared_ptr<Node> residualNode = Mul(residualNodeName);
+                    residualNode->attributes()->addAttr("isScaling", 0.0);
+                    residualNode->attributes()->addAttr("isResidual", 0.0);
+                    
+                    //Adding the SF as a producer of the node
+                    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {1.0});
+                    std::shared_ptr<Node> scalingFactorProducer = addProducer(residualNode, 1, {1}, "ScalingFactor"); 
+                    scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
 
-                    residualNode->getOperator()->setDataType(DataType::Float64); //getDataType(parentNode)
+                    residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
                     residualNode->getOperator()->setBackend("cpu");
 
                     graphView->insertParent(node, residualNode, i, 0, 0);
+                    graphView->add(scalingFactorProducer);
                 }
             }
         }
     }
 }
 
-static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> parentNode)
-{
-    int index = 0;
-    while (node->getParent(index) != parentNode) 
-        index++;
-    return index;
-}
 
 void insertScalingNodes(std::shared_ptr<GraphView> graphView)
 {
@@ -263,37 +420,30 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
         if (isAffine(parentNode) || isMerging(parentNode))
         {
             std::string scalingNodeName = makeUniqueName(parentNode->name() + "_Scaling", graphView);
-            std::shared_ptr<Node> scalingNode = Scaling(1.0, scalingNodeName);
+            //std::shared_ptr<Node> scalingNode = Scaling(1.0, scalingNodeName);
+            
+            //Adding Mul operator with tag "isScaling"
+            std::shared_ptr<Aidge::Node> scalingNode = Mul(scalingNodeName);
+            scalingNode->attributes()->addAttr("isScaling",0.0);
+
+            //Adding the SF as a producer of the node
+            std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {1.0});
+            std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "ScalingFactor"); 
+            scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
 
             scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
             scalingNode->getOperator()->setBackend("cpu");
 
             if (parentNode->getChildren().size() > 0)
             {
-                // SCALING NODE INSERTION
-                
-                // We always have one output from Affine and Add nodes, but possibly multiple childs
-                std::vector<std::shared_ptr<Node>> nextNodes = parentNode->getChildren(0); 
-
-                // For each node in nextNodes store the connexion index
-                std::vector<int> inputIndices(nextNodes.size());
-                for (std::size_t i = 0; i < nextNodes.size(); i++)
-                    inputIndices[i] = getInputIndex(nextNodes[i], parentNode);
-                    
-                for (std::shared_ptr<Node> nextNode : nextNodes)
-                    parentNode->removeChild(nextNode, 0);
-
-                parentNode->addChild(scalingNode, 0, 0);
-
-                for (std::size_t i = 0; i < nextNodes.size(); i++)
-                    scalingNode->addChild(nextNodes[i], 0, inputIndices[i]);
-
-                graphView->add(scalingNode);
+                insertNodeBetween(parentNode,scalingNode,graphView);
+                graphView->add(scalingFactorProducer);
             }
             else
             {
-                // Log::info(" last node reached ! ");
+                // Log::notice(" last node reached ! ");
                 parentNode->addChild(scalingNode, 0, 0);
+                graphView->add(scalingFactorProducer);
                 graphView->add(scalingNode);
             }
         }
@@ -303,7 +453,7 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
 static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> mergingNode)
 {
     std::shared_ptr<Node> currNode = mergingNode;
-    while(currNode->type() != "Scaling")
+    while(!currNode->attributes()->hasAttr("isScaling"))
     {
         if (currNode->getParents().size() == 0)
         {
@@ -346,7 +496,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
     for (std::shared_ptr<Node> node : nodeVector)
     {
         // Scaling nodes still have a ratio of 1, so they are seamless ...
-        if (node->type() == "ReLU" || node->type() == "Scaling" || isSeamless(node))
+        if (node->type() == "ReLU" || node->attributes()->hasAttr("isScaling") || isSeamless(node))
         {
             if (node != firstNode)
             {
@@ -362,7 +512,8 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
             std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
             double scaling = getTensorAbsoluteMax(weightTensor);
             double ratio = 1.0 / scaling;
-            rescaleTensor(weightTensor, ratio);
+            //rescaleTensor(weightTensor, ratio);
+            insertScalingBelowProducer(node->getParent(1),ratio,graphView);
 
             // Accumulate the ratio
             if (node == firstNode)
@@ -380,7 +531,8 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
             if (nodeHasBias(node))
             {
                 std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
-                rescaleTensor(biasTensor, accumulatedRatios[node->name()] );
+                //rescaleTensor(biasTensor, accumulatedRatios[node->name()] );
+                insertScalingBelowProducer(node->getParent(2),accumulatedRatios[node->name()],graphView);
             }
         }
 
@@ -407,8 +559,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
 
                 std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
 
-                double currScalingFactor = getScalingFactor(scalingNode);
-                updateScalingFactor(scalingNode, currScalingFactor / rescaling);
+                multiplyScalingFactor(scalingNode,1/rescaling);
 
                 accumulatedRatios[mergingNode->name()] /= rescaling; // optional ...
             }
@@ -433,7 +584,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView
     std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
     for (std::shared_ptr<Node> node : nodeSet)
     {
-        if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
+        if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
         {
             std::shared_ptr<Operator> nodeOperator = node->getOperator();
             std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
@@ -455,7 +606,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView
     // std::shared_ptr<Node> inputNode = getFirstNode(graphView);
 
     for (std::shared_ptr<Node> node : nodeSet)
-        if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
+        if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
             valueRanges.insert(std::make_pair(node->name(), 0));
 
     if (useCuda)
@@ -468,7 +619,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView
 
     for (std::shared_ptr<Tensor> sample : inputDataSet)
     {
-        //Log::info(" IT : {}", it++);
+        //Log::notice(" IT : {}", it++);
 
         // Inference ...
 
@@ -482,7 +633,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView
         std::map<std::string, double> sampleRanges;
         for (std::shared_ptr<Node> node : nodeSet)
         {
-            if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
+            if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
             {
                 std::shared_ptr<Operator> nodeOperator = node->getOperator();
                 std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
@@ -504,7 +655,7 @@ std::map<std::string, double> computeRanges(std::shared_ptr<GraphView> graphView
 
         for (std::shared_ptr<Node> node : nodeSet)
         {
-            if ((scalingNodesOnly && (node->type() == "Scaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
+            if ((scalingNodesOnly && (node->attributes()->hasAttr("isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer")))
                 {
                     std::string nodeName = node->name();
                     if (sampleRanges[nodeName] > valueRanges[nodeName])
@@ -540,7 +691,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
     for (std::shared_ptr<Node> node : nodeVector)
     {
         // Seamless scaling factor propagation ...
-    
+
         if (isAffine(node) || isSeamless(node) || node->type() == "ReLU") 
         {
             if (node == firstNode)
@@ -554,11 +705,13 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
             }
         }
 
+
         // Here prevNode is either a 'Affine' or a 'Merging'
         // => do not split the cases, just handle the bias ...
 
-        if (node->type() == "Scaling") 
+        if (node->attributes()->hasAttr("isScaling")) 
         {
+
             // retrieve the previous scaling factor ...
             std::shared_ptr<Node> prevNode = node->getParent(0);
             double prevScalingFactor = scalingFactors[prevNode->name()];
@@ -566,8 +719,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
             // ValueRanges must contains all the scaling nodes !!!
             double scalingFactor = valueRanges[node->name()]; 
 
-            double currScalingFactor = getScalingFactor(node);
-            updateScalingFactor(node, currScalingFactor / (scalingFactor / prevScalingFactor));
+            multiplyScalingFactor(node,1/(scalingFactor / prevScalingFactor));
 
             scalingFactors[node->name()] = scalingFactor;
 
@@ -575,11 +727,13 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
 
             if (isAffine(prevNode))
             {
+
                 bool prevNodeHasBias = nodeHasBias(prevNode);
                 if (prevNodeHasBias)  
-                {
+                {                
                     std::shared_ptr<Tensor> biasTensor = getBiasTensor(prevNode);
-                    rescaleTensor(biasTensor, 1.0 / prevScalingFactor);
+                    //rescaleTensor(biasTensor, 1.0 / prevScalingFactor);
+                    insertScalingBelowProducer(prevNode->getParent(2),1.0 / prevScalingFactor,graphView);
                 }
             }
         }
@@ -608,10 +762,9 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st
                 double rescaling = mergingNodeScaling / maxScaling;
 
                 std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode);
-                //Log::info(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
-
-                double currScalingFactor = getScalingFactor(scalingNode);
-                updateScalingFactor(scalingNode, currScalingFactor * rescaling);                
+                //Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name());
+                
+                multiplyScalingFactor(scalingNode,rescaling) ;          
             }
         }
     }
@@ -647,7 +800,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap
             signMap[node->name()].second = false;
         } 
 
-        if (node->type() == "Scaling") 
+        if (node->attributes()->hasAttr("isScaling")) 
         {
             signMap[node->name()].second = false;
 
@@ -694,7 +847,7 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap
                 // Arbitration : Signed type wins !
                 for(std::shared_ptr<Node> parent : parentNodes)
                 {
-                    while (parent->type() != "Scaling")
+                    while (!parent->attributes()->hasAttr("isScaling"))
                     {
                         signMap[parent->name()] = std::make_pair(false, false);
                         // We are on a branch so nodes always have 1 parent ...
@@ -725,9 +878,9 @@ std::map<std::string, std::pair<bool, bool>> computeSignMap(std::shared_ptr<Grap
 
     if (verbose)
     {
-        Log::info(" === SIGN MAP === ");
+        Log::notice(" === SIGN MAP === ");
         for (std::shared_ptr<Node> node : nodeVector)
-            Log::info(" {}{} | {}", static_cast<int>(signMap[node->name()].first), static_cast<int>(signMap[node->name()].second), node->name());
+            Log::notice(" {}{} | {}", static_cast<int>(signMap[node->name()].first), static_cast<int>(signMap[node->name()].second), node->name());
     }
 
     // SANITY CHECK (TEMPORARY)
@@ -776,26 +929,23 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
         if (isAffine(node))
         {
             // Rescale the weight tensor
-
             std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
-            rescaleTensor(weightTensor, signedMax);
+            insertScalingBelowProducer(node->getParent(1),signedMax,graphView);
 
             if (!noQuant)
-                roundTensor(weightTensor);
+                insertRoundBelowProducer(node->getParent(1),graphView);
 
             // Rescale the bias tensor
-
             if (nodeHasBias(node))  
             {
                 bool inputIsUnsigned = signMap[node->name()].first;
                 double rescaling = inputIsUnsigned ? unsignedMax * signedMax : signedMax * signedMax;
-                
-
+            
                 std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
-                rescaleTensor(biasTensor, rescaling);
+                insertScalingBelowProducer(node->getParent(2),rescaling,graphView);
 
                 if (!noQuant)
-                    roundTensor(biasTensor);
+                    insertRoundBelowProducer(node->getParent(2),graphView);
             }
 
             // Compensate the rescaling using the next Scaling node
@@ -810,8 +960,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
             
             std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ...
 
-            double currScalingFactor = getScalingFactor(scalingNode);
-            updateScalingFactor(scalingNode, currScalingFactor * rescaling);
+            multiplyScalingFactor(scalingNode,rescaling) ;          
         }
         
         if (isMerging(node))
@@ -826,23 +975,25 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
 
             std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); // Assert if scalingNode is a Scaling ...
         
-            double currScalingFactor = getScalingFactor(scalingNode); // XXX bad naming
-            updateScalingFactor(scalingNode, currScalingFactor * rescaling);
+            multiplyScalingFactor(scalingNode,rescaling) ;          
         }
         
         // Handle the Scaling Nodes ...
 
-        if (node->type() == "Scaling")
+        if (node->attributes()->hasAttr("isScaling"))
         {
             if (!noQuant) 
             {  
                 // Replace  the  Scaling Node by Quantizer
+                auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
+                std::shared_ptr<Tensor> fallback;
+                const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
+                double old_sf = localTensor.get<double>(0);//!\\ 
 
-                std::shared_ptr<Node> quantizerNode = Quantizer(getScalingFactor(node), -(signedMax + 1), signedMax, node->name());
+                std::shared_ptr<Node> quantizerNode = Quantizer(old_sf, -(signedMax + 1), signedMax, node->name());
                 quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
                 quantizerNode->getOperator()->setBackend("cpu");
-
-                graphView->replace({node}, {quantizerNode});
+                graphView->replace({node,node->getParent(1)}, {quantizerNode});
 
                 if (optimizeSigns)
                 {
@@ -856,6 +1007,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
 
                     double currScalingFactor = getScalingFactor(quantizerNode);
                     updateScalingFactor(quantizerNode, currScalingFactor * rescaling);
+                    
 
                     if(outputIsUnsigned)
                     {
@@ -876,51 +1028,40 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
 
     for (std::shared_ptr<Node> node : nodeVector)
     {
-        // A merging node is always followed by a scaling node at this point ...
+        // A merging node is always followed by a Quantizer node at this point
 
-        if (node->type() == "Quantizer")
+        if (node->type() == "Quantizer" && (node->attributes()->hasAttr("isResidual") || !isAffine(node->getParent(0))))
         {   
-            bool prevNodeIsForking = ((node->getParent(0))->getChildren().size() > 1);
-            bool prevNodeIsAffine = isAffine(node->getParent(0));
-            bool insertNode = prevNodeIsForking || !prevNodeIsAffine;
-
-            if (insertNode)
-            {
-                // create and insert the multplicative node
-
-                std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView);
-                std::shared_ptr<Node> mulNode = Mul(mulNodeName);
-
-                mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
-                mulNode->getOperator()->setBackend("cpu");
-
-                graphView->insertParent(node, mulNode, 0, 0, 0);
 
-                // create and insert the producer node
+            // check if the Quantizer is a residual one, and insert a compensation node if so ...
+            // create and insert the multplicative node before the Quantizer
 
-                std::shared_ptr<Tensor> inputTensor = std::static_pointer_cast<Tensor> (mulNode->getOperator()->getRawInput(0));
-                std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>();
+            std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView);
+            std::shared_ptr<Node> mulNode = Mul(mulNodeName);
+            
+            mulNode->attributes()->addAttr("isCompensation",0.0);
+            mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
+            mulNode->getOperator()->setBackend("cpu");
 
-                coeffTensor->setDataType(DataType::Float64); // getDataType(parentNode)
-                coeffTensor->setBackend("cpu"); 
+            graphView->insertParent(node, mulNode, 0, 0, 0);
 
-                coeffTensor->resize(inputTensor->dims());
-                fillTensor(coeffTensor, 1); 
+            // Add the coeff producer to the multiplier node
 
-                std::shared_ptr<Node> producerNode = Producer(coeffTensor, makeUniqueName("coeff", graphView));
-                producerNode->addChild(mulNode);
-                graphView->add(producerNode);
+            std::shared_ptr<Node> coeffProducer = addProducer(mulNode, 1, {1}, ""); 
+            std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(Array1D<double, 1> {signedMax});
+            coeffProducer->getOperator()->setOutput(0, coeffTensor);
 
-                // rescale the coeffs and edit scaling factor
+            coeffProducer->getOperator()->setDataType(DataType::Float64);
+            coeffProducer->attributes()->addAttr("quantization.ptq.CompensationCoeff",signedMax);
+            coeffProducer->getOperator()->setBackend("cpu"); 
 
-                fillTensor(coeffTensor, signedMax);
+            graphView->add(coeffProducer); // needed ?
 
-                double currScalingFactor = getScalingFactor(node); // XXX bad naming !
-                updateScalingFactor(node, currScalingFactor / signedMax);
+            // Adapt the scaling factor value accordingly
 
-                // TODO : double check this !!!
-                //std::cout << getTensorAbsoluteMax(coeffTensor) << std::endl;
-            }
+            double currScalingFactor = getScalingFactor(node); 
+            updateScalingFactor(node, currScalingFactor / signedMax);
+            
         }
     }
 }
@@ -931,10 +1072,11 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
 
     for (std::shared_ptr<Node> node : nodeVector)
     {
-        // Use A meatoperator of type Scaling of MulCompensation instead
-        if (isAffine(node) || (node->type() == "Mul"))
+        if (isAffine(node) || (node->type() == "Mul" && node->attributes()->hasAttr("isCompensation")))
         {
             std::shared_ptr<Node> scalingNode = (*node->getChildren().begin());
+            if(scalingNode->attributes()->hasAttr("isCasting"))
+                scalingNode = (*node->getChildren().begin());
 
             double base = getScalingFactor(scalingNode);
 
@@ -944,17 +1086,16 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
 
             double ratio = base / approx;
 
-            std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
-            rescaleTensor(weightTensor, ratio);
-            if (!noQuant)
-                roundTensor(weightTensor);
+            insertScalingBelowProducer(node->getParent(1),ratio,graphView);
+            if (!noQuant && !node->getParent(1)->attributes()->hasAttr("isProducerRounding"))
+                insertRoundBelowProducer(node->getParent(1),graphView);
 
             if (nodeHasBias(node))
             {
-                std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
-                rescaleTensor(biasTensor, ratio);
-                if (!noQuant)
-                roundTensor(biasTensor);
+                insertScalingBelowProducer(node->getParent(2),ratio,graphView);
+
+                if (!noQuant && !node->getParent(1)->attributes()->hasAttr("isProducerRounding"))
+                    insertRoundBelowProducer(node->getParent(2),graphView);
             }
         }
     }
@@ -962,12 +1103,12 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
 
 static void printScalingFactors(std::shared_ptr<GraphView> graphView)
 {
-    Log::info(" === SCALING FACTORS === ");
+    Log::notice(" === SCALING FACTORS === ");
     for (auto node : retrieveNodeVector(graphView))
-        if (node->type() == "Scaling" || node->type() == "Quantizer")
+        if (node->attributes()->hasAttr("isScaling") || node->type() == "Quantizer")
         {
             double scalingFactor = getScalingFactor(node);
-            Log::info(" {:.6f} ({})", scalingFactor, node->name());
+            Log::notice(" {:.6f} ({})", scalingFactor, node->name());
         }
 }
 
@@ -994,13 +1135,14 @@ static void printRanges(std::shared_ptr<GraphView> graphView, std::map<std::stri
 
     auto scheduling = scheduler.getStaticScheduling();
     for (auto node : scheduling)
-        if (node->type() == "Scaling")
-            fmt::println("{} range = {}", node->name(), valueRanges[node->name()]);
+        if (node->attributes()->hasAttr("isScaling"))
+            Log::debug("{} range = {}",node->name(),valueRanges[node->name()]);
 }
 
-void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose)
+void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet,
+ Clipping clippingMode, DataType targetType,bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose)
 {
-    Log::info(" === QUANT PTQ 0.2.21 === ");
+    Log::notice(" === QUANT PTQ 0.2.21 === ");
 
     graphView->setBackend("cpu");
 
@@ -1010,62 +1152,79 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
     if (!checkArchitecture(graphView))
         return;
 
-    Log::info(" Preparing the network for the PTQ ... ");
+    Log::notice(" Preparing the network for the PTQ ... ");
     prepareNetwork(graphView);
 
-    Log::info(" Inserting the scaling nodes ...");
+    Log::notice(" Inserting the scaling nodes ...");
     insertScalingNodes(graphView);
 
     crossLayerEqualization(graphView);
-
-    Log::info(" Normalizing the parameters ...");
+    Log::notice(" Normalizing the parameters ...");
     normalizeParameters(graphView);
 
-    Log::info(" Computing the value ranges ...");
+    Log::notice(" Computing the value ranges ...");
     std::map<std::string, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda);
 
-    //std::cout << " === RANGES (BEFORE ADJUST) ===" << std::endl;
+    //Log:debug("=== RANGES (BEFORE ADJUST) ===");
     //printRanges(graphView, valueRanges);
 
-    Log::info(" Optimizing the clipping values ...");
+    Log::notice(" Optimizing the clipping values ...");
     valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose);
 
-    //std::cout << " === RANGES (AFTER ADJUST) ===" << std::endl;
+    //Log:debug("=== RANGES (AFTER ADJUST) ===");
     //printRanges(graphView, valueRanges);
-
-    Log::info(" Normalizing the activations ...");
+    Log::notice(" Normalizing the activations ...");
     normalizeActivations(graphView, valueRanges);
 
-    Log::info(" Quantizing the normalized network ...");
+    Log::notice(" Quantizing the normalized network ...");
     quantizeNormalizedNetwork(graphView, nbBits, noQuant, optimizeSigns, verbose);
-
+    
     if (singleShift)
     {
-        Log::info( " Inserting the compensation nodes ...");
+        Log::notice( " Inserting the compensation nodes ...");
         insertCompensationNodes(graphView, nbBits);
 
-        Log::info(" Performing the Single-Shift approximation ...");
+        Log::notice(" Performing the Single-Shift approximation ...");
         performSingleShiftApproximation(graphView, noQuant);
     }
+    if(targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) 
+    {
+        AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
+        Log::notice("Starting to cast operators into the desired type ...");
+        castQuantizedGraph(graphView,DataType::Int32,singleShift);
+    }
+    else
+    {
+        setupDataType(graphView, inputDataSet, targetType);
+    }
+
+    if(foldGraph)
+    {
+        Log::notice("Applying constant folding recipe to the graph ...");
+        applyConstFold(graphView);
+    }
+    //Mandatory to handle all of the newly added connections!
+    graphView->updateInputsOutputs();
+
+    //reset input nodes
+    /*for(Aidge::NodePtr input_node : graphView->inputNodes())
+    {
+        std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->resetInput()
+    }*/
 
     if (verbose)
         printScalingFactors(graphView);
 
-    //std::cout << " === SCALINGS (BEFORE CAST) ===" << std::endl;
-    //printScalingFactors(graphView);
 
-    setupDataType(graphView, inputDataSet, initialDataType);
     if (useCuda)
-        graphView->setBackend("cuda");
+        //graphView->setBackend("cuda");
 
-    //std::cout << " === SCALINGS (AFTER CAST) ===" << std::endl;
-    //printScalingFactors(graphView);
-
-    Log::info(" Reseting the scheduler ...");
+    Log::notice(" Reseting the scheduler ...");
     SequentialScheduler scheduler(graphView);
     scheduler.resetScheduling();
 
-    Log::info(" Network is quantized !");
+    Log::notice(" Network is quantized !");
+
 }
 
 std::map<std::string, double> getWeightRanges(std::shared_ptr<GraphView> graphView)
@@ -1090,15 +1249,14 @@ void clearBiases(std::shared_ptr<GraphView> graphView)
     for (std::shared_ptr<Node> node : graphView->getNodes()) {
         if (node->type() == "FC" || node->type() == "Conv2D") {
             std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
-            rescaleTensor(biasTensor, 0);
+            //rescaleTensor(biasTensor, 0);
+            insertScalingBelowProducer(node->getParent(2),0,graphView);
         }
     }
 }
-
 void devPTQ(std::shared_ptr<GraphView> graphView) 
 {
     for (std::shared_ptr<Node> node : graphView->getNodes())
-        fmt::println(" UUU : {}", node->name());
+        Log::debug(" UUU : {}", node->name());
 }
-
 }
diff --git a/src/PTQ/PTQMetaOps.cpp b/src/PTQ/PTQMetaOps.cpp
deleted file mode 100644
index 527d8534ae4981471e1fa7dca04f08b4e668677b..0000000000000000000000000000000000000000
--- a/src/PTQ/PTQMetaOps.cpp
+++ /dev/null
@@ -1,152 +0,0 @@
-/********************************************************************************
- * Copyright (c) 2023 CEA-List
- *
- * This program and the accompanying materials are made available under the
- * terms of the Eclipse Public License 2.0 which is available at
- * http://www.eclipse.org/legal/epl-2.0.
- *
- * SPDX-License-Identifier: EPL-2.0
- *
- ********************************************************************************/
-
-#include "aidge/quantization/PTQ/PTQMetaOps.hpp"
-
-#include <array>
-#include <memory>
-#include <utility>
-
-//Operator
-#include "aidge/operator/Clip.hpp"
-#include "aidge/operator/Mul.hpp"
-#include "aidge/operator/Round.hpp"
-
-#include "aidge/graph/Node.hpp"
-#include "aidge/graph/OpArgs.hpp"
-#include "aidge/operator/MetaOperator.hpp"
-#include "aidge/operator/Producer.hpp"
-#include "aidge/utils/ArrayHelpers.hpp"
-#include "aidge/utils/Types.h"
-#include "aidge/operator/Identity.hpp"
-#include "aidge/data/Tensor.hpp"
-#include "aidge/operator/OperatorTensor.hpp"
-#include "aidge/utils/Log.hpp"
-
-
-namespace Aidge 
-{
-
-std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name)
-{
-    // create the nodes
-
-    std::shared_ptr<Node> mulNode =  Mul((!name.empty()) ? name + "_MulQuant" : "");
-    std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_RoundQuant" : "");
-    std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_ClipQuant" : "", clipMin, clipMax);
-
-    // connect the scaling factor producer
-
-    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
-    std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); 
-    scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
-    
-    // create the metaop graph
-
-    std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode});
-    std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ???
-
-    // return the metaop 
-
-    std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype
-
-    return metaopNode; 
-}
-
-std::shared_ptr<Node> Scaling(double scalingFactor, const std::string& name)
-{
-    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
-
-    std::shared_ptr<Node> mulNode = Mul((!name.empty()) ? name + "_Scaling" : "");
-
-    std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); 
-    scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
-
-    std::shared_ptr<GraphView> graphView  = Sequential({mulNode});
-    std::shared_ptr<GraphView> connectedGraphView  = getConnectedGraphView(mulNode);
-
-    NodePtr metaopNode = MetaOperator("Scaling", connectedGraphView, {}, name);
-
-    return metaopNode;
-}
-
-static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType)
-{
-    std::shared_ptr<Node> mulNode = nullptr;
-    for(std::shared_ptr<Node> node : graphView->getNodes())
-        if (node->type() == nodeType)
-            mulNode = node;
-
-    return mulNode;
-}
-
-void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor)
-{
-    if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer")
-        Log::warn(" Cannot update the scaling factor on Node of type {}", metaOpNode->type());
-
-    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
-
-    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(metaOpNode->getOperator());
-    
-    std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul");
-
-    if (!mulNode)
-        Log::warn(" Invalid PTQ MetaOperator, no Mul node found inside ! ");
-
-    mulNode->input(1).first->getOperator()->setOutput(0, scalingFactorTensor);
-}
-
-double getScalingFactor(std::shared_ptr<Node> MetaOpNode)
-{
-    if (MetaOpNode->type() != "Scaling" && MetaOpNode->type() != "Quantizer") {
-        Log::warn(" Cannot get the scaling factor on Node of type {}", MetaOpNode->type());
-        return 0;
-    }
-
-    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(MetaOpNode->getOperator());
-    
-    std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul");
-
-    if (!mulNode) {
-        Log::warn(" Invalid PTQ MetaOperator, no Mul found inside node of type {}", MetaOpNode->type());
-        return 0;
-    }
-
-    auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1);
-    std::shared_ptr<Tensor> fallback;
-    const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); 
-    
-    return localTensor.get<double>(0);
-}
-
-
-void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max)
-{
-    if (quantizerNode->type() != "Quantizer") {
-        Log::warn(" Cannot set the clipping range on Node of type {}", quantizerNode->type());
-        return;
-    }
-
-    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (quantizerNode->getOperator());
-
-    std::shared_ptr<Node> clipNode = getSubNode(metaOp->getMicroGraph(), "Clip");
-
-    if (!clipNode) {
-        Log::warn(" Invalid PTQ MetaOperator, no Clip found inside node of type {}", quantizerNode->type());
-        return;
-    }
-
-    std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(clipNode->getOperator());
-    clipOp->max() = max;
-    clipOp->min() = min;
-}
-}
\ No newline at end of file
diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp
index 9b51e846df498a9303b7373ae1c86d4b007a96f0..8a42770ac9ff5c9426c0538d407c7f58d0021c15 100644
--- a/src/QAT/QAT_LSQ.cpp
+++ b/src/QAT/QAT_LSQ.cpp
@@ -13,7 +13,6 @@
 #include "aidge/operator/LSQ.hpp"
 #include "aidge/operator/ReLU.hpp"
 
-
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/scheduler/SequentialScheduler.hpp"
@@ -23,7 +22,42 @@
 
 namespace Aidge {
 
-void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float stepSize)
+static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
+{
+    auto valueTensor = (*tensor).abs().mean();
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu");
+    return localTensor.get<float>(0);
+}
+
+// INIT THE STEP SIZE OF A QUANTIZER NODE
+
+static bool initStepSize(std::shared_ptr<Node> quantizer)
+{
+    const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator());
+
+    float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0));
+
+    float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second));
+
+    auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
+
+    // XXX Manage backend here ?
+    stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend());
+    stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType());
+
+    auto stepSizeProducer = quantizer->getParent(1);
+
+    stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor);
+
+    Log::debug("[ INIT STEP SIZE = {} ]",stepSize);
+
+    return false;
+}
+
+// INPUT QUANTIZERS INSERTION
+
+static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
     const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
 
@@ -34,180 +68,76 @@ void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbB
         std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
         std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1};
 
-        // INPUT QUANTIZERS INSERTION
+        // Create the input quantizer node
 
-        // TODO : double check this, and use createUniqueName()
-        auto inputQuantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
-        auto inputQuantizerNode = LSQ(signedRange, inputQuantizerName);
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName);
 
-        // Set the step size
+        // Init the step-size using the node call stack
 
-        auto inputStepSizeOp = inputQuantizerNode->getParent(1)->getOperator();
-        auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        inputStepSizeOp->setOutput(0, inputStepSizeTensor);
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
 
         // Absorb the ReLU when possible ...
 
-        // XXX is this safe ???
-        bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); 
-        // bool nodeHasParent = (linearNode->getParents().size() != 0);
+        bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]);  // XXX is this safe ?
 
         if (nodeHasParent) {
             auto parentNode = linearNode->getParents()[0];
             if (parentNode->type() == "ReLU") {
-                auto inputQuantizerOp = std::static_pointer_cast<LSQ_Op> (inputQuantizerNode->getOperator());
-                inputQuantizerOp->range() = unsignedRange;
+                auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator());
+                quantizerOp->range() = unsignedRange;
                 graphView->replace({parentNode}, {}); 
             }
         }
 
-        // We need to handle the case where the linear node is the first one ...
+        // Insert the quantizer in the graphView ...
+        // (We need to handle the case where the linear node is the first one)
 
         if (nodeHasParent) {
-            graphView->insertParent(linearNode, inputQuantizerNode, 0, 0, 0);
+            graphView->insertParent(linearNode, quantizerNode, 0, 0, 0);
         } else {
-            inputQuantizerNode->addChild(graphView);
-            graphView->add(inputQuantizerNode);
+            quantizerNode->addChild(graphView);
+            graphView->add(quantizerNode);
         }
-
-        // PARAM QUANTIZERS INSERTION
-
-        // TODO : double check this, and use createUniqueName()
-        auto paramQuantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
-        auto paramQuantizerNode = LSQ(signedRange, paramQuantizerName); 
-        graphView->insertParent(linearNode, paramQuantizerNode, 1, 0, 0);
-
-        // Set the step size
-
-        auto paramStepSizeOp = paramQuantizerNode->getParent(1)->getOperator();
-        auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        paramStepSizeOp->setOutput(0, paramStepSizeTensor);
     }
-
 }
 
-static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
-{
-    auto backend = tensor->backend();
-    if (backend == "cuda")
-        tensor->setBackend("cpu");
-
-    float acc = 0;
-    float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr());
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        acc += std::abs(castedTensor[i]);
-    acc /= static_cast<float> (tensor->size());
-
-    if (backend == "cuda")
-        tensor->setBackend("cuda");
-
-    return acc;
-}
+// PARAM QUANTIZERS INSERTION
 
-static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> calibrationData, bool useCuda)
+static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
-    // Propagate the calibration tensor
+    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
 
-    SequentialScheduler scheduler(graphView);
-    scheduler.resetScheduling();
-    scheduler.forward(true, {calibrationData});
+    std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
 
-    // Store the input tensor statistics
+    for (const auto& match : matches) 
+    {       
+        auto linearNode = match.graph->rootNode(); 
 
-    if (useCuda)
-        graphView->setBackend("cpu"); 
+        // TODO : double check this, and use createUniqueName()
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName); 
 
-    std::map<std::string, float> inputStats;
-    for (auto node : graphView->getNodes())
-    {
-        if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!!
-        {
-            const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator());
-            float inputAbsMean = getTensorAbsMean(op->getInput(0));
-            inputStats.insert(std::make_pair(node->name(), inputAbsMean));
-            fmt::println("{} -> {}", node->name(), inputAbsMean);
-        }
-    }
+        // Init the step-size using the node call stack
 
-    if (useCuda)
-        graphView->setBackend("cuda");
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
 
-    return inputStats;
-}
+        // Insert the quantizer in the graphView
 
-static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> graphView, bool useCuda)
-{
-    if (useCuda)
-        graphView->setBackend("cpu");
-
-    std::map<std::string, float> paramStats;
-    for (auto node : graphView->getNodes())
-    {
-        if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!!
-        {
-            const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator());
-            float paramAbsMean = getTensorAbsMean(op->getInput(1));
-            paramStats.insert(std::make_pair(node->name(), paramAbsMean));
-            fmt::println("{} -> {}", node->name(), paramAbsMean);
-        }
+        graphView->insertParent(linearNode, quantizerNode, 1, 0, 0);
     }
-    
-    if (useCuda)
-        graphView->setBackend("cuda");
-
-    return paramStats;
 }
 
-static void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView, std::map<std::string, float> inputStats, std::map<std::string, float> paramStats)
+void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
-    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
-
-    for (const auto& match : matches) 
-    {
-        auto linearNode = match.graph->rootNode();
-
-        // INPUT QUANTIZERS STEP-SIZES
-
-        auto inputQuantNode = linearNode->getParent(0);
-        auto inputQuantOp = std::static_pointer_cast<LSQ_Op>(inputQuantNode->getOperator());
-
-        float absMean = inputStats[linearNode->name()];
-        float stepSize = 2.0f * (absMean / std::sqrt(inputQuantOp->range().second));
-
-        auto inputStepSizeOp = inputQuantNode->getParent(1)->getOperator();
-        // XXX inputStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})));
-        auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        inputStepSizeOp->setOutput(0, inputStepSizeTensor);
-
-        // PARAM QUANTIZERS STEP-SIZES
-
-        auto paramQuantNode = linearNode->getParent(1);
-        auto paramQuantOp = std::static_pointer_cast<LSQ_Op>(paramQuantNode->getOperator());
-
-        absMean = paramStats[linearNode->name()];
-        stepSize = 2.0f * (absMean / std::sqrt(paramQuantOp->range().second));
-
-        auto paramStepSizeOp = paramQuantNode->getParent(1)->getOperator();
-        // XXX paramStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})));
-        auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        paramStepSizeOp->setOutput(0, paramStepSizeTensor);
-    }
+    setupInputQuantizers(graphView, nbBits);
+    setupParamQuantizers(graphView, nbBits);
 }
 
-void QuantLSQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData)
+void QuantLSQ::devLSQ(std::shared_ptr<Tensor> tensor)
 {
-    bool useCuda = (calibrationData->backend() == "cuda");
-
-    // Collect the tensor statisics
-    auto inputStats = collectInputStats(graphView, calibrationData, useCuda);
-
-    auto paramStats = collectParamStats(graphView, useCuda);
-
-    // Insert the quantizers
-    insertQuantizers(graphView, nbBits, 1.0);
-
-    // Adjust the quantizers step-sizes
-    adjustQuantizersStepSizes(graphView, inputStats, paramStats);
+    float mean = (tensor->mean()).get<float> (0);
+    Log::debug("MEAN = {}",mean);
 }
 
 }
\ No newline at end of file
diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fb7366467cf9d1e84b3465929027e0217b2a354f
--- /dev/null
+++ b/src/operator/PTQMetaOps.cpp
@@ -0,0 +1,229 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/operator/PTQMetaOps.hpp"
+
+#include <array>
+#include <memory>
+#include <utility>
+
+//Operator
+#include "aidge/operator/Clip.hpp"
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/Round.hpp"
+#include "aidge/operator/Cast.hpp"
+#include "aidge/operator/BitShift.hpp"
+
+#include "aidge/graph/Node.hpp"
+#include "aidge/graph/OpArgs.hpp"
+#include "aidge/operator/MetaOperator.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/utils/ArrayHelpers.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/operator/Identity.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/Log.hpp"
+
+
+namespace Aidge 
+{
+static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType)
+{
+    std::shared_ptr<Node> mulNode = nullptr;
+    for(std::shared_ptr<Node> node : graphView->getNodes())
+        if (node->type() == nodeType)
+            mulNode = node;
+
+    return mulNode;
+}
+
+std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name)
+{
+    // create the nodes
+
+    std::shared_ptr<Node> mulNode =  Mul((!name.empty()) ? name + "_MulQuant" : "");
+    std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_RoundQuant" : "");
+    std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_ClipQuant" : "", clipMin, clipMax);
+
+    // connect the scaling factor producer
+
+    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
+    std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); 
+    scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
+    
+    // create the metaop graph
+
+    std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode});
+    std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ???
+
+    // return the metaop 
+
+    std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype
+
+    return metaopNode; 
+}
+std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name)
+{
+    double scalingFactor = getScalingFactor(oldQuantizer);
+
+    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (oldQuantizer->getOperator());
+    std::shared_ptr<Node> oldclipNode = getSubNode(metaOp->getMicroGraph(), "Clip");
+
+    if (!oldclipNode) {
+    Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", oldQuantizer->type());
+        return nullptr;
+    }
+
+    std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(oldclipNode->getOperator());
+    int shift = std::log2(scalingFactor);
+    BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left;
+
+    if(shift < 0 )
+    {
+        direction = BitShift_Op::BitShiftDirection::right;
+        shift = -shift;
+    }
+
+    std::shared_ptr<Node> bitShiftNode = BitShift(direction,(!name.empty()) ? name + "_MulIQuant" : "");
+    std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_IClipQuant" : "", clipOp->min(), clipOp->max());
+
+    std::shared_ptr<Tensor> bitshiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {shift});
+    std::shared_ptr<Node> bitshiftProducer = addProducer(bitShiftNode, 1, {1}, "ScalingFactor");
+     
+    bitshiftProducer->getOperator()->setOutput(0, bitshiftTensor);
+    bitshiftProducer->attributes()->addAttr("quantization.ptq.ShiftAmount",shift);
+    bitshiftProducer->getOperator()->setDataType(targetType); 
+
+    // connect the scaling factor producer
+
+    bitShiftNode->getOperator()->setDataType(targetType);
+    clipNode->getOperator()->setDataType(targetType);
+    
+    // create the metaop graph
+
+    std::shared_ptr<GraphView> graphView = Sequential({bitShiftNode,clipNode});
+    std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(bitShiftNode); // XXX why not use the graphView ???
+
+    // return the metaop 
+    std::shared_ptr<Node> metaopNode = MetaOperator("BitShiftQuantizer", connectedGraphView, {}, name); // XXX alternative prototype
+
+    return metaopNode; 
+}
+std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name)
+{
+    double scalingFactor = getScalingFactor(oldQuantizer);
+
+    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (oldQuantizer->getOperator());
+    std::shared_ptr<Node> oldclipNode = getSubNode(metaOp->getMicroGraph(), "Clip");
+
+    if (!oldclipNode) {
+    Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", oldQuantizer->type());
+        return nullptr;
+    }
+    std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(oldclipNode->getOperator());
+
+    std::shared_ptr<Node> castPreNode =  Cast(DataType::Float64,((!name.empty()) ? name + "_PreCast" : ""));
+
+    std::shared_ptr<Node> mulNode =  Mul((!name.empty()) ? name + "_MulIQuant" : "");
+    std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_IRoundQuant" : "");
+    std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_IClipQuant" : "", clipOp->min(), clipOp->max());
+
+    std::shared_ptr<Node> castPostNode =  Cast(targetType,((!name.empty()) ? name + "_PostCast" : ""));
+
+    // connect the scaling factor producer
+
+    castPreNode->getOperator()->setDataType(DataType::Float64);
+    mulNode->getOperator()->setDataType(DataType::Float64);
+    roundNode->getOperator()->setDataType(DataType::Float64);
+    clipNode->getOperator()->setDataType(DataType::Float64);
+
+    castPostNode->getOperator()->setDataType(targetType);
+
+    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
+    std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor"); 
+    scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
+    
+    // create the metaop graph
+
+    std::shared_ptr<GraphView> graphView = Sequential({castPreNode, mulNode, roundNode, clipNode, castPostNode});
+    std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ???
+
+    // return the metaop 
+
+    std::shared_ptr<Node> metaopNode = MetaOperator("IntQuantizer", connectedGraphView, {}, name); // XXX alternative prototype
+
+    return metaopNode; 
+}
+
+
+void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor)
+{
+    if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer")
+        Log::warn("Cannot update the scaling factor on Node of type {}", metaOpNode->type());
+
+    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
+
+    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(metaOpNode->getOperator());
+    
+    std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul");
+
+    if (!mulNode)
+        Log::warn("Invalid PTQ MetaOperator, no Mul node found inside ! ");
+
+    mulNode->input(1).first->getOperator()->setOutput(0, scalingFactorTensor);
+}
+
+double getScalingFactor(std::shared_ptr<Node> MetaOpNode)
+{
+    if (MetaOpNode->type() != "Scaling" && MetaOpNode->type() != "Quantizer") {
+        Log::warn("Cannot get the scaling factor on Node of type {}", MetaOpNode->type());
+        return 0;
+    }
+
+    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(MetaOpNode->getOperator());
+    
+    std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul");
+
+    if (!mulNode) {
+        Log::warn("Invalid PTQ MetaOperator, no Mul found inside node of type {}", MetaOpNode->type());
+        return 0;
+    }
+
+    auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1);
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); 
+    
+    return localTensor.get<double>(0);
+}
+
+
+void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max)
+{
+    if (quantizerNode->type() != "Quantizer") {
+        Log::warn("Cannot set the clipping range on Node of type {}", quantizerNode->type());
+        return;
+    }
+
+    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (quantizerNode->getOperator());
+
+    std::shared_ptr<Node> clipNode = getSubNode(metaOp->getMicroGraph(), "Clip");
+
+    if (!clipNode) {
+        Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", quantizerNode->type());
+        return;
+    }
+
+    std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(clipNode->getOperator());
+    clipOp->max() = max;
+    clipOp->min() = min;
+}
+}
\ No newline at end of file