diff --git a/.gitignore b/.gitignore
index ba5c59398b68083c6c1c5fe820fb9070d999c18e..57409a5cddc52f82eb67bf88b0ae28ca23e8a72b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,8 +4,7 @@
 # C++ Build
 build*/
 install*/
-include/aidge/backend/quantization_version.h
-
+include/aidge/quantization_version.h
 # VSCode
 .vscode
 
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 4256774056379969c7406a35e4bcde3ff25c6550..e94a62dd6b89922c90006028a4d1c5b1171709b0 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -12,36 +12,24 @@ stages:
   - deploy
 
 include:
-  - project: 'eclipse/aidge/gitlab_shared_files' 
+  - project: 'eclipse/aidge/gitlab_shared_files'
     ref: 'main'
-    file: 
+    file:
       # choose which jobs to run by including the corresponding files.
       - '.gitlab/ci/ubuntu_cpp.gitlab-ci.yml'
-
       - '.gitlab/ci/ubuntu_python.gitlab-ci.yml'
-      - '.gitlab/ci/release/cibuildwheel_ubuntu.gitlab-ci.yml'   
- 
-      # - '.gitlab/ci/windows_cpp.gitlab-ci.yml'
-
-      # - '.gitlab/ci/windows_python.gitlab-ci.yml'   
-      # - '.gitlab/ci/release/cibuildwheel_windows.gitlab-ci.yml'   
-
+      - '.gitlab/ci/release/cibuildwheel_ubuntu.gitlab-ci.yml'
 
 test:ubuntu_python:
   before_script:
-    - !reference [.retrieve_deps:apt, script]
-    - source venv/bin/activate
-    - python -m pip install numpy unittest-xml-reporting
-    - python -m pip list
+    - !reference [.setup:test:ubuntu_python, before_script]
     - DEPS_NAMES=("aidge_onnx")
     - DEPENDENCY_JOB="build:ubuntu_python"
     - !reference [.ubuntu:download:artifacts, script]
 
 coverage:ubuntu_python:
-  before_script: 
-    - !reference [.retrieve_deps:apt, script]
-    - source venv/bin/activate
-    - python -m pip install numpy coverage 
+  before_script:
+    - !reference [.setup:coverage:ubuntu_python, before_script]
     - DEPS_NAMES=("aidge_onnx")
     - DEPENDENCY_JOB="build:ubuntu_python"
     - !reference [.ubuntu:download:artifacts, script]
@@ -65,12 +53,12 @@ release:pip:ubuntu:
 #     # Install dependencies
 #     - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y
 #     - choco install git -Y
-#     - choco install python --version=$python_version -Y 
+#     - choco install python --version=$python_version -Y
 #     # Update PATH
 #     - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
 #     - python -m pip install cibuildwheel==2.17.0
 #     # Download repositories
 #     - $DEPS_NAMES = "aidge_core","aidge_backend_cpu","aidge_backend_cuda","aidge_onnx"
 #     - $DEPENDENCY_JOB="build:windows_python"
-#     - !reference [.windows:download:repositories, script] 
+#     - !reference [.windows:download:repositories, script]
 
diff --git a/aidge_quantization/unit_tests/test_ptq.py b/aidge_quantization/unit_tests/test_ptq.py
index e2acab95c7c3c7ee517ebaba5af102d336679cbe..f6b243c27b1a08dbbfc5da522e385ceb4ec9c2f4 100644
--- a/aidge_quantization/unit_tests/test_ptq.py
+++ b/aidge_quantization/unit_tests/test_ptq.py
@@ -9,29 +9,35 @@ import sys
 from pathlib import Path
 
 """
-Unit test for the PTQ pipeline:
-This script is designed to test and validate the accuracy of five small model topologies on the MNIST dataset:
-["MiniResNet", "ConvNet", "BranchNetV4", "TestNet", "MLP"]
-It compares the results for three configurations: the baseline, quantization, and quantization with single shift. 
-The value of sigma represents the tolerance for the tests.
+    Unit test for the PTQ pipeline
+    ==============================
+    This script is designed to test and validate the accuracy of five small topologies on the MNIST dataset :
+    ["MiniResNet", "ConvNet", "BranchNetV4", "TestNet", "MLP"]
+    It compares the results for three configurations : baseline, quantization, and quantization with single shift. 
+    The value of delta represents the tolerance of the tests.
 """
 aidge_core.Log.set_console_level(aidge_core.Level.Error)  # Reduce useless logs
 # --------------------------------------------------------------
 # CONFIGURATION
 # --------------------------------------------------------------
 
-NB_SAMPLES = 1000
+NB_SAMPLES   = 1000
 SAMPLE_SHAPE = (1, 1, 28, 28)
-NB_BITS = 4
-CLIPPING = aidge_quantization.Clipping.MSE
+NB_BITS      = 4
+TARGET_TYPE  = aidge_core.dtype.int32
+CLIPPING     = aidge_quantization.Clipping.MSE
+NO_QUANT     = False
+OPTIM_SIGNS  = True
+FOLD_GRAPH   = True
+DELTA        = 0.05
+
 EXPECTED_RESULTS = {
-    "MiniResNet.onnx": (95.4, 94.5, 94.7),
-    "ConvNet.onnx": (97.9, 97.7, 97.4),
-    "BranchNetV4.onnx": (93.8, 93.2, 93.7),
-    "TestNet.onnx": (95.5, 94.2, 94.2),
-    "MLP.onnx": (94.7, 94.2, 93.3)
+    "MiniResNet.onnx"  : (95.4, 94.4, 95.0),
+    "ConvNet.onnx"     : (97.9, 97.2, 96.7),
+    "BranchNetV4.onnx" : (93.8, 92.7, 93.7),
+    "TestNet.onnx"     : (95.5, 94.0, 94.5),
+    "MLP.onnx"         : (94.7, 92.9, 93.8)
 }
-SIGMA = 0.05
 
 # --------------------------------------------------------------
 # UTILS
@@ -46,8 +52,8 @@ def propagate(model, scheduler, sample):
     return np.array(output_tensor)
 
 def compute_accuracy(model, samples, labels):
-    schedueler = aidge_core.SequentialScheduler(model)
-    acc = sum(labels[i] == np.argmax(propagate(model, schedueler, x)) for i, x in enumerate(samples))
+    scheduler = aidge_core.SequentialScheduler(model)
+    acc = sum(labels[i] == np.argmax(propagate(model, scheduler, x)) for i, x in enumerate(samples))
     return acc / len(samples)
 
 # --------------------------------------------------------------
@@ -60,62 +66,68 @@ class TestQuantization(unittest.TestCase):
         curr_file_dir = Path(__file__).parent.resolve()
         self.samples = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_samples.npy.gz', "r"))
         self.labels = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_labels.npy.gz', "r"))
-        self.quantized_sample = np.load(gzip.GzipFile(curr_file_dir / 'assets/mnist_samples.npy.gz', "r")) * ((1 << (NB_BITS - 1)) - 1)
-        
+        self.quant_samples = np.round(self.samples.copy() * (2**(NB_BITS-1)-1))
+
     def run_model_test(self, model_name):
+
+        expected_base, expected_quant, expected_quant_ssa = EXPECTED_RESULTS[model_name]
+
+        # load the model ...
+
         model_path = Path(__file__).parent / "assets" / model_name
         model = aidge_onnx.load_onnx(model_path, verbose=False)
         aidge_core.remove_flatten(model)
-        model.set_datatype(aidge_core.dtype.float64)
+
+        model.set_datatype(aidge_core.dtype.float32)
         model.set_backend("cpu")
         
-        expected_base, expected_quant, expected_quant_ss = EXPECTED_RESULTS[model_name]
-        
-        # Baseline Accuracy
-        base_accuracy = compute_accuracy(model, self.samples[:NB_SAMPLES], self.labels)
-        self.assertAlmostEqual(base_accuracy * 100, expected_base, delta=SIGMA, msg=f"[X] Baseline accuracy mismatch for {model_name}. Expected accuracy was: {expected_base}, but got: {base_accuracy * 100}")
-        
-        # Quantize
+        # create the tensor subset
+
         tensors = [aidge_core.Tensor(np.reshape(sample, SAMPLE_SHAPE)) for sample in self.samples[:NB_SAMPLES]]
 
-        aidge_quantization.quantize_network(network = model,
-                                    nb_bits =  NB_BITS,
-                                    input_dataset = tensors,
-                                    clipping_mode = CLIPPING,
-                                    target_type = aidge_core.dtype.float64,
-                                    no_quantization = False,
-                                    optimize_signs = True,
-                                    single_shift = False,
-                                    use_cuda = False,
-                                    fold_graph = True,
-                                    bitshift_rounding = False,
-                                    verbose = False)
-        quant_accuracy = compute_accuracy(model, self.quantized_sample[:NB_SAMPLES], self.labels)
-
-        self.assertAlmostEqual(quant_accuracy * 100, expected_quant, delta=SIGMA, msg=f"[X] Quantized accuracy mismatch for {model_name},Expected accuracy was: {expected_quant}, but got: {quant_accuracy * 100}")
-        
-        # Quantize with Single Shift
-        model_ss = aidge_onnx.load_onnx(model_path, verbose=False)
-        aidge_core.remove_flatten(model_ss)
-        model_ss.set_datatype(aidge_core.dtype.float64)
-        model_ss.set_backend("cpu")
+        # BASELINE ACCURACY
+
+        base_accuracy = compute_accuracy(model, tensors, self.labels[:NB_SAMPLES])
+        self.assertAlmostEqual(base_accuracy * 100, expected_base, delta=DELTA, msg=f"[X] Baseline accuracy mismatch for {model_name}. Expected accuracy was: {expected_base}, but got: {base_accuracy * 100}")
+
+        # QUANTIZED ACCURACY
+
+        aidge_quantization.quantize_network(
+            network=model,
+            nb_bits=NB_BITS,
+            calibration_set=tensors,
+            target_type=TARGET_TYPE,
+            clipping_mode=CLIPPING,
+            no_quant=NO_QUANT,
+            optimize_signs=OPTIM_SIGNS,
+            single_shift=False,
+            use_cuda=False,
+            fold_graph=FOLD_GRAPH)
+
+        quant_accuracy = compute_accuracy(model, self.quant_samples[:NB_SAMPLES], self.labels)
+        self.assertAlmostEqual(quant_accuracy * 100, expected_quant, delta=DELTA, msg=f"[X] Quantized accuracy mismatch for {model_name}. Expected accuracy was: {expected_quant}, but got: {quant_accuracy * 100}")
         
-        aidge_quantization.quantize_network(network = model_ss,
-                                            nb_bits =  NB_BITS,
-                                            input_dataset = tensors,
-                                            clipping_mode = CLIPPING,
-                                            target_type = aidge_core.dtype.float64,
-                                            no_quantization = False,
-                                            optimize_signs = True,
-                                            single_shift = True,
-                                            use_cuda = False,
-                                            fold_graph = True,
-                                            bitshift_rounding = False,
-                                            verbose = False)    
+        # QUANTIZED ACCURACY WITH SSA
+
+        model = aidge_onnx.load_onnx(model_path, verbose=False)
+        model.set_datatype(aidge_core.dtype.float32)
+        model.set_backend("cpu")
         
-        quant_accuracy_ss = compute_accuracy(model_ss, self.quantized_sample[:NB_SAMPLES], self.labels)
-        self.assertAlmostEqual(quant_accuracy_ss * 100, expected_quant_ss, delta=SIGMA, msg=f"[X] Quantized Single Shift accuracy mismatch for {model_name}.Expected accuracy was: {expected_quant_ss}, but got: {quant_accuracy_ss * 100}")
-    
+        aidge_quantization.quantize_network(
+            network=model,
+            nb_bits=NB_BITS,
+            calibration_set=tensors,
+            target_type=TARGET_TYPE,
+            clipping_mode=CLIPPING,
+            no_quant=NO_QUANT,
+            optimize_signs=OPTIM_SIGNS,
+            single_shift=True,
+            use_cuda=False,
+            fold_graph=FOLD_GRAPH)
+
+        quant_accuracy_ssa = compute_accuracy(model, self.quant_samples[:NB_SAMPLES], self.labels)
+        self.assertAlmostEqual(quant_accuracy_ssa * 100, expected_quant_ssa, delta=DELTA, msg=f"[X] Quantized accuracy (with SSA) mismatch for {model_name}. Expected accuracy was: {expected_quant_ssa}, but got: {quant_accuracy_ssa * 100}")
+
     def test_models(self):
         for model in EXPECTED_RESULTS.keys():
             with self.subTest(model=model):
diff --git a/include/aidge/operator/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp
index ff8235c6ea4c92935b869d8ba522a3fdcbc9b8e2..96182202f38be20afa539eb41a8d32b989afcf9f 100644
--- a/include/aidge/operator/PTQMetaOps.hpp
+++ b/include/aidge/operator/PTQMetaOps.hpp
@@ -16,74 +16,69 @@
 
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/graph/Node.hpp"
+#include "aidge/data/Data.hpp"
 
 namespace Aidge {
 
-/// @brief Quantizer acts as a meta-operator to handle scaling operations in the PTQ, replacing the Scaling Operator.
-/// This operator is composed of a sequence of [Mul] -> [Clip] -> [Round] operations.
-///
-/// @param scalingFactor The scaling factor to apply to the input (essentially a scalar to multiply the input with).
-/// @param clip_min The minimum value for the clip operation.
-/// @param clip_max The maximum value for the clip operation.
-/// @param name The name of the meta-operator node created.
-/// @return A shared pointer to an instance of the meta-operator node.
-std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name);
+    /**
+     * @brief Create a Quantizer node that initially consists of a multiplier and a scaling factor.
+     * @param scalingFactor The value of the multiplicative coefficient.
+     * @param name Name of the Quantizer.
+     */
+    std::shared_ptr<Aidge::Node> Quantizer(double scalingFactor, const std::string& name);
 
-/// @brief IntQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration 
-///        into computation graphs with a data type other than Float while preserving floating-point precision.
-/// 
-/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after 
-/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), 
-/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs 
-/// while maintaining the precision of floating-point operations.
-///
-/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted.
-/// @param targetType The target data type to which the final output should be cast after the quantization process.
-/// @param name The name of the meta-operator node created.
-/// @return A shared pointer to a new instance of the modified meta-operator node.
-std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name);
+    /**
+     * @brief Given a Quantizer, multiply it's internal multiplicative coefficient by a value.
+     * @param coeff The value of the multiplicative coefficient.
+     */
+    void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff);
 
-/// @brief BitShiftQuantizer acts as an extension of the Quantizer meta-operator, enabling seamless integration 
-///        into computation graphs with a data type other than Float while preserving floating-point precision.
-/// 
-/// This operator modifies the provided Quantizer by inserting explicit casting operations before and after 
-/// the quantization process. It first casts the input to Float64, applies the quantization steps (Mul, Clip, Round), 
-/// and then casts the result back to the target data type. This ensures compatibility with integer-based computation graphs 
-/// while maintaining the precision of floating-point operations.
-///
-/// @param oldQuantizer A shared pointer to the existing Quantizer node that will be adapted.
-/// @param targetType The target data type to which the final output should be cast after the quantization process.
-/// @param name The name of the meta-operator node created.
-/// @return A shared pointer to a new instance of the modified meta-operator node.
-std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType,bool bitshiftRounding, const std::string& name);
+    /**
+     * @brief Given a Quantizer, create a copy of it that has a Round node and a Clip node added
+     * at its endpoint, and replace the given Quantizer by it (a swap is also done by reference). 
+     * @param quantizer The quantizer to modify and replace.
+     * @param clipMin the min value of the clip node.
+     * @param clipMax the max value of the clip node.
+     */
+    void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax);
+    
+    /**
+     * @brief Given a Quantizer, create a copy of it that has the Round node removed,
+     * and replace the given Quantizer by it (a swap is also done by reference). 
+     * @param quantizer The quantizer to modify and replace.
+     */
+    void removeRound(std::shared_ptr<Node>& quantizer);
 
-/// @brief Updates the scaling factor of a PTQ meta-operator node, allowing for dynamic adjustment of the scaling parameter.
-/// This function sets a new scaling factor for a specified meta-operator node, modifying the scalar applied in the [Mul] operation.
-/// The meta-operator node must be a PTQ-specific operator, such as a Quantizer or Scaling node.
-///
-/// @param MetaOpNode A shared pointer to the PTQ meta-operator node whose scaling factor will be updated.
-/// @param newScalingFactor The new scaling factor to apply to the meta-operator node.
-/// @return True if the scaling factor was successfully updated, false if the operation failed (e.g., if MetaOpNode is null or incompatible).
-void updateScalingFactor(std::shared_ptr<Aidge::Node> MetaOpNode, double newScalingFactor);
+    /**
+     * @brief Given a Quantizer, create a copy of it where the Mul node is replaced by
+     * a Bit-Shift node, and replace the given Quantizer by it (a swap is also done by reference). 
+     * @param quantizer The quantizer to modify and replace.
+     */
+    void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer);
 
-/// @brief Retrieves the current scaling factor of a PTQ meta-operator node.
-/// This function returns the scaling factor associated with the specified PTQ meta-operator node,
-/// allowing inspection of the current scalar applied in the [Mul] operation.
-///
-/// @param MetaOpNode A shared pointer to the PTQ meta-operator node whose scaling factor is being queried.
-/// @return The scaling factor currently applied to the meta-operator node, or -1 if the operation fails (e.g., if MetaOpNode is null or incompatible).
-double getScalingFactor(std::shared_ptr<Aidge::Node> MetaOpNode);
+    /**
+     * @brief Given a Quantizer, create a copy of it that has Cast nodes inserted at it's IOs,
+     * and replace the given Quantizer by it (a swap is also done by reference). The input cast 
+     * node convert the input data to the internal type, while the output cast convert it back
+     * to the external type.
+     * @param quantizer The quantizer to modify and replace.
+     * @param externalType The external data type used for the casts.
+     */
+    void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType);
 
-/// @brief Sets the clip range for an existing Quantizer node by specifying minimum and maximum clipping values.
-/// This function modifies the clip range of a Quantizer node, allowing adjustment of the range within which values are clipped
-/// in the [Clip] operation of the Quantizer sequence.
-///
-/// @param QuantizerNode A shared pointer to the Quantizer node whose clip range is being set.
-/// This node should have been created using the Quantizer function.
-/// @param min The minimum value for the clip range. Values below this will be clipped to this minimum.
-/// @param max The maximum value for the clip range. Values above this will be clipped to this maximum.
-/// @return True if the clip range was successfully set, false if the operation failed (e.g., if QuantizerNode is null).
-void setClipRange(std::shared_ptr<Aidge::Node> QuantizerNode, double min, double max);
+    /**
+     * @brief Given a Quantizer, retreive the coefficient of it's Mul node.
+     * @param quantizer The quantizer containing the multiplicative coefficient.
+     */
+    double getScalingFactor(std::shared_ptr<Aidge::Node> quantizer);
+
+    /**
+     * @brief Given a Quantizer containing a Clip node, replace its clipping values.
+     * @param quantizer The quantizer containing the Clip node.
+     * @param min The min clipping value.
+     * @param max The max clipping value.
+     */
+    void setClipRange(std::shared_ptr<Aidge::Node> quantizer, double min, double max);
 
 }
 
diff --git a/include/aidge/quantization/PTQ/Clipping.hpp b/include/aidge/quantization/PTQ/Clipping.hpp
index 159b64f12f8c6ae2bb3e88592b29f211e15fa614..35f23f5f2022128238e1991717876d6462d0b6da 100644
--- a/include/aidge/quantization/PTQ/Clipping.hpp
+++ b/include/aidge/quantization/PTQ/Clipping.hpp
@@ -33,10 +33,10 @@ namespace Aidge
      * @param valueRanges A map associating each considered node name to its corresponding output range.
      * @param nbBins Desired number of bins of the returned histograms.
      * @param graphView The GraphView containing the considered nodes.
-     * @param inputDataSet The input dataset, consisting of a vector of input samples.
+     * @param calibrationSet The calibration dataset, consisting of a vector of input samples.
      * @return A map associating each node name to it's corresponding activation histogram.
      */
-    std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda);
+    std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda);
 
     /**
      * @brief Given an input activation histogram, compute the optimal clipping value in the sense of the Lp norm.
@@ -63,11 +63,11 @@ namespace Aidge
      * @param valueRanges The map associating each affine node to its output range.
      * @param nbBits The quantization number of bits.
      * @param graphView The GraphView containing the considered nodes.
-     * @param inputDataSet The input dataset, consisting of a vector of input samples.
+     * @param calibrationSet The calibration dataset, consisting of a vector of input samples.
      * @param verbose Whether to print the clipping values or not.
      * @return The corrected map associating each provided node to its clipped range.
      */
-    std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose);
+    std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda, bool verbose);
 
 }
 
diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp
index 7f11c012636b04d0434ff7a6a04bbf5131096171..d9b944e33cc5706bb8f62ddeb1553ace0619245d 100644
--- a/include/aidge/quantization/PTQ/PTQ.hpp
+++ b/include/aidge/quantization/PTQ/PTQ.hpp
@@ -89,24 +89,12 @@ namespace Aidge {
     std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule = true, bool verbose = false);
     
     /**
-     * @brief Inserts a scaling node below the given producer node in the graph view. 
-     *        If the node is already a producer scaling node, it accumulates the scaling factor by multiplyins its value directly.
-     *
+     * @brief Inserts a scaling node below the given producer node in the graphView. 
      * @param node A shared pointer to the producer node where the scaling node will be inserted (below).
-     * @param scalingFactor The scaling factor to apply.
      * @param graphView A shared pointer to the graph view in which the nodes are located.
      * @return True if the scaling node was successfully inserted or the scaling factor was accumulated; False otherwise.
      */
-    bool insertScalingBelowProducer(std::shared_ptr<Node> node, double scalingFactor, std::shared_ptr<GraphView> graphView);
-
-    /**
-     * @brief Inserts a rounding node below the given producer (also below its ows producerScaling) node in the graph view. 
-     *
-     * @param node A shared pointer to the producer node where the rounding node will be inserted.
-     * @param graphView A shared pointer to the graph view in which the nodes are located.
-     * @return True if the rounding node was successfully inserted; False otherwise.
-     */
-    bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView);
+    void insertScalingBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView);
 
     /**
      * @brief Determine whether an input GraphView can be quantized or not.
@@ -124,9 +112,13 @@ namespace Aidge {
     void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff);
 
 
+    /**
+     * @brief Prepare a network before the quantization is applied to it, by removing, replacing
+     * or fusing the nodes that are not supported by the PTQ pipeline. 
+     * @param graphView The network to prepare for the quantization
+     */
     void prepareNetwork(std::shared_ptr<GraphView> graphView);
 
-
     /**
      * @brief Insert a scaling node after each affine node of the GraphView.
      * Also insert a scaling node in every purely residual branches.
@@ -143,11 +135,11 @@ namespace Aidge {
     /**
      * @brief Compute the activation ranges of every affine node, given an input dataset.
      * @param graphView The GraphView containing the affine nodes, on which the inferences are performed.
-     * @param inputDataSet The input dataset, consisting of a vector of input samples.
+     * @param calibrationSet The calibration dataset, consisting of a vector of input samples.
      * @param scalingNodesOnly Whether to restrain the retreival of the ranges to scaling nodes only or not.
      * @return A map associating each affine node name to it's corresponding output range.
      */
-    std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda);
+    std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool scalingNodesOnly, bool useCuda);
 
     /**
      * @brief Normalize the activations of each affine node so that they fit in the [-1:1] range.
@@ -176,34 +168,34 @@ namespace Aidge {
      */
     void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant, bool optimizeSigns, bool verbose);
 
+    /**
+     * @brief Take a quantized GraphView represented in floating precision and cast it to the desired target precision. 
+     * If single-shift option is set to True, the scaling nodes contained in the activation quantizers are replaced with bit-shifts.
+     * @param graphView The GraphView to modify.
+     * @param targetType The desired precision of the cast.
+     * @param singleShift If set to True, replace the scaling-factors by bit-shifts.
+     * @param bitShiftRounding If singleShift is True, specifies the kind of bit-shift roundinqg..
+     * 
+     */  
+    void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType targetType, bool singleShift /*, bool bitShiftRounding*/);
+
     /**
      * @brief Main quantization routine. Performs every step of the quantization pipeline.
      * @param graphView The GraphView to be quantized.
      * @param nbBits The desired number of bits of the quantization.
-     * @param inputDataSet The input dataset on which the value ranges are computed.
-     * @param clippingMode Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
-     * @param targetType Desired target type to cast the graph into (default is Float64 which will NOT apply casting on the network)
+     * @param calibrationSet The calibration dataset used for the activations calibration.
+     * @param targetType The desired data-type of the outputed GraphView.
+     * @param clippingMode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
      * @param noQuant Whether to apply the rounding operations or not.
      * @param optimizeSigns Whether to take account of the IO signs of the operators or not.
-     * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
-     * @param useCuda Wheter to speed up the PTQ by computing the values ranges using CUDA kernels. 
-     * This flag does not set the backend of the graphview to "cuda" at the end of the PTQ pipeline 
-     * @param foldGraph Whether to apply the constant folding recipe which makes the end graphview much easier to read
-     * @param bitshiftRounding Whether rounding should be applied after bit-shifting operations. If enabled, the result of bit-shifting is rounded to the nearest integer.
-     * @param verbose Whether to print internal informations about the quantization process.
-     */
-    void quantizeNetwork(std::shared_ptr<GraphView> graphView, 
-        std::uint8_t nbBits,
-        std::vector<std::shared_ptr<Tensor>> inputDataSet,
-        Clipping clippingMode,
-        DataType targetType,
-        bool noQuant,
-        bool optimizeSigns,
-        bool singleShift,
-        bool useCuda,
-        bool foldGraph,
-        bool bitshiftRounding,
-        bool verbose);
+     * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes parameters.
+     * @param useCuda Whether to use the CUDA backend for performing the activation calibration or not.
+     * @param foldGraph Whether to fold the parameter quantizers after the quantization or not.
+     * @param verbose Whether to print internal informations about the quantization process or not.
+     */
+    void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> calibrationSet, DataType targetType, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose);
+
+
     /**
      * @brief Compute the weight ranges of every affine node. Provided for debugging purposes.
      * @param graphView The GraphView containing the affine nodes.
diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h
deleted file mode 100644
index 546263af3a7e8b7a73991173f48d0b095c7d9501..0000000000000000000000000000000000000000
--- a/include/aidge/quantization_version.h
+++ /dev/null
@@ -1,11 +0,0 @@
-#ifndef VERSION_H
-#define VERSION_H
-
-namespace Aidge {
-static constexpr const int PROJECT_VERSION_MAJOR = 0;
-static constexpr const int PROJECT_VERSION_MINOR = 2;
-static constexpr const int PROJECT_VERSION_PATCH = 0;
-static constexpr const char * PROJECT_VERSION = "0.2.0";
-static constexpr const char * PROJECT_GIT_HASH = "f50c860";
-}
-#endif // VERSION_H
diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp
index d7bc00dcc095736419732b9ed56918ca37663b50..ad6931c8f6dcc9e6f3dd8d16fb57e6cadf06efe6 100644
--- a/python_binding/pybind_PTQ.cpp
+++ b/python_binding/pybind_PTQ.cpp
@@ -55,13 +55,13 @@ void init_PTQ(py::module &m) {
     :type network: :py:class:`aidge_core.GraphView`
     )mydelimiter");
 
-    m.def("compute_ranges", &computeRanges, py::arg("network"), py::arg("input_dataset"), py::arg("scaling_nodes_only"), py::arg("use_cuda"),
+    m.def("compute_ranges", &computeRanges, py::arg("network"), py::arg("calibration_set"), py::arg("scaling_nodes_only"), py::arg("use_cuda"),
     R"mydelimiter(
     Compute the activation ranges of every affine node, given an input dataset.
     :param network: The GraphView containing the affine nodes, on which the inferences are performed.
     :type network: :py:class:`aidge_core.GraphView`
-    :param input_dataset: The input dataset, consisting of a vector of input samples.
-    :type input_dataset: list of :py:class:`aidge_core.Tensor`
+    :param calibration_set: The input dataset, consisting of a vector of input samples.
+    :type calibration_set: list of :py:class:`aidge_core.Tensor`
     :param scaling_nodes_only: Whether to restrain the retreival of the ranges to scaling nodes only or not
     :type scaling_nodes_only: bool
     :return: A map associating each considered node name to it's corresponding output range.
@@ -78,56 +78,61 @@ void init_PTQ(py::module &m) {
     :type value_ranges: list of float.
     )mydelimiter");
 
-    m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("no_quantization")=false, py::arg("optimize_signs"), py::arg("verbose") = false,
+    m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("no_quant")=false, py::arg("optimize_signs"), py::arg("verbose") = false,
     R"mydelimiter(
     Quantize an already normalized (in term of parameters and activations) network.
     :param network: The GraphView to be quantized.
     :type network: :py:class:`aidge_core.GraphView`
     :param nb_bits: The desired number of bits of the quantization.
     :type nb_bits: int
-    :param apply_rounding: Whether to apply the rounding operations or not.
-    :type apply_rounding: bool
+    :param no_quant: Whether to apply the rounding operations or not.
+    :type no_quant: bool
     :param optimize_signs: Whether to take account of the IO signs of the operators or not.
     :type optimize_signs: bool
     :param verbose: Whether to print the sign map or not.
     :type verbose: bool
     )mydelimiter");
 
-    m.def("quantize_network",
-        &quantizeNetwork,
-        py::arg("network"),
-        py::arg("nb_bits"),
-        py::arg("input_dataset"),
-        py::arg("clipping_mode") = Clipping::MAX,
-        py::arg("target_type") = DataType::Float64,
-        py::arg("no_quantization") = false,
-        py::arg("optimize_signs") = false,
-        py::arg("single_shift") = false, 
-        py::arg("use_cuda") = false,
-        py::arg("fold_graph") = true,
-        py::arg("bitshift_rounding") = false,
-        py::arg("verbose") = false,
+    m.def("cast_quantized_network", &castQuantizedNetwork, py::arg("network"), py::arg("target_type"), py::arg("single_shift"), /* py::arg("bitshift_rounding"),*/
+    R"mydelimiter(
+    Take a quantized GraphView represented in floating precision and cast it to the desired target precision. 
+    If single-shift option is set to True, the scaling nodes contained in the activation quantizers are replaced with bit-shifts.
+    :param network: The GraphView to cast.
+    :type network: :py:class:`aidge_core.GraphView`
+    :param targetType: The node output value ranges computed over the calibration dataset.
+    :type targetType: :py:class:`aidge_core.DataType`
+    :param single_shift: If set to True, replace the scaling-factors by bit-shifts.
+    :type single_shift: bool
+    )mydelimiter");
+
+    m.def("quantize_network", &quantizeNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_set"), py::arg("target_type"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quant") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false,  py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false,
     R"mydelimiter(
     Main quantization routine. Performs every step of the quantization pipeline.
     :param network: The GraphView to be quantized.
     :type network: :py:class:`aidge_core.GraphView`
     :param nb_bits: The desired number of bits of the quantization.
     :type nb_bits: int
-    :param input_dataset: The input dataset on which the value ranges are computed.
-    :type input_dataset: list of :py:class:`aidge_core.Tensor`
+    :param calibration_set: The input dataset used for the activations calibration.
+    :type calibration_set: list of :py:class:`aidge_core.Tensor`
+    :param target_type: The desired data-type of the outputed GraphView.
+    :type target_type: :py:class:`aidge_core.DataType`
     :param clipping_mode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'.
-    :type clipping_mode: string
-    :param no_quantization: Whether to truly quantize the network or not.
-    :type no_quantization: bool
+    :type clipping_mode: :py:class:`aidge_quantization.Clipping`
+    :param no_quant: Whether to apply the rounding operations or not.
+    :type no_quant: bool
     :param optimize_signs: Whether to take account of the IO signs of the operators or not.
     :type optimize_signs: bool
-    :param single_shift: Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights.
+    :param single_shift: Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes parameters.
     :type single_shift: bool
+    :param use_cuda: Whether to use the CUDA backend for performing the activation calibration or not.
+    :type use_cuda: bool  
+    :param fold_graph: Whether to fold the parameter quantizers after the quantization or not.
+    :type fold_graph: bool
     :param verbose: Whether to print internal informations about the quantization process.
     :type verbose: bool
     )mydelimiter");
 
-    m.def("compute_histograms", &computeHistograms, py::arg("value_ranges"), py::arg("nb_bins"), py::arg("network"), py::arg("input_dataset"), py::arg("use_cuda"),
+    m.def("compute_histograms", &computeHistograms, py::arg("value_ranges"), py::arg("nb_bins"), py::arg("network"), py::arg("calibration_set"), py::arg("use_cuda"),
     R"mydelimiter(
     Compute the histograms of the activations of each node contained in the map of the ranges (passed as argument).
     :param value_ranges: A map associating each considered node name to its corresponding output range.
@@ -136,8 +141,8 @@ void init_PTQ(py::module &m) {
     :type nb_bins: int
     :param network: The GraphView containing the considered nodes.
     :type network: :py:class:`aidge_core.GraphView`
-    :param input_dataset: The input dataset, consisting of a list of input samples.
-    :type input_dataset: list of :py:class:`aidge_core.Tensor`
+    :param calibration_set: The input dataset, consisting of a list of input samples.
+    :type calibration_set: list of :py:class:`aidge_core.Tensor`
     :return: A map associating each node name to it's corresponding activation histogram.
     :rtype: dict
     )mydelimiter");
@@ -166,7 +171,7 @@ void init_PTQ(py::module &m) {
     :rtype: float
     )mydelimiter");
 
-    m.def("adjust_ranges", &adjustRanges, py::arg("clipping_mode"), py::arg("value_ranges"), py::arg("nb_bits"), py::arg("network"), py::arg("input_dataset"), py::arg("use_cuda"), py::arg("verbose") = false,
+    m.def("adjust_ranges", &adjustRanges, py::arg("clipping_mode"), py::arg("value_ranges"), py::arg("nb_bits"), py::arg("network"), py::arg("calibration_set"), py::arg("use_cuda"), py::arg("verbose") = false,
     R"mydelimiter(
     Return a corrected map of the provided activation ranges.
     To do so compute the optimal clipping values for every node and multiply the input ranges by those values.
@@ -179,8 +184,8 @@ void init_PTQ(py::module &m) {
     :type nb_bits: int
     :param network: The GraphView containing the considered nodes.
     :type network: :py:class:`aidge_core.GraphView`
-    :param input_dataset: The input dataset, consisting of a list of input samples.
-    :type input_dataset: list of :py:class:`aidge_core.Tensor`
+    :param calibration_set: The input dataset, consisting of a list of input samples.
+    :type calibration_set: list of :py:class:`aidge_core.Tensor`
     :param verbose: Whether to print the clipping values or not.
     :type verbose: bool
     :return: The corrected map associating to each provided node its clipped range.
@@ -226,7 +231,6 @@ void init_PTQ(py::module &m) {
     )mydelimiter");
 
     m.def("prepare_network", &prepareNetwork, py::arg("network"), "prepare the network for the PTQ");
-
 }
 
 } // namespace Aidge
diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index 57787a8951a513cd0dc8660c6ef3a99b63e74729..33cf14667de7121f93d2804d66e6c8037b643f81 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -20,24 +20,17 @@
 #include "aidge/quantization/PTQ/PTQ.hpp"  // retrieveNodeVector
 
 #include "aidge/graph/GraphView.hpp"
-
 #include "aidge/scheduler/SequentialScheduler.hpp"
 #include "aidge/scheduler/Scheduler.hpp"
 #include "aidge/utils/Log.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
-#include "aidge/utils/Log.hpp"
-
-#include "aidge/operator/Mul.hpp"
-#include "aidge/operator/ArgMax.hpp"
-#include "aidge/operator/Abs.hpp"
-#include "aidge/operator/Reshape.hpp"
-#include "aidge/operator/Round.hpp"
 
 #include "aidge/operator/Mul.hpp"
 #include "aidge/operator/ArgMax.hpp"
 #include "aidge/operator/Abs.hpp"
 #include "aidge/operator/Reshape.hpp"
 #include "aidge/operator/Round.hpp"
+#include "aidge/operator/MetaOperator.hpp"
 
 namespace Aidge
 {
@@ -62,17 +55,29 @@ static bool nodeHasBias(std::shared_ptr<Node> node)
     return false;
 }
 
-// What is this thing ???
-// Function used to extract the local tensor (from a ProducerScalingNode)
-std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) 
+std::shared_ptr<Aidge::Tensor> getScaledWeightTensor(std::shared_ptr<Node> node) 
 {
-    if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) 
+    if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerQuantizer"))
     {
-        std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator());
-        operatorTensor->forward(); // We need the forward pass to compute the scaled value of the Tensor
-        return operatorTensor->getOutput(0);
-    } else {
-        return getWeightTensor(node);
+        auto quantizer = node->getParent(1);
+
+        // perform an inference on the branch
+
+        auto graphView = Sequential({quantizer});
+        graphView->add(quantizer->getParent(0));
+        SequentialScheduler scheduler(graphView);
+        scheduler.forward(true, {});
+
+        // gather and return the result
+
+        auto op = std::static_pointer_cast<MetaOperator_Op>(quantizer->getOperator());
+        auto result = op->getOutput(0);
+        return result;
+    } 
+    else 
+    {
+        auto result = getWeightTensor(node);        
+        return result;
     }
 }
 
@@ -115,26 +120,30 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
             std::shared_ptr<Node> n1 = affineNodeVector[i];
             std::shared_ptr<Node> n2 = affineNodeVector[i+1];
 
-            std::shared_ptr<Aidge::Tensor> n1localTensor = getLocalTensor(n1);
-            std::shared_ptr<Aidge::Tensor> n2localTensor = getLocalTensor(n2);
+            std::shared_ptr<Aidge::Tensor> w1 = getScaledWeightTensor(n1);
+            std::shared_ptr<Aidge::Tensor> w2 = getScaledWeightTensor(n2);
             
-            double r1 = getTensorAbsoluteMax(n1localTensor);
-            double r2 = getTensorAbsoluteMax(n2localTensor);
+            //Log::notice(" TENSOR : \n {}", *w1);
+
+            double r1 = getTensorAbsoluteMax(w1);
+            double r2 = getTensorAbsoluteMax(w2);
 
             double s1 = std::sqrt(r1 * r2) / r1;
             double s2 = std::sqrt(r1 * r2) / r2;
 
-            insertScalingBelowProducer(n1->getParent(1), s1, graphView);
+            multiplyScalingFactor(n1->getParent(1), s1);
 
             if (nodeHasBias(n1))
-                insertScalingBelowProducer(n1->getParent(2), s1, graphView);
+                multiplyScalingFactor(n1->getParent(2), s1);
 
-            insertScalingBelowProducer(n2->getParent(1), s2, graphView);
+            multiplyScalingFactor(n2->getParent(1), s2);
 
             double rangeDelta = std::abs(r1 - r2);
             if (rangeDelta > maxRangeDelta)
                 maxRangeDelta = rangeDelta;
         }
+
+        // Log::notice(" CLE delta = {} ", maxRangeDelta);
     }
     while (maxRangeDelta > targetDelta);
 }
diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp
index 5bd2a7da90f8ddd2d5d3903e4a4479e7654233e5..4107c555e6b356401ed06057c9a06463084b74bd 100644
--- a/src/PTQ/Clipping.cpp
+++ b/src/PTQ/Clipping.cpp
@@ -18,8 +18,8 @@
 
 namespace Aidge
 {
-    
-std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda)
+
+std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(std::unordered_map<std::shared_ptr<Node>, double> valueRanges, int nbBins, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda)
 {
     if (useCuda)
         graphView->setBackend("cuda");
@@ -35,17 +35,10 @@ std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(st
 
     // Setup the histograms ...
 
-    for (std::shared_ptr<Node> node : graphView->getNodes())
+    for (std::pair<std::shared_ptr<Node>, double> pair : valueRanges)
     {
-        bool isInsideRanges = (valueRanges.find(node) != valueRanges.end());
-        if (isInsideRanges)
-        {
-            std::vector<int> histogram;
-            for (int i = 0; i < nbBins; i++)
-                histogram.push_back(0);
-
-            histograms.insert(std::make_pair(node, histogram));
-        }
+        std::vector<int> histogram(nbBins, 0);
+        histograms.insert(std::make_pair(pair.first, histogram));
     }
 
     // Fill the histograms ...
@@ -54,7 +47,7 @@ std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(st
 
     int it = 0;
 
-    for (std::shared_ptr<Tensor> inputTensor : inputDataSet)
+    for (std::shared_ptr<Tensor> inputTensor : calibrationSet)
     {
         Log::debug(" IT (BIS) : {}", it++);
 
@@ -66,35 +59,32 @@ std::unordered_map<std::shared_ptr<Node>, std::vector<int>> computeHistograms(st
         scheduler.forward(true, {inputTensor});
 
         // Gather values ...
-
-        for (std::shared_ptr<Node> node : graphView->getNodes())
+        
+        for (std::pair<std::shared_ptr<Node>, double> pair : valueRanges)
         {
-            bool isInsideRanges = (valueRanges.find(node) != valueRanges.end());
-            if (isInsideRanges)
-            {
-                double valueRange = valueRanges[node];
+            std::shared_ptr<Node> node = pair.first;
+            double valueRange = pair.second; 
 
-                std::shared_ptr<Operator> nodeOperator = node->getOperator();
-                std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
+            std::shared_ptr<Operator> nodeOperator = node->getOperator();
+            std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
 
-                if (useCuda)
-                    valueTensor->setBackend("cpu");
+            if (useCuda)
+                valueTensor->setBackend("cpu");
 
-                double * castedTensor = static_cast<double *> (valueTensor->getImpl()->rawPtr());
+            double * castedTensor = static_cast<double *> (valueTensor->getImpl()->rawPtr());
 
-                std::vector<int> nodeHistogram = histograms[node];
-                for(std::size_t i = 0; i < valueTensor->size(); i++)
-                {
-                    std::size_t bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins));
-                    bin = std::min(bin, nodeHistogram.size() - 1);
-                    nodeHistogram[bin]++;
-                }
+            std::vector<int> nodeHistogram = histograms[node];
+            for(std::size_t i = 0; i < valueTensor->size(); i++)
+            {
+                std::size_t bin = std::round(std::abs(castedTensor[i] / valueRange * nbBins));
+                bin = std::min(bin, nodeHistogram.size() - 1);
+                nodeHistogram[bin]++;
+            }
 
-                histograms[node] = nodeHistogram;   
+            histograms[node] = nodeHistogram;   
 
-                if (useCuda)
-                    valueTensor->setBackend("cuda");
-            }
+            if (useCuda)
+                valueTensor->setBackend("cuda");
         }
 
         if (useCuda)
@@ -206,8 +196,7 @@ double computeKLClipping(std::vector<int> refHistogram, std::uint8_t nbBits)
     return bestClipping;
 }
 
-
-std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool useCuda, bool verbose)
+std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clippingMode, std::unordered_map<std::shared_ptr<Node>, double> valueRanges, std::uint8_t nbBits, std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool useCuda, bool verbose)
 {
     double clipping = 1.0f;
 
@@ -218,11 +207,11 @@ std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clipping
         if (verbose)
             Log::info(" === CLIPPING VALUES === ");
 
-        std::unordered_map<std::shared_ptr<Node>, std::vector<int>> histograms = computeHistograms(valueRanges, nbBins, graphView, inputDataSet, useCuda);
+        std::unordered_map<std::shared_ptr<Node>, std::vector<int>> histograms = computeHistograms(valueRanges, nbBins, graphView, calibrationSet, useCuda);
 
         for (std::shared_ptr<Node> node : graphView->getNodes())
         {
-            if (node->attributes()->hasAttr("quantization.ptq.isScaling"))
+            if (node->attributes()->hasAttr("quantization.ptq.isActivationQuantizer"))
             {
                 std::vector<int> histogram = histograms[node];
 
@@ -241,7 +230,6 @@ std::unordered_map<std::shared_ptr<Node>, double> adjustRanges(Clipping clipping
         }
     }
     
-
     return valueRanges;
 }
 
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index df203f2547e720bcfbef109e05e7ccca5ed42b9e..9e2a62c8b975bbca4f63c90d500324a75e492e64 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -9,34 +9,31 @@
  *
  ********************************************************************************/
 
-#include "aidge/quantization/PTQ/CLE.hpp"
-#include "aidge/quantization/PTQ/Clipping.hpp"
-#include "aidge/quantization/PTQ/PTQ.hpp"
-#include "aidge/operator/PTQMetaOps.hpp"
-
-#include "aidge/data/Tensor.hpp"
-#include "aidge/graph/GraphView.hpp"
-#include "aidge/graph/Node.hpp"
-#include "aidge/scheduler/SequentialScheduler.hpp"
-#include "aidge/scheduler/Scheduler.hpp"
-#include "aidge/utils/Log.hpp"
-#include "aidge/operator/MetaOperator.hpp"
-
-#include "aidge/operator/Producer.hpp"
-#include "aidge/operator/Mul.hpp"
-#include "aidge/operator/Round.hpp"
-#include "aidge/operator/ReLU.hpp"
-#include "aidge/operator/BatchNorm.hpp"
-#include "aidge/operator/Conv.hpp"
-#include "aidge/operator/ArgMax.hpp"
-#include "aidge/operator/Reshape.hpp"
-#include "aidge/operator/MatMul.hpp"
-#include "aidge/operator/Cast.hpp"
-
-
-
-#include "aidge/recipes/Recipes.hpp"
-#include "aidge/recipes/QuantRecipes.hpp"
+ #include "aidge/quantization/PTQ/CLE.hpp"
+ #include "aidge/quantization/PTQ/Clipping.hpp"
+ #include "aidge/quantization/PTQ/PTQ.hpp"
+ #include "aidge/operator/PTQMetaOps.hpp"
+ 
+ #include "aidge/data/Tensor.hpp"
+ #include "aidge/graph/GraphView.hpp"
+ #include "aidge/graph/Node.hpp"
+ #include "aidge/scheduler/SequentialScheduler.hpp"
+ #include "aidge/scheduler/Scheduler.hpp"
+ #include "aidge/utils/Log.hpp"
+ 
+ #include "aidge/operator/Producer.hpp"
+ #include "aidge/operator/Mul.hpp"
+ #include "aidge/operator/Round.hpp"
+ #include "aidge/operator/ReLU.hpp"
+ #include "aidge/operator/BatchNorm.hpp"
+ #include "aidge/operator/Conv.hpp"
+ #include "aidge/operator/ArgMax.hpp"
+ #include "aidge/operator/Reshape.hpp"
+ #include "aidge/operator/MatMul.hpp"
+ 
+ #include "aidge/recipes/Recipes.hpp"
+ #include "aidge/recipes/QuantRecipes.hpp"
+ #include "aidge/operator/MetaOperator.hpp"
 
 namespace Aidge
 {
@@ -83,20 +80,13 @@ bool isNotQuantized(std::shared_ptr<Node> node)
     return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
 }
 
-std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView)
-{
-    return graphView->getOrderedInputs()[0].first;
-}
 void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType)
 {
-    for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes())
-    {
-        for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++)
-        {
-            inputNode->getOperator()->resetInput(index);
-        }
-    }
+    for (std::shared_ptr<Aidge::Node> node : graphView->inputNodes())
+        for (Aidge::IOIndex_t i = node->getFirstFreeDataInput(); i < node->getNbFreeDataInputs(); i++)
+            node->getOperator()->resetInput(i);
 }
+
 bool checkArchitecture(std::shared_ptr<GraphView> graphView)
 {
     std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"});
@@ -222,22 +212,6 @@ static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> paren
     return index;
 }
 
-void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff)
-{
-    AIDGE_ASSERT(node->type() == "Mul" && hasAttr(node, "isProducerScaling") || hasAttr(node, "isScaling"),
-        "Cannot update the scaling factor on Node of type {} with no scaling tag", node->type());
-    
-    auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
-
-    std::shared_ptr<Tensor> fallback;
-    const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
-    double previousScalingFactor = localTensor.get<double>(0);
-
-    std::shared_ptr<Tensor> resultTensor = std::make_shared<Tensor>(Array1D<double, 1> {previousScalingFactor * coeff});
-    node->input(1).first->getOperator()->setOutput(0, resultTensor);
-}
-
-// Utility function that insert a node below another one already connected 
 static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> newNode, std::shared_ptr<GraphView> graphView) 
 {
     // Checking the parents always have at least 1 children
@@ -267,80 +241,110 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n
 
     graphView->add(newNode);
 }
-void applyConstFold(std::shared_ptr<GraphView> &graphView)
+
+void foldProducerQuantizers(std::shared_ptr<GraphView> graphView)
 {
-    for (const std::shared_ptr<Node> node : graphView->getNodes())
+    std::vector<std::shared_ptr<Node>> producerQuantizers;
+    for (std::shared_ptr<Node> node : graphView->getNodes())
+        if (hasAttr(node, "isProducerQuantizer"))
+            producerQuantizers.push_back(node);
+
+    for (std::shared_ptr<Node> quantizer : producerQuantizers)
     {
-        if (node->type() == "Producer" )
-        {
-            const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator());
-            producer->constant() = true;
-        }
+        // Set the param producer to be constant
+
+        auto paramProducer = quantizer->getParent(0);
+        auto paramProducerOp = std::static_pointer_cast<Producer_Op>(paramProducer->getOperator());
+        paramProducerOp->constant() = true;
+
+        // Set the internal producers of the quantizer to be constant
+
+        auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    
+        auto microGraph = quantizerOp->getMicroGraph();
+        
+        for (auto producer : microGraph->getNodes())
+            if (producer->type() == "Producer")
+            {
+                auto producerOp = std::static_pointer_cast<Producer_Op>(producer->getOperator());
+                producerOp->constant() = true;  
+            }   
+
+        expandMetaOp(quantizer); // mandatory for now !!!
     }
+
     constantFolding(graphView);
 }
 
-bool castQuantizedGraph(std::shared_ptr<GraphView> &graphView, Aidge::DataType targetType, bool singleShift,bool bitshiftRounding)
+void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType targetType, bool singleShift /*, bool bitShiftRounding*/)
 {
-    //We need a deepcopy of the graphs nodes since we will replace some nodes
-    std::vector<std::shared_ptr<Node>> nodeVector(graphView->getNodes().begin(), graphView->getNodes().end());
+    std::set<Aidge::DataType> supportedFloatTypes = {DataType::Float16, DataType::Float32, DataType::Float64};
+    std::set<Aidge::DataType> supportedIntTypes = {DataType::Int16, DataType::Int32, DataType::Int64};
 
-    for (std::shared_ptr<Node> node : nodeVector)
+    bool castToFloat = (supportedFloatTypes.find(targetType) != supportedFloatTypes.end());
+    bool castToInt = (supportedIntTypes.find(targetType) != supportedIntTypes.end());
+
+    if (castToFloat)
     {
-        if (node->type() == "Round" && node->attributes()->hasAttr("quantization.ptq.isProducerRounding"))
-        {
-            std::shared_ptr<Aidge::Node> castNode =  Cast(targetType,node->name() + "_Cast");
-            castNode->getOperator()->setDataType(targetType);
-            castNode->getOperator()->setBackend(node->getOperator()->backend());
-            insertChildren(node,castNode,graphView);
-            castNode->attributes()->addAttr("quantization.ptq.isProducerCasting",0.0);
-            node->getOperator()->setDataType(DataType::Float64);
-        }
-        else if(node->type() == "Quantizer")
+        graphView->setDataType(targetType);
+    }
+    else if (castToInt)
+    {
+        if (singleShift)
         {
-            if(singleShift)
-            {
-                std::shared_ptr<Node> newBitShiftQuantizer = BitShiftQuantizer(node,targetType,bitshiftRounding,node->name()+"_BitShift_Quantizer");
-                newBitShiftQuantizer->getOperator()->setBackend(node->getOperator()->backend());
-                graphView->replace({node},{newBitShiftQuantizer});
+            Log::notice(" Replacing scaling nodes with bit-shifts ...");
 
-            }
-            else //If single shift is not enabled we keep using the alternative Int Quantizer (which cast the data before and after the regular Quantizer Operations) 
+            // Replace the scaling nodes with bit-shifts (activations only)
+
+            std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); // must be called again because of removeRound() !
+            for (std::shared_ptr<Node> node : nodes)
             {
-                std::shared_ptr<Node> newIntQuantizer = IntQuantizer(node,targetType,node->name());
-                newIntQuantizer->getOperator()->setBackend(node->getOperator()->backend());
-                graphView->replace({node},{newIntQuantizer});
+                if (node->type() == "Quantizer")
+                {
+                    if (hasAttr(node, "isActivationQuantizer"))
+                    {
+                        removeRound(node);
+                        replaceScalingWithBitShift(node);
+                    }
+                    else if (hasAttr(node, "isProducerQuantizer"))
+                        castQuantizerIOs(node, targetType);
+                }
             }
+
+            // Cast the nodes (excepted the producers and quantizers) to integer precision
+            nodes = graphView->getNodes();
+            for (std::shared_ptr<Node> node : nodes)
+                if (node->type() != "Producer" && !hasAttr(node, "isProducerQuantizer")) // TODO : double check this !
+                    node->getOperator()->setDataType(targetType);
         }
-        else if (node->type() != "Producer" &&
-        !node->attributes()->hasAttr("quantization.ptq.isProducerScaling")) 
-        {              
-            node->getOperator()->setDataType(targetType);
-        }   
-    }
-    return true;
-}
+        else
+        {
+            // Set the nodes (excepted the producers and quantizers) to have integer IOs
 
-bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView)
-{
-    if (hasAttr(node, "isProducerScaling") && node->type() != "Round")
-    {
-        std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round");
-        roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
-        roundNode->getOperator()->setBackend(determineBackend(node));
+            std::set<std::shared_ptr<Node>> nodes = graphView->getNodes();
+            for (std::shared_ptr<Node> node : nodes)
+                if (node->type() != "Quantizer" && node->type() != "Producer")
+                    node->getOperator()->setDataType(targetType);
 
-        insertChildren(node, roundNode, graphView);
-        addAttr(roundNode, "isProducerRounding");
-    
-        return true;
+            // Cast the quantizers input and outputs by inserting Cast nodes
+
+            for (std::shared_ptr<Node> node : nodes)
+                if (node->type() ==  "Quantizer")
+                    castQuantizerIOs(node, targetType);
+        }
+    }
+    else
+    {
+        Log::error(" Cannot cast the quantized network : target type '{}' is not supported ! ", targetType);
     }
-    return false;
 }
 
 double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
 {
+    std::shared_ptr<Tensor> fallback;
+
     // get the abs tensor
-    std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR
+
     std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
 
     // flatten the abs tensor
@@ -373,28 +377,6 @@ double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
     return localFlatTensor.get<double>(maxIndex);
 }
 
-
-// TODO : pass nodeVector by reference ...
-static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::shared_ptr<Node>> nodeVector, std::string nodeType)
-{
-    std::vector<std::shared_ptr<Node>> remainingNodes;
-    for (std::shared_ptr<Node> node : nodeVector)
-        if (node->type() != nodeType)
-            remainingNodes.push_back(node);
-
-    return remainingNodes;
-}
-
-static std::vector<std::shared_ptr<Node>> removeProdScalingNodes(std::vector<std::shared_ptr<Node>> nodeVector)
-{
-    std::vector<std::shared_ptr<Node>> remainingNodes;
-    for (std::shared_ptr<Node> node : nodeVector)
-        if (!hasAttr(node, "isProducerScaling"))
-            remainingNodes.push_back(node);
-
-    return remainingNodes;
-}
-
 static void fixScheduling(std::vector<std::shared_ptr<Node>>& nodeVector) {
 
     std::vector<std::shared_ptr<Node>> correctedVector;
@@ -414,22 +396,42 @@ static void fixScheduling(std::vector<std::shared_ptr<Node>>& nodeVector) {
 
 static std::shared_ptr<Tensor> getWeightTensor(std::shared_ptr<Node> node)
 {
-    return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
+    std::shared_ptr<Node> producer = node->getParent(1);
+
+    if (producer->type() == "Quantizer") 
+        producer = producer->getParent(0);
+
+    return std::static_pointer_cast<OperatorTensor>(producer->getOperator())->getOutput(0);
 }
 
 static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
 {
-    return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
+    std::shared_ptr<Node> producer = node->getParent(2);
+
+    if (producer->type() == "Quantizer") 
+        producer = producer->getParent(0);
+
+    return std::static_pointer_cast<OperatorTensor>(producer->getOperator())->getOutput(0);
 }
 
 std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose)
 {
     std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes();
    
+    // Remove duplicate nodes. XXX Is it still needed ???
+
     fixScheduling(nodeVector); 
 
-    nodeVector = removeMatchingNodes(nodeVector, "Producer");
-    nodeVector = removeProdScalingNodes(nodeVector);
+    // Remove Producers and their Scalings
+
+    std::vector<std::shared_ptr<Node>> remainingNodes;
+    for (std::shared_ptr<Node> node : nodeVector)
+        if ((node->type() != "Producer") && !hasAttr(node, "isProducerQuantizer"))
+            remainingNodes.push_back(node);
+
+    nodeVector = remainingNodes;
+
+    // Verbose
 
     if (verbose) 
     {
@@ -441,6 +443,10 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView>
     return nodeVector;    
 }
 
+static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView)
+{
+    return retrieveNodeVector(graphView)[0];
+}
 
 // TODO : enhance this by modifying OperatorImpl in "core" ...
 static DataType getDataType(std::shared_ptr<Node> node)
@@ -449,54 +455,56 @@ static DataType getDataType(std::shared_ptr<Node> node)
     return op->getOutput(0)->dataType();
 }
 
-static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vector<std::string> attributes, double value)
+// XXX double check this !
+static bool nodeHasBias(std::shared_ptr<Node> node)
 {
-    std::shared_ptr<Node> scalingNode = Mul(name);
-  
-    for (std::string a : attributes)
-        addAttr(scalingNode, a);
-    
-    // Add the scaling factor as a producer of the node
-
-    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {value});
-    std::shared_ptr<Node> scalingFactorProducer = addProducer(scalingNode, 1, {1}, "ScalingFactor"); 
-
-    scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
-
-    // XXX graphView->add(scalingNode);
-
-    return scalingNode;
+    if (node->getParents().size() == 3) {
+        std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
+        if (biasTensor)
+            return true;
+    }
+    return false;
 }
 
-bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scalingFactor, std::shared_ptr<GraphView> graphView)
+// TODO: rework this !
+static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> node)
 {
-    if (hasAttr(producerNode, "isProducerRounding"))
-    {
-        // In this case we 'bump' the node to the one above him (an actual ProducerScaling)
-        // because the round node is not usable (only used when SSA is enabled)
-        producerNode = producerNode->getParent(0);
-    }
-
-    if (hasAttr(producerNode, "isProducerScaling"))
-    {
-        // We accumulate the previous scaling factors by multiplying the SF of the ProducerScaling node 
-        // (adding new nodes each time would make the graph unusable)
-        multiplyScalingFactor(producerNode, scalingFactor);
-        return true;
+    std::shared_ptr<Node> currNode = node;
+    while(!hasAttr(currNode, "isActivationQuantizer")) {
+        if (currNode->getParents().size() == 0) {
+            Log::warn(" Warning : No previous Scaling node were found ! ");
+            break;
+        }
+        currNode = currNode->getParents()[0];
     }
+    return currNode;
+}
 
-    AIDGE_ASSERT(producerNode->type() == "Producer", " Cannot apply a scaling factor on node of type: {} which is not a Producer", producerNode->type());
-   
+void insertScalingBelowProducer(std::shared_ptr<Node> producerNode, std::shared_ptr<GraphView> graphView)
+{
     std::string scalingNodeName = makeUniqueName(producerNode->name() + "_ProducerScaling", graphView);
-    std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isProducerScaling"}, scalingFactor);
+    std::shared_ptr<Node> scalingNode = Quantizer(1.0, scalingNodeName);;
+    addAttr(scalingNode, "isProducerQuantizer");
 
     scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
-    scalingNode->getOperator()->setBackend(determineBackend(producerNode));
+    scalingNode->getOperator()->setBackend(determineBackend(producerNode)); // XXX use the producer parent instead ???
 
     insertChildren(producerNode, scalingNode, graphView);
-    graphView->add(scalingNode->getParent(1)); // add the scaling factor producer
+}
 
-    return true;
+void insertProducerScalingNodes(std::shared_ptr<GraphView> graphView)
+{
+    std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
+
+    for (std::shared_ptr<Node> node : nodeSet)
+    {
+        if (isAffine(node))
+        {
+            insertScalingBelowProducer(node->getParent(1), graphView);
+            if (nodeHasBias(node))
+                insertScalingBelowProducer(node->getParent(2), graphView);
+        }
+    }
 }
 
 // XXX HERE : Branches containing only Seamless nodes should be considered as residual too !!!
@@ -524,47 +532,39 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView)
                     Log::info(" ### inserting multiplicative node ...");
 
                     std::string residualNodeName = makeUniqueName(parentNode->name() + "_Res", graphView);
-                    std::shared_ptr<Node> residualNode = createScalingNode(residualNodeName, {"isScaling", "isResidual"}, 1.0);
+                    
+                    // XXX
+                    std::shared_ptr<Node> residualNode = Quantizer(1.0, residualNodeName);
+                    addAttr(residualNode, "isActivationQuantizer");
 
                     residualNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
                     residualNode->getOperator()->setBackend(determineBackend(parentNode));
 
                     graphView->insertParent(node, residualNode, i, 0, 0);
-                    graphView->add(residualNode->getParent(1)); // add the scaling factor producer
-
                 }
             }
         }
     }
 }
 
-static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> node)
-{
-    std::shared_ptr<Node> currNode = node;
-    while(!hasAttr(currNode, "isScaling"))
-    {
-        if (currNode->getParents().size() == 0)
-        {
-            Log::warn(" Warning : No previous Scaling node were found ! ");
-            break;
-        }
-        currNode = currNode->getParents()[0];
-    }
-    return currNode;
-}
-
 void insertScalingNodes(std::shared_ptr<GraphView> graphView)
 {
+    insertProducerScalingNodes(graphView);
     insertResidualScalingNodes(graphView);
 
     std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
 
     for (std::shared_ptr<Node> parentNode : nodeSet)
     {
+        // Insert a Scaling node after each node that have to be quantized
+
         if (isAffine(parentNode) || isMerging(parentNode) || isNotQuantized(parentNode))
         {
             std::string scalingNodeName = makeUniqueName(parentNode->name() + "_Scaling", graphView);
-            std::shared_ptr<Node> scalingNode = createScalingNode(scalingNodeName, {"isScaling"}, 1.0);
+
+            // XXX XXX XXX
+            std::shared_ptr<Node> scalingNode = Quantizer(1.0, scalingNodeName);
+            addAttr(scalingNode, "isActivationQuantizer");
 
             scalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
             scalingNode->getOperator()->setBackend(determineBackend(parentNode));
@@ -572,42 +572,30 @@ void insertScalingNodes(std::shared_ptr<GraphView> graphView)
             if (parentNode->getChildren().size() > 0) {
                 insertChildren(parentNode, scalingNode, graphView);
             } else {
-                // Log::info(" last node reached ! ");
                 parentNode->addChild(scalingNode, 0, 0);
                 graphView->add(scalingNode);
             }
-
-            graphView->add(scalingNode->getParent(1)); // add the scaling factor producer
-
+            
             // In the case the node is a non-linear operator we want to add an extra
             // scaling node before it to rescale it's input ...
 
             if (isNotQuantized(parentNode))
             {
                 std::string prevScalingNodeName = makeUniqueName(parentNode->name() + "_PrevScaling", graphView);
-                std::shared_ptr<Node> prevScalingNode = createScalingNode(prevScalingNodeName, {"isScaling"}, 1.0);
+
+                // XXX XXX XXX
+                std::shared_ptr<Node> prevScalingNode = Quantizer(1.0, prevScalingNodeName);
+                addAttr(prevScalingNode, "isActivationQuantizer");
 
                 prevScalingNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
                 prevScalingNode->getOperator()->setBackend(determineBackend(parentNode));
 
                 graphView->insertParent(parentNode, prevScalingNode, 0, 0, 0);
-                graphView->add(prevScalingNode->getParent(1)); // add the scaling factor producer
             }
         }
     }
 }
 
-// XXX double check this !
-static bool nodeHasBias(std::shared_ptr<Node> node)
-{
-    if (node->getParents().size() == 3) {
-        std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
-        if (biasTensor)
-            return true;
-    }
-    return false;
-}
-
 void normalizeParameters(std::shared_ptr<GraphView> graphView)
 {
     // CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
@@ -620,12 +608,12 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
 
     // ITERATE OVER THE GRAPH /////////////////////////////////////////////////
 
-    std::shared_ptr<Node> firstNode =getFirstNode(graphView);
+    std::shared_ptr<Node> firstNode = getFirstNode(graphView);
 
     for (std::shared_ptr<Node> node : nodeVector)
     {
         // Scaling nodes still have a ratio of 1, so they are seamless ...
-        if (node->type() == "ReLU" || hasAttr(node, "isScaling") || isSeamless(node))
+        if (node->type() == "ReLU" || hasAttr(node, "isActivationQuantizer") || isSeamless(node))
         {
             if (node != firstNode)
             {
@@ -640,11 +628,11 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
             // Rescale the weight tensor
             
             std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
-            double scaling = getTensorAbsoluteMax(weightTensor);
-            double ratio = 1.0 / scaling;
 
-            //rescaleTensor(weightTensor, ratio);
-            insertScalingBelowProducer(node->getParent(1), ratio, graphView);
+            double ratio = 1.0 / getTensorAbsoluteMax(weightTensor);
+
+            // rescaleTensor(weightTensor, ratio);
+            multiplyScalingFactor(node->getParent(1), ratio);
 
             // Accumulate the ratio
 
@@ -661,7 +649,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
             {
                 std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
                 //rescaleTensor(biasTensor, accumulatedRatios[node] );
-                insertScalingBelowProducer(node->getParent(2), accumulatedRatios[node], graphView);
+                multiplyScalingFactor(node->getParent(2), accumulatedRatios[node]);
             }
         }
 
@@ -729,44 +717,13 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
     }
 }
 
-// XXX TODO : take care of the CUDA backend for this too !!!
-std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> inputTensor, bool scalingNodesOnly)
-{
-    std::unordered_map<std::shared_ptr<Node>, double> valueRanges;
-
-    SequentialScheduler scheduler(graphView);
-    scheduler.resetScheduling();
-
-    // Inference ... 
-
-    scheduler.forward(true, {inputTensor});
-
-    // Gather ranges ...
-
-    std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
-    for (std::shared_ptr<Node> node : nodeSet)
-    {
-        if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
-        {
-            std::shared_ptr<Operator> nodeOperator = node->getOperator();
-            std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
-            double range = getTensorAbsoluteMax(valueTensor);
-
-            // Associate the value to the scaling node ...
-            valueRanges.insert(std::make_pair(node, range));
-        }
-    }
-
-    return valueRanges;
-}
-
-std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, bool scalingNodesOnly, bool useCuda)
+std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, bool scalingNodesOnly, bool useCuda)
 {
     std::unordered_map<std::shared_ptr<Node>, double> valueRanges;
     std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes();
     
     for (std::shared_ptr<Node> node : nodeSet)
-        if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
+        if ((scalingNodesOnly && hasAttr(node, "isActivationQuantizer")) || (!scalingNodesOnly && (node->type() != "Producer")))
             valueRanges.insert(std::make_pair(node, 0));
 
     if (useCuda)
@@ -777,7 +734,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
 
     int it = 0;
 
-    for (std::shared_ptr<Tensor> sample : inputDataSet)
+    for (std::shared_ptr<Tensor> sample : calibrationSet)
     {
         //Log::info(" IT : {}", it++);
 
@@ -793,7 +750,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
         std::unordered_map<std::shared_ptr<Node>, double> sampleRanges;
         for (std::shared_ptr<Node> node : nodeSet)
         {
-            if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
+            if ((scalingNodesOnly && hasAttr(node, "isActivationQuantizer")) || (!scalingNodesOnly && (node->type() != "Producer")))
             {
                 std::shared_ptr<Operator> nodeOperator = node->getOperator();
                 std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0));
@@ -815,7 +772,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
 
         for (std::shared_ptr<Node> node : nodeSet)
         {
-            if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer")))
+            if ((scalingNodesOnly && hasAttr(node, "isActivationQuantizer")) || (!scalingNodesOnly && (node->type() != "Producer")))
                 if (sampleRanges[node] > valueRanges[node])
                     valueRanges[node] = sampleRanges[node];
         }
@@ -832,7 +789,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
 
 void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges)
 {
-    std::shared_ptr<Node> firstNode =  getFirstNode(graphView);
+    std::shared_ptr<Node> firstNode = getFirstNode(graphView);
 
     // CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
 
@@ -861,7 +818,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
 
         // Use the Scaling nodes to rescale the ranges ...
 
-        if (hasAttr(node, "isScaling")) 
+        if (hasAttr(node, "isActivationQuantizer")) 
         {
             std::shared_ptr<Node> prevNode = node->getParent(0);
 
@@ -874,11 +831,9 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
 
             // If prevNode is Affine, fix the bias ...
 
-            if (isAffine(prevNode)) {
-                if (nodeHasBias(prevNode)) {                
-                    insertScalingBelowProducer(prevNode->getParent(2), 1.0 / prevRatio, graphView);
-                }
-            }
+            if (isAffine(prevNode))
+                if (nodeHasBias(prevNode))  
+                    multiplyScalingFactor(prevNode->getParent(2), 1.0 / prevRatio);
         }
 
         // Merging nodes handling : use a maximum arbritration ...
@@ -929,19 +884,18 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
 
             multiplyScalingFactor(prevScalingNode, prevRatio);
         }
-
     }
 }
 
 std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose)
 {
-    std::shared_ptr<Node> firstNode =  getFirstNode(graphView);
+    std::shared_ptr<Node> firstNode = getFirstNode(graphView);
 
     std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap;
 
     std::pair<bool, bool> unsignedPair(true, true);
     for (std::shared_ptr<Node> node : graphView->getNodes())
-        if (node->type() != "Producer")
+        if (node->type() != "Producer") // XXX XXX XXX we should use nodeVector instead ...
             signMap.insert(std::make_pair(node, unsignedPair));
 
     // ITERATE OVER THE GRAPH
@@ -963,7 +917,7 @@ std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(
             signMap[node].second = false;
         } 
 
-        if (hasAttr(node, "isScaling")) 
+        if (hasAttr(node, "isActivationQuantizer")) 
         {
             signMap[node].second = false;
 
@@ -1010,7 +964,7 @@ std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(
                 // Arbitration : Signed type wins !
                 for(std::shared_ptr<Node> parent : parentNodes)
                 {
-                    while (!hasAttr(parent, "isScaling"))
+                    while (!hasAttr(parent, "isActivationQuantizer"))
                     {
                         signMap[parent] = std::make_pair(false, false);
                         // We are on a branch so nodes always have 1 parent ...
@@ -1074,6 +1028,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
         signMap = computeSignMap(graphView, verbose);
     else
     {
+        // XXX XXX XXX we should use the (retreive) node vector
         std::pair<bool, bool> signedPair(false, false);
         for (std::shared_ptr<Node> node : graphView->getNodes())
             if (node->type() != "Producer")
@@ -1089,11 +1044,11 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
         if (isAffine(node))
         {
             // Rescale the weight tensor
-            std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
-            insertScalingBelowProducer(node->getParent(1),signedMax,graphView);
+            multiplyScalingFactor(node->getParent(1), signedMax);
 
+            // UUU Quantize the Producer !!!
             if (!noQuant)
-                insertRoundBelowProducer(node->getParent(1),graphView);
+                appendRoundClip(node->getParent(1), -(signedMax + 1), signedMax);
 
             // Rescale the bias tensor
             if (nodeHasBias(node))  
@@ -1101,11 +1056,12 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
                 bool inputIsUnsigned = signMap[node].first;
                 double rescaling = inputIsUnsigned ? unsignedMax * signedMax : signedMax * signedMax;
             
-                std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
-                insertScalingBelowProducer(node->getParent(2),rescaling,graphView);
+                multiplyScalingFactor(node->getParent(2), rescaling);
 
+                // XXX TODO : enhance this ! 
+                int biasMax = (1 << (12 + nbBits));
                 if (!noQuant)
-                    insertRoundBelowProducer(node->getParent(2),graphView);
+                    appendRoundClip(node->getParent(2), -(biasMax + 1), biasMax);
             }
 
             // Compensate the rescaling using the next Scaling node
@@ -1120,7 +1076,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
             
             std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ...
 
-            multiplyScalingFactor(scalingNode,rescaling) ;          
+            multiplyScalingFactor(scalingNode, rescaling);          
         }
         
         if (isMerging(node))
@@ -1139,7 +1095,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
             if (node->type() == "MatMul")
                 rescaling /= inputIsUnsigned ? unsignedMax : signedMax;
 
-            multiplyScalingFactor(scalingNode, rescaling) ;          
+            multiplyScalingFactor(scalingNode, rescaling);          
         }
 
         if (isNotQuantized(node))
@@ -1155,7 +1111,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
         
         // Handle the Scaling Nodes ...
 
-        if (hasAttr(node, "isScaling"))
+        if (hasAttr(node, "isActivationQuantizer")) 
         {
             // Don't touch the scalings that precede non-linearities ...
 
@@ -1166,40 +1122,33 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_
 
             if (!noQuant && !precedesNonLinearNode) 
             {  
-                // Replace the Scaling Node by a Quantizer
+                // we need to gather the sign informations before we modify 
+                // the node pointer with appendRoundClip() ... 
 
-                auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
-                std::shared_ptr<Tensor> fallback;
-                const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
-                double oldScalingFactor = localTensor.get<double>(0); //!\\ 
+                bool inputIsUnsigned  = signMap[node].first;
+                bool outputIsUnsigned = signMap[node].second;
 
-                std::shared_ptr<Node> quantizerNode = Quantizer(oldScalingFactor, -(signedMax + 1), signedMax, node->name());
-                quantizerNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
-                quantizerNode->getOperator()->setBackend(determineBackend(node));
-                graphView->replace({node, node->getParent(1)}, {quantizerNode});
+                appendRoundClip(node, -(signedMax + 1), signedMax);
 
                 if (optimizeSigns)
                 {
                     double rescaling = 1.0;
 
-                    bool inputIsUnsigned  = signMap[node].first;
-                    bool outputIsUnsigned = signMap[node].second;
-
                     rescaling /= inputIsUnsigned  ? unsignedMax : signedMax;
                     rescaling *= outputIsUnsigned ? unsignedMax : signedMax;
 
-                    double currScalingFactor = getScalingFactor(quantizerNode);
-                    updateScalingFactor(quantizerNode, currScalingFactor * rescaling);
-                    
-                    if(outputIsUnsigned)
-                        setClipRange(quantizerNode, 0, unsignedMax);                 
+                    // XXX XXX XXX
+                    multiplyScalingFactor(node, rescaling);
+
+                    if (outputIsUnsigned)
+                        setClipRange(node, 0, unsignedMax);                                  
                 }
             }
         }
     }
 }
 
-static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits)
+void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant)
 {
     // XXX Use the signMap to increase the resolution when possible ...
     double signedMax = (1 << (nbBits - 1)) - 1;    
@@ -1210,7 +1159,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
     {
         // The appropriate strategy is to check if the Quantizer is not   
         // preceded by an Weighted node (that is not forking), and insert  
-        // a coeff node (Compensation) if so ...
+        // a mul node (Compensation) before it if so ...
  
         if (node->type() == "Quantizer")
         {
@@ -1227,7 +1176,7 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
                 std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView);
                 std::shared_ptr<Node> mulNode = Mul(mulNodeName);
                 
-                addAttr(mulNode, "isCompensation");
+                // XXX XXX XXX addAttr(mulNode, "isCompensation");
 
                 mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode)
                 mulNode->getOperator()->setBackend(determineBackend(node));
@@ -1247,14 +1196,26 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u
 
                 // Adapt the scaling factor value accordingly
 
-                double currScalingFactor = getScalingFactor(node); 
-                updateScalingFactor(node, currScalingFactor / signedMax); 
+                multiplyScalingFactor(node, 1.0 / signedMax); // XXX XXX XXX OK
+
+                // Insert a Quantizer for the coeffProducer that will handle  
+                // the single-shift approximation via it's scalingFactor ...
+
+                insertScalingBelowProducer(coeffProducer, graphView);
+                
+                if (!noQuant) 
+                {
+                    // XXX XXX XXX double check this ...
+                    std::shared_ptr<Node> coeffQuantizer = mulNode->getParent(1);
+                    appendRoundClip(coeffQuantizer, -(signedMax + 1), signedMax);
+                }
+
             }
         }
     }
 }
 
-void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool noQuant)
+void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView)
 {
     std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
 
@@ -1265,39 +1226,33 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool
             std::shared_ptr<Node> linearNode = node->getParent(0);
 
             double base = getScalingFactor(node);
-
             double approx = std::pow(2, std::ceil(std::log2(base)));
+            double ratio = approx / base;
 
-            updateScalingFactor(node, approx);
+            // set the scaling factor value to the approximation ...
 
-            double ratio = base / approx;
+            multiplyScalingFactor(node, ratio);
 
-            insertScalingBelowProducer(linearNode->getParent(1), ratio, graphView);
-            if (!noQuant)
-                insertRoundBelowProducer(linearNode->getParent(1), graphView);
+            // compensate the ratio using the previous node scaling factors ...
 
+            multiplyScalingFactor(linearNode->getParent(1), 1.0 / ratio);
             if (nodeHasBias(linearNode))
-            {
-                insertScalingBelowProducer(linearNode->getParent(2), ratio, graphView);
-                if (!noQuant)
-                    insertRoundBelowProducer(linearNode->getParent(2), graphView);
-            }
+                multiplyScalingFactor(linearNode->getParent(2), 1.0 / ratio);
         }
     }
 }
 
-
 static void printScalingFactors(std::shared_ptr<GraphView> graphView)
 {
     for (auto node : retrieveNodeVector(graphView))
-        if (hasAttr(node, "isScaling") || node->type() == "Quantizer")
+        if (hasAttr(node, "isActivationQuantizer") || node->type() == "Quantizer")
         {
             double scalingFactor = getScalingFactor(node);
-            Log::info(" {:.6f} ({})", scalingFactor, node->name());
+            Log::notice(" SCALING FACTOR : {} ({})", scalingFactor, node->name());
         }
 }
 
-static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> inputDataSet, DataType dataType)
+static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std::shared_ptr<Tensor>> calibrationSet, DataType dataType)
 {
     graphView->setDataType(dataType);
 
@@ -1308,32 +1263,21 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std:
             inputTensor->setDataType(dataType);
     }
 
-    for (auto tensor : inputDataSet)
+    for (auto tensor : calibrationSet)
         tensor->setDataType(dataType);
 }
 
-void quantizeNetwork(std::shared_ptr<GraphView> graphView, 
-    std::uint8_t nbBits, 
-    std::vector<std::shared_ptr<Tensor>> inputDataSet,
-    Clipping clippingMode, 
-    DataType targetType,
-    bool noQuant, 
-    bool optimizeSigns,
-    bool singleShift,
-    bool useCuda,
-    bool foldGraph,
-    bool bitshiftRounding,
-    bool verbose)
+void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> calibrationSet, DataType targetType, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose)
 {
-    Log::notice(" === QUANT PTQ 0.2.21 === ");
+    Log::notice(" === QUANT PTQ 0.3.0 === ");
 
     graphView->setBackend("cpu");
 
     if (!checkArchitecture(graphView))
         return;
 
-    DataType initialDataType = (inputDataSet[0])->dataType();
-    setupDataType(graphView, inputDataSet, DataType::Float64);
+    DataType initialDataType = (calibrationSet[0])->dataType();
+    setupDataType(graphView, calibrationSet, DataType::Float64);
 
     Log::notice(" Preparing the network for the PTQ ... ");
     prepareNetwork(graphView);
@@ -1341,17 +1285,17 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
     Log::notice(" Inserting the scaling nodes ...");
     insertScalingNodes(graphView);
 
-    // TODO : double check this !
-    crossLayerEqualization(graphView);
+    // TODO : double check the CLE ...
+    crossLayerEqualization(graphView); // XXX XXX XXX
 
     Log::notice(" Normalizing the parameters ...");
     normalizeParameters(graphView);
 
     Log::notice(" Computing the value ranges ...");
-    std::unordered_map<std::shared_ptr<Node>, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda);
+    std::unordered_map<std::shared_ptr<Node>, double> valueRanges = computeRanges(graphView, calibrationSet, true, useCuda);
 
     Log::notice(" Optimizing the clipping values ...");
-    valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose);
+    valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, calibrationSet, useCuda, verbose);
 
     Log::notice(" Normalizing the activations ...");
     normalizeActivations(graphView, valueRanges);
@@ -1362,38 +1306,30 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
     if (singleShift)
     {
         Log::notice( " Inserting the compensation nodes ...");
-        insertCompensationNodes(graphView, nbBits);
+        insertCompensationNodes(graphView, nbBits, noQuant);
 
         Log::notice(" Performing the Single-Shift approximation ...");
-        performSingleShiftApproximation(graphView, noQuant);
-    }
-    if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) 
-    {
-        AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
-        Log::notice("Starting to cast operators into the desired type ...");
-        castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding);
-        
-        graphView->updateInputsOutputs();
-        clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType
+        performSingleShiftApproximation(graphView);
     }
-    else
-    {
-        setupDataType(graphView, inputDataSet, targetType);
-    }
-    if(foldGraph)
+
+    Log::notice(" Casting the network to the target type ({}) ...", targetType);
+    castQuantizedNetwork(graphView, targetType, singleShift);
+
+    if (foldGraph)
     {
-        Log::notice("Applying constant folding recipe to the graph ...");
-        applyConstFold(graphView);
+        Log::notice(" Folding the Producer's Quantizers ...");
+        foldProducerQuantizers(graphView);
     }
-    //Mandatory to handle all of the newly added connections!
-    graphView->updateInputsOutputs();
-    
-    //Clearing input nodes
-    Log::notice("Clearing all input nodes ...");
+
+    // TODO ...
+    // Log::notice(" Clearing the input nodes ...");
 
     if (verbose)
         printScalingFactors(graphView);
-    
+
+    if (useCuda)
+        graphView->setBackend("cuda");
+
     Log::notice(" Reseting the scheduler ...");
     SequentialScheduler scheduler(graphView);
     scheduler.resetScheduling();
@@ -1424,8 +1360,9 @@ void clearBiases(std::shared_ptr<GraphView> graphView)
         if (node->type() == "FC" || node->type() == "Conv2D") {
             std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
             //rescaleTensor(biasTensor, 0);
-            insertScalingBelowProducer(node->getParent(2), 0, graphView);
+            //insertScalingBelowProducer(node->getParent(2), 0, graphView);
+            multiplyScalingFactor(node->getParent(2), 0);
         }
     }
 }
-}
\ No newline at end of file
+}
diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp
index c043e4739991c6b7b93fc8e5a710e56e0ce2ac30..82b2e501dc275b361899a0ae8284f8a5409d32dc 100644
--- a/src/operator/PTQMetaOps.cpp
+++ b/src/operator/PTQMetaOps.cpp
@@ -32,159 +32,346 @@
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/utils/Log.hpp"
 
-
 namespace Aidge
+{ 
+
+static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr)
 {
-static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType)
-{
-    std::shared_ptr<Node> mulNode = nullptr;
-    for(std::shared_ptr<Node> node : graphView->getNodes())
-        if (node->type() == nodeType)
-            mulNode = node;
+    return node->attributes()->hasAttr("quantization.ptq." + attr);
+}
 
-    return mulNode;
+static void addAttr(std::shared_ptr<Aidge::Node> node, std::string attr, double value = 0.0)
+{
+    node->attributes()->addAttr("quantization.ptq." + attr, value);
 }
-std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double clipMax, const std::string& name)
+
+// TODO : rework this
+static void copyDynamicAttributes(std::shared_ptr<Aidge::Node> prevNode, std::shared_ptr<Aidge::Node> newNode)
 {
-    // create the nodes
+    if (hasAttr(prevNode, "isProducerQuantizer"))
+        addAttr(newNode, "isProducerQuantizer");
 
-    std::shared_ptr<Node> mulNode =  Mul((!name.empty()) ? name + "_MulQuant" : "");
-    std::shared_ptr<Node> roundNode = Round((!name.empty()) ? name + "_RoundQuant" : "");
-    std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_ClipQuant" : "", clipMin, clipMax);
+    if (hasAttr(prevNode, "isActivationQuantizer"))
+        addAttr(newNode, "isActivationQuantizer");  
+}
+
+std::shared_ptr<Node> Quantizer(double scalingFactor, const std::string& name)
+{
+    std::shared_ptr<Node> mulNode =  Mul(name + "_MulQuant");
 
-    // connect the scaling factor producer
+    // Scaling Factor Producer
 
     std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
     std::shared_ptr<Node> scalingFactorProducer = addProducer<1>(mulNode, 1, {1}, "ScalingFactor");
     scalingFactorProducer->getOperator()->setOutput(0, scalingFactorTensor);
 
-    // create the metaop graph
+    // TODO : the above could be replaced by :
+    // std::shared_ptr<Node> scalingFactorProducer = Producer(scalingFactorTensor);
+    // scalingFactorProducer->addChild(mulNode, 0, 1);
+
+    // create the graphView ...
 
-    std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode});
-    std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ???
+    std::shared_ptr<GraphView> graphView = Sequential({mulNode});
+    graphView->add(scalingFactorProducer);
 
-    // return the metaop
+    // alternative : capture the Producer ...
+    // std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); 
 
-    return MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype
+    std::shared_ptr<Node> quantizer = MetaOperator("Quantizer", graphView, {}, name); // an simpler prototype exists ...
 
+    return quantizer;
 }
-std::shared_ptr<Node> BitShiftQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType,bool bitshiftRounding, const std::string& name)
-{
-    double scalingFactor = getScalingFactor(oldQuantizer);
 
-    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (oldQuantizer->getOperator());
-    std::shared_ptr<Node> oldclipNode = getSubNode(metaOp->getMicroGraph(), "Clip");
+void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
+{
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
 
-    if (!oldclipNode) {
-    Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", oldQuantizer->type());
-        return nullptr;
-    }
+    // Get the Mul node from the microGraph
 
-    std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(oldclipNode->getOperator());
-    int shift = std::log2(scalingFactor);
-    BitShift_Op::BitShiftDirection direction = BitShift_Op::BitShiftDirection::left;
+    std::shared_ptr<Node> mulNode = nullptr;
+    auto microGraph = quantizerOp->getMicroGraph();
+    for (auto node : microGraph->getNodes())
+        if (node->type() == "Mul")
+            mulNode = node;
 
-    if(shift < 0 )
-    {
-        direction = BitShift_Op::BitShiftDirection::right;
-        shift = -shift;
-    }
+    // Retreive the previous scaling factor
 
-    std::shared_ptr<Node> bitShiftNode = BitShift(direction,bitshiftRounding,(!name.empty()) ? name + "_MulIQuant" : "");
-    std::shared_ptr<Node> clipNode = Clip((!name.empty()) ? name + "_IClipQuant" : "", clipOp->min(), clipOp->max());
+    auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1);
 
-    std::shared_ptr<Tensor> bitshiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {shift});
-    std::shared_ptr<Node> bitshiftProducer = addProducer(bitShiftNode, 1, {1}, "ScalingFactor");
-     
-    bitshiftProducer->getOperator()->setOutput(0, bitshiftTensor);
-    bitshiftProducer->attributes()->addAttr("quantization.ptq.ShiftAmount",shift);
-    bitshiftProducer->getOperator()->setDataType(targetType); 
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
+    double prevScalingFactor = localTensor.get<double>(0);    
 
-    // connect the scaling factor producer
+    // Create the new scaling factor tensor
 
-    bitShiftNode->getOperator()->setDataType(targetType);
-    clipNode->getOperator()->setDataType(targetType);
-    
-    // create the metaop graph
+    std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(prevScalingFactor * coeff);
+    newScalingFactorTensor->setBackend(scalingFactorTensor->backend());
+    newScalingFactorTensor->setDataType(scalingFactorTensor->dataType());
 
-    std::shared_ptr<GraphView> graphView = Sequential({bitShiftNode,clipNode});
-    std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(bitShiftNode); // XXX why not use the graphView ???
+    // Set the tensor of the producer
 
-    // return the metaop 
-    return MetaOperator("BitShiftQuantizer", connectedGraphView, {}, name); // XXX alternative prototype
+    auto producer = mulNode->getParent(1);
+    producer->getOperator()->setOutput(0, newScalingFactorTensor);
 }
-std::shared_ptr<Node> IntQuantizer(std::shared_ptr<Node> oldQuantizer, DataType targetType, const std::string& name)
+
+void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax)
 {
-    std::shared_ptr<Node> castPreNode =  Cast(DataType::Float64,((!name.empty()) ? name + "_PreCast" : ""));
-    std::shared_ptr<Node> castPostNode =  Cast(targetType,((!name.empty()) ? name + "_PostCast" : ""));
+    // Retreive a clone of the microGraph
 
-    castPreNode->getOperator()->setDataType(DataType::Float64);
-    castPostNode->getOperator()->setDataType(targetType);
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    auto microGraph = quantizerOp->getMicroGraph()->clone();
 
-    std::shared_ptr<GraphView> graphView = Sequential({castPreNode, oldQuantizer->clone(), castPostNode});
+    // Save the datatype / backend
 
-    return MetaOperator("IntQuantizer", graphView, {}, name); // XXX alternative prototype
-}
+    auto outputNode = *(microGraph->outputNodes().begin());
+    auto outputOp = std::static_pointer_cast<OperatorTensor> (outputNode->getOperator());
 
-void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor)
-{
-    if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer")
-        Log::warn("Cannot update the scaling factor on Node of type {}", metaOpNode->type());
+    auto dataType = outputOp->getOutput(0)->dataType();
+    auto backend = outputOp->getOutput(0)->backend();
 
-    std::shared_ptr<Tensor> scalingFactorTensor = std::make_shared<Tensor>(Array1D<double, 1> {scalingFactor});
+    // append round
+
+    auto roundNode = Round(quantizer->name() + "_RoundQuant");
+    outputNode->addChild(roundNode, 0, 0);
+    microGraph->add(roundNode);
+
+    // append clip
 
-    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(metaOpNode->getOperator());
+    auto clipNode = Clip(quantizer->name() + "_ClipQuant");
 
-    std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul");
+    auto minTensor = std::make_shared<Tensor>(clipMin);
+    auto minNode = Producer(minTensor);
+    minNode->addChild(clipNode, 0, 1);
+    microGraph->add(minNode);
 
-    if (!mulNode)
-        Log::warn("Invalid PTQ MetaOperator, no Mul node found inside ! ");
+    auto maxTensor = std::make_shared<Tensor>(clipMax);
+    auto maxNode = Producer(maxTensor);   
+    maxNode->addChild(clipNode, 0, 2);
+    microGraph->add(maxNode);
 
-    mulNode->input(1).first->getOperator()->setOutput(0, scalingFactorTensor);
+    roundNode->addChild(clipNode, 0, 0);
+    microGraph->add(clipNode);
+
+    // set the datatype / backend
+
+    microGraph->setDataType(dataType);
+    microGraph->setBackend(backend);
+
+    // create the new quantizer and replace the previous one
+
+    std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name());
+    copyDynamicAttributes(quantizer, newQuantizer);
+    GraphView::replace({quantizer}, {newQuantizer});
+
+    // replace the old pointer with the new one (by reference)
+
+    quantizer = newQuantizer;
 }
 
-double getScalingFactor(std::shared_ptr<Node> MetaOpNode)
+double getScalingFactor(std::shared_ptr<Node> quantizer)
 {
-    if (MetaOpNode->type() != "Scaling" && MetaOpNode->type() != "Quantizer") {
-        Log::warn("Cannot get the scaling factor on Node of type {}", MetaOpNode->type());
-        return 0;
-    }
+    // Retreive the previous microGraph
 
-    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op>(MetaOpNode->getOperator());
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    auto microGraph = quantizerOp->getMicroGraph();
 
-    std::shared_ptr<Node> mulNode = getSubNode(metaOp->getMicroGraph(), "Mul");
+    // Get the Mul node from the microGraph
 
-    if (!mulNode) {
-        Log::warn("Invalid PTQ MetaOperator, no Mul found inside node of type {}", MetaOpNode->type());
-        return 0;
-    }
+    std::shared_ptr<Node> mulNode = nullptr;
+    for (auto node : microGraph->getNodes())
+        if (node->type() == "Mul")
+            mulNode = node;
+
+    auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); 
+
+    // Retreive the scaling factor
+
+    auto scalingFactorTensor = mulOp->getInput(1);
 
-    auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(mulNode->getOperator())->getInput(1);
     std::shared_ptr<Tensor> fallback;
     const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
+    double scalingFactor = localTensor.get<double>(0);
 
-    return localTensor.get<double>(0);
+    return scalingFactor;
 }
 
+void setClipRange(std::shared_ptr<Node> quantizer, double min, double max)
+{
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    auto microGraph = quantizerOp->getMicroGraph();
+
+    std::shared_ptr<Node> clipNode = nullptr;
+    for (auto node : microGraph->getNodes())
+        if (node->type() == "Clip")
+            clipNode = node;
+
+    // TODO : assert that we've got not a nullptr ...
+
+    auto clipOp = std::static_pointer_cast<Clip_Op> (clipNode->getOperator()); 
+
+    // set the attributes
+
+    clipOp->max() = max;
+    clipOp->min() = min;
+
+    // Retreive the previous min/max tensors 
+
+    auto minTensor = std::static_pointer_cast<OperatorTensor>(clipNode->getOperator())->getInput(1);
+    auto maxTensor = std::static_pointer_cast<OperatorTensor>(clipNode->getOperator())->getInput(2);
+
+    // Create the new min/max tensors
+
+    std::shared_ptr<Tensor> newMinTensor = std::make_shared<Tensor>(min);
+    newMinTensor->setBackend(minTensor->backend());
+    newMinTensor->setDataType(minTensor->dataType());
+
+    std::shared_ptr<Tensor> newMaxTensor = std::make_shared<Tensor>(max);
+    newMaxTensor->setBackend(maxTensor->backend());
+    newMaxTensor->setDataType(maxTensor->dataType());
+
+    // Set the tensors of the producer
+
+    auto minProducer = clipNode->getParent(1);
+    minProducer->getOperator()->setOutput(0, newMinTensor);
 
-void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max)
+    auto maxProducer = clipNode->getParent(2);
+    maxProducer->getOperator()->setOutput(0, newMaxTensor);
+}
+
+void removeRound(std::shared_ptr<Node>& quantizer)
 {
-    if (quantizerNode->type() != "Quantizer") {
-        Log::warn("Cannot set the clipping range on Node of type {}", quantizerNode->type());
-        return;
-    }
+    // Retreive a clone of the microGraph
 
-    std::shared_ptr<MetaOperator_Op> metaOp = std::static_pointer_cast<MetaOperator_Op> (quantizerNode->getOperator());
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    auto microGraph = quantizerOp->getMicroGraph()->clone();
 
-    std::shared_ptr<Node> clipNode = getSubNode(metaOp->getMicroGraph(), "Clip");
+    // retreive the rounding node
 
-    if (!clipNode) {
-        Log::warn("Invalid PTQ MetaOperator, no Clip found inside node of type {}", quantizerNode->type());
+    std::shared_ptr<Node> roundNode = nullptr;
+    for (auto node : microGraph->getNodes())
+        if (node->type() == "Round")
+            roundNode = node;
+
+    if (roundNode == nullptr)
         return;
+    
+    // remove the Round node
+
+    microGraph->replace({roundNode}, {});
+
+    // Create the new quantizer and replace the previous one 
+
+    std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name());
+    copyDynamicAttributes(quantizer, newQuantizer);
+    GraphView::replace({quantizer}, {newQuantizer});
+
+    // replace the old pointer with the new one (by reference)
+
+    quantizer = newQuantizer;  
+}
+
+void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer)
+{
+    // Retreive a clone of the microGraph
+
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    auto microGraph = quantizerOp->getMicroGraph()->clone();
+
+    // retreive the multiplicative (scaling) node
+
+    std::shared_ptr<Node> mulNode = nullptr;
+    for (auto node : microGraph->getNodes())
+        if (node->type() == "Mul")
+            mulNode = node;
+
+    auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator());
+
+    // Save the datatype / backend
+
+    auto dataType = mulOp->getOutput(0)->dataType();
+    auto backend = mulOp->getOutput(0)->backend();
+
+    // compute the shift value 
+
+    double scalingFactor = getScalingFactor(quantizer);
+    int bitShiftAmount = -std::round(std::log2(scalingFactor));
+    auto bitShiftDirection = BitShift_Op::BitShiftDirection::right;
+
+    Log::notice(" SHIFT AMOUNT = {} ({})", bitShiftAmount, scalingFactor);
+
+    if (bitShiftAmount < 0 )
+    {
+        bitShiftDirection = BitShift_Op::BitShiftDirection::left;
+        bitShiftAmount = -bitShiftAmount;
     }
 
-    std::shared_ptr<Clip_Op> clipOp = std::static_pointer_cast<Clip_Op>(clipNode->getOperator());
-    clipOp->max() = max;
-    clipOp->min() = min;
+    bool bitShiftRounding = true; // XXX use an argument !!!
+
+    // create the replacement bit-shift nodes
+
+    auto bitShiftNode = BitShift(bitShiftDirection, bitShiftRounding, quantizer->name() + "_BitShiftQuant"); 
+    auto bitShiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {bitShiftAmount});
+
+    auto bitShiftProducer = Producer(bitShiftTensor, "bitShiftAmount");
+    bitShiftProducer->addChild(bitShiftNode, 0, 1);
+
+    // edit the micrograph
+
+    microGraph->replace({mulNode, mulNode->getParent(1)}, {bitShiftNode, bitShiftNode->getParent(1)});
+
+    // set the datatype / backend
+
+    microGraph->setDataType(dataType);
+    microGraph->setBackend(backend);
+
+    // create the new quantizer and replace the previous one
+
+    std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name());
+    copyDynamicAttributes(quantizer, newQuantizer);  
+    GraphView::replace({quantizer}, {newQuantizer});
+
+    // replace the old pointer with the new one (by reference)
+
+    quantizer = newQuantizer;  
+}
+
+void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType)
+{
+    // Retreive a clone of the microGraph
+
+    auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
+    auto microGraph = quantizerOp->getMicroGraph()->clone();
+
+    // Edit the micrograph (insert Cast nodes at it's IOs)
+
+    auto mulNode = *(microGraph->inputNodes().begin()); // TODO : assert that mulNode is a Mul !
+    auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator());
+
+    auto internalType = mulOp->getOutput(0)->dataType();
+    auto castInputNode = Cast(internalType, quantizer->name() + "_CastIn");  
+    auto castOutputNode = Cast(externalType, quantizer->name() + "_CastOut");
+
+    microGraph = Sequential({castInputNode, microGraph, castOutputNode});
+
+    // Set the micrograph datatype 
+
+    microGraph->setDataType(internalType);
+    castOutputNode->getOperator()->setDataType(externalType);
+
+    // Set the micrograph backend 
+
+    auto backend = mulOp->getOutput(0)->backend(); 
+    microGraph->setBackend(backend); 
+
+    // Create the new quantizer and replace the old one
+
+    std::shared_ptr<Node> newQuantizer = MetaOperator("Quantizer", microGraph, {}, quantizer->name());
+    copyDynamicAttributes(quantizer, newQuantizer);  
+    GraphView::replace({quantizer}, {newQuantizer});
+
+    // replace the old pointer with the new one (by reference)
+
+    quantizer = newQuantizer;  
 }
+
 }
\ No newline at end of file