diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 4256774056379969c7406a35e4bcde3ff25c6550..6b36832776146dedcd397491fbaa3771e6558fdd 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -21,6 +21,7 @@ include:
       - '.gitlab/ci/ubuntu_python.gitlab-ci.yml'
       - '.gitlab/ci/release/cibuildwheel_ubuntu.gitlab-ci.yml'   
  
+      # Cannot find successful job on aidge_backend_cuda yet
       # - '.gitlab/ci/windows_cpp.gitlab-ci.yml'
 
       # - '.gitlab/ci/windows_python.gitlab-ci.yml'   
diff --git a/aidge_quantization/__init__.py b/aidge_quantization/__init__.py
index b00fae178421997967a79fc9fb0f680ed4afbe84..c321e4695d7a230eda90cc7edc9f3427fc45aa19 100644
--- a/aidge_quantization/__init__.py
+++ b/aidge_quantization/__init__.py
@@ -1 +1,2 @@
 from aidge_quantization.aidge_quantization import * # import so generated by PyBind
+from .freezeProducers import *
\ No newline at end of file
diff --git a/aidge_quantization/_version.py b/aidge_quantization/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d34d3557071ed5c22aea83c63bfb7684b180cf9
--- /dev/null
+++ b/aidge_quantization/_version.py
@@ -0,0 +1,4 @@
+# file generated by setuptools_scm
+# don't change, don't track in version control
+__version__ = version = '0.2.1.dev60+g8044e79.d20250106'
+__version_tuple__ = version_tuple = (0, 2, 1, 'dev60', 'g8044e79.d20250106')
\ No newline at end of file
diff --git a/aidge_quantization/freezeProducers.py b/aidge_quantization/freezeProducers.py
new file mode 100644
index 0000000000000000000000000000000000000000..87839718cce84fa233504aca712c30595797a307
--- /dev/null
+++ b/aidge_quantization/freezeProducers.py
@@ -0,0 +1,38 @@
+import aidge_core
+import aidge_onnx
+
+def freeze_weights(graphview: aidge_core.GraphView, all_producers: bool = False):
+    """freeze the weights and bias of Convolution and fully connected nodes. Usage primarly lies so constant folding may be used in those parts of the graph
+
+    :param graphview:  model to freeze the weights in
+    :type graphview: py:class:`aidge_core.GraphView`
+    :param all_producers:  defaults to False, if true, freezes all producers that are part of the wieght input and bias input of the conv or FC
+    :type all_producers: bool
+    """ 
+    def freeze_all(node):
+        for inpt in node.get_parents():
+            if inpt is None:
+                break
+            elif inpt.type()!= "Producer":
+                freeze_all(inpt)
+            else:
+                inpt.get_operator().attr.set_attr("constant",True)
+
+    #Possible way to have a registry of looked at nodes to prevent unecessary iterations
+    for node in graphview.get_nodes():
+        #Search for Convolution and Fully connected nodes
+        if node.type() in ["FC","Conv1D", "Conv2D", "Conv3D","ConvDepthWise1D", "ConvDepthWise2D", "ConvDepthWise3D"]:
+            #iterate over it's weights and if present, bias
+            for inputs_id in range(node.get_nb_inputs() - 1):
+                parent_node = node.get_parent(inputs_id + 1)
+
+                #get parent until getting the producer, if directly connected no iteration will be performed
+                #loop present to also be able to freeze producers so that they can get constant folded
+                if all_producers:
+                    freeze_all(parent_node)
+                else:
+                    while(parent_node.type() != "Producer"):
+                        parent_node = parent_node.get_parent(0)
+                        if parent_node is None:
+                            raise RuntimeError(f"Could not find a parent producer for node {node.name()}")
+                    parent_node.get_operator().attr.set_attr("constant",True)
diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp
index 4970be07fae8737a1c2863600757bb81ff3a65f9..d7d03ca78ff63b328ba068dd4ff82c61270218e3 100644
--- a/include/aidge/quantization/QAT/QAT_LSQ.hpp
+++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp
@@ -20,22 +20,14 @@ namespace Aidge {
 namespace QuantLSQ {
 
 /**
- * @brief Insert the LSQ quantizer nodes in a given GraphView
- * @param graphView The GraphView containing the graph to quantize.
+ * @brief Given a GraphView with parameters properly initialized, insert
+ * the LSQ quantizer nodes, and setup the adjustment their step-sizes.
+ * @param graphView The GraphView containing the network to quantize.
  * @param nbBits Number of quantization bits.
- * @param span Fixed output span of the quantizers.
  */
-void insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float step_size);
+void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
 
-/**
- * @brief Given a GraphView with parameters properly initialized and some calibration data,
- * insert the LSQ quantizer nodes, and adjust their step-sizes.
- * @param graphView The GraphView containing the graph to quantize.
- * @param nbBits Number of quantization bits.
- * @param calibrationData Calibration data used to adjust the spans.
- * @param scale Multiplicative constant applied to the spans.
- */
-void insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData);
+void devLSQ(std::shared_ptr<Tensor> tensor);
 
 }
 }
diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h
index 546263af3a7e8b7a73991173f48d0b095c7d9501..eba0eab45a985df3d4fa2f898d54abf5bce7a1ea 100644
--- a/include/aidge/quantization_version.h
+++ b/include/aidge/quantization_version.h
@@ -3,9 +3,9 @@
 
 namespace Aidge {
 static constexpr const int PROJECT_VERSION_MAJOR = 0;
-static constexpr const int PROJECT_VERSION_MINOR = 2;
+static constexpr const int PROJECT_VERSION_MINOR = 3;
 static constexpr const int PROJECT_VERSION_PATCH = 0;
-static constexpr const char * PROJECT_VERSION = "0.2.0";
-static constexpr const char * PROJECT_GIT_HASH = "f50c860";
+static constexpr const char * PROJECT_VERSION = "0.3.0";
+static constexpr const char * PROJECT_GIT_HASH = "c374ce4";
 }
 #endif // VERSION_H
diff --git a/include/aidge/recipes/ONNXRecipes.hpp b/include/aidge/recipes/ONNXRecipes.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..058deeae6e24221862f0ef2acadab406840059be
--- /dev/null
+++ b/include/aidge/recipes/ONNXRecipes.hpp
@@ -0,0 +1,47 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_QUANTIZATION_RECIPES_ONNXRECIPES_H_
+#define AIDGE_QUANTIZATION_RECIPES_ONNXRECIPES_H_
+
+#include <memory>
+
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/graph/Matching.hpp"
+#include "aidge/operator/MetaOperator.hpp"
+
+namespace Aidge {
+    /**
+     * @brief Prepare a Aidge model for ONNX export: regroup aidge nodes into quantizelinear,dequantizelinear or qlinearconv operators.
+     * @param graphView The GraphView to process.
+     * @param QoperatorFormat if true indicates inclusion of metaoperator qlinearconv, if false qdq or QuantizeDequantize format will be used(see https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#onnx-quantization-representation-format)
+     */
+void quantizeMatchingtoExport(std::shared_ptr<GraphView> graphView, bool QoperatorFormat = true, bool foldWeights = true);
+
+    /**
+     * @brief Prepare a Aidge model for ONNX export: regroup aidge nodes into quantizelinear,dequantizelinear or qlinearconv operators.
+     * @param scalingFactor Scaling factor used in the quantization operation
+     * @param zeroPoint Zero point used in the quantization operation, for aidge quantization this should always be equal to 0
+     * @param basename name used as base for the names of the quantizelinear metaoperator and its components
+     */
+std::shared_ptr<Node> createQuantizeLinearNode(float scalingFactor = 1.0, uint8_t zeroPoint = 0,const std::string basename = "");
+
+    /**
+     * @brief Prepare a Aidge model for ONNX export: regroup aidge nodes into quantizelinear,dequantizelinear or qlinearconv operators.
+     * @param descalingFactor Scaling factor used in the quantization operation
+     * @param zeroPoint Zero point used in the quantization operation, for aidge quantization this should always be equal to 0
+     * @param castDtype Dtype of the output of the dequantizelinear metaop. This argument may be deprecated in the future because of ONNX's imposed dtypes
+     * @param basename name used as base for the names of the quantizelinear metaoperator and its components
+     */
+std::shared_ptr<Node> createDequantizeLinearNode(Tensor descalingFactor, uint8_t zeroPoint,Aidge::DataType castDtype,const std::string basename = "");
+}
+
+#endif //AIDGE_QUANTIZATION_RECIPES_ONNXRECIPES_H_
diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp
index ae0a0def28a861e2fc207adbc27c6af47dc0ded8..8d5390db777c59694aefb175f4567665774bce8e 100644
--- a/python_binding/pybind_PTQ.cpp
+++ b/python_binding/pybind_PTQ.cpp
@@ -17,7 +17,6 @@
 #include "aidge/quantization/PTQ/Clipping.hpp"
 #include "aidge/quantization/PTQ/CLE.hpp"
 #include "aidge/quantization/PTQ/PTQ.hpp"
-
 #include "aidge/graph/GraphView.hpp"
 
 namespace py = pybind11;
@@ -48,6 +47,14 @@ void init_PTQ(py::module &m) {
     :type network: :py:class:`aidge_core.GraphView`
     )mydelimiter");
 
+    m.def( "multiply_scaling_factor",&multiplyScalingFactor,py::arg("node"), py::arg("coeff"),
+     R"mydelimiter(
+    Updates the scaling factor of a "Mul" node in a graph if the node is marked as a scaling node. This function multiplies the existing scaling factor by a given coefficient.
+    :param node: A node representing the node to modify.
+    :param coeff: A floating value representing the multiplication coefficient to apply to the scaling factor.
+    )mydelimiter"
+    );
+    
     m.def("normalize_parameters", &normalizeParameters, py::arg("network"),
     R"mydelimiter(
     Normalize the parameters of each parametrized node, so that they fit in the [-1:1] range.
diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp
index 206985efe4558a84ce1ed67a1264bd6902213764..0b9fcc29d1a144708537084d4538eaa47873cd05 100644
--- a/python_binding/pybind_QAT_LSQ.cpp
+++ b/python_binding/pybind_QAT_LSQ.cpp
@@ -23,8 +23,9 @@ void init_QAT_LSQ(py::module &m) {
 
     auto mQuantLSQ = m.def_submodule("lsq");
 
-    mQuantLSQ.def("insert_quantizers", &QuantLSQ::insertQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("step_size"));
+    mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits"));
+
+    mQuantLSQ.def("dev_lsq", &QuantLSQ::devLSQ, py::arg("tensor"));
 
-    mQuantLSQ.def("insert_and_init_quantizers", &QuantLSQ::insertAndInitQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_data"));
 }
 } // namespace Aidge
diff --git a/python_binding/pybind_Quantization.cpp b/python_binding/pybind_Quantization.cpp
index 7ac344dcfcd4fc93e3bba1dcd19c1413f5a29d0c..b91b0b80996c84c7fb1d4906fcb407acca465c63 100644
--- a/python_binding/pybind_Quantization.cpp
+++ b/python_binding/pybind_Quantization.cpp
@@ -34,6 +34,7 @@ void init_PTQ(py::module &m);
 void init_QAT_FixedQ(py::module &m);
 void init_QAT_LSQ(py::module &m);
 void init_QuantRecipes(py::module &m);
+void init_ONNXRecipes(py::module &m);
 
 void init_QuantizationVersionInfo(py::module &m);
 
@@ -48,6 +49,7 @@ PYBIND11_MODULE(aidge_quantization, m)
     init_QAT_FixedQ(m);
     init_QAT_LSQ(m);
     init_QuantRecipes(m);
+    init_ONNXRecipes(m);
     init_QuantizationVersionInfo(m);
 }
 
diff --git a/python_binding/recipes/pybind_ONNXRecipes.cpp b/python_binding/recipes/pybind_ONNXRecipes.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9664f76782dcdbfff8735e59cdb178b0753bc993
--- /dev/null
+++ b/python_binding/recipes/pybind_ONNXRecipes.cpp
@@ -0,0 +1,28 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+#include "aidge/recipes/ONNXRecipes.hpp"
+#include "aidge/graph/GraphView.hpp"
+
+namespace py = pybind11;
+
+namespace Aidge {
+
+void init_ONNXRecipes(py::module &m) {
+
+    m.def("quantize_matching_to_export", &quantizeMatchingtoExport, py::arg("graph_view"), py::arg("qop")=true, py::arg("fold_weights")=true);
+
+}
+
+} // namespace Aidge
diff --git a/setup.py b/setup.py
index cde7c1e513e8f3092474bddcb57842efced415e6..ad42ebeca6d482e353917fa0951e5f8a4414df83 100644
--- a/setup.py
+++ b/setup.py
@@ -74,30 +74,20 @@ class AidgePkgBuild(build_ext):
         )
         test_onoff = os.environ.get("AIDGE_BUILD_TEST", "OFF")
 
-        os.chdir(str(build_temp))
-
-        cmake_cmd = [
-            "cmake",
-            *build_gen_opts,
-            str(cwd),
-            f"-DTEST={test_onoff}",
-            f"-DCMAKE_INSTALL_PREFIX:PATH={install_path}",
-            f"-DCMAKE_BUILD_TYPE={build_type}",
-            f"-DCMAKE_C_COMPILER={c_compiler}",
-            f"-DCMAKE_CXX_COMPILER={cxx_compiler}",
-            f"-DENABLE_ASAN={asan}",
-            f"-DCUDA={with_cuda}",
-            "-DPYBIND=ON",
-            f"-DPYBIND_INSTALL_PREFIX:PATH={pybind_install_prefix}",
-            "-DCMAKE_EXPORT_COMPILE_COMMANDS=1",
-            "-DCOVERAGE=OFF",
-        ]
-
-        # Append architecture-specific arguments if provided
-        if cmake_arch:
-            cmake_cmd.append(cmake_arch)
-
-        self.spawn(cmake_cmd)
+        self.spawn(
+            [
+                "cmake",
+                *build_gen_opts,
+                str(cwd),
+                f"-DTEST={test_onoff}",
+                f"-DCMAKE_INSTALL_PREFIX:PATH={install_path}",
+                "-DCMAKE_BUILD_TYPE=Debug", #f"-DCMAKE_BUILD_TYPE={compile_type}",
+                "-DPYBIND=ON",
+                f"-DPYBIND_INSTALL_PREFIX:PATH={pybind_install_prefix}",
+                "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON",
+                "-DCOVERAGE=OFF",
+            ]
+        )
 
         if not self.dry_run:
             self.spawn(
diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp
index 2738f8a92154368962e9162fba62c41b7622d07c..56ff71f67e7fc93ef6b488a674da41133a9eb17f 100644
--- a/src/PTQ/CLE.cpp
+++ b/src/PTQ/CLE.cpp
@@ -39,6 +39,12 @@
 #include "aidge/operator/Reshape.hpp"
 #include "aidge/operator/Round.hpp"
 
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/ArgMax.hpp"
+#include "aidge/operator/Abs.hpp"
+#include "aidge/operator/Reshape.hpp"
+#include "aidge/operator/Round.hpp"
+
 namespace Aidge
 {
 
diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp
index 9b51e846df498a9303b7373ae1c86d4b007a96f0..8a42770ac9ff5c9426c0538d407c7f58d0021c15 100644
--- a/src/QAT/QAT_LSQ.cpp
+++ b/src/QAT/QAT_LSQ.cpp
@@ -13,7 +13,6 @@
 #include "aidge/operator/LSQ.hpp"
 #include "aidge/operator/ReLU.hpp"
 
-
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/scheduler/SequentialScheduler.hpp"
@@ -23,7 +22,42 @@
 
 namespace Aidge {
 
-void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float stepSize)
+static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
+{
+    auto valueTensor = (*tensor).abs().mean();
+    std::shared_ptr<Tensor> fallback;
+    const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu");
+    return localTensor.get<float>(0);
+}
+
+// INIT THE STEP SIZE OF A QUANTIZER NODE
+
+static bool initStepSize(std::shared_ptr<Node> quantizer)
+{
+    const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator());
+
+    float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0));
+
+    float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second));
+
+    auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
+
+    // XXX Manage backend here ?
+    stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend());
+    stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType());
+
+    auto stepSizeProducer = quantizer->getParent(1);
+
+    stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor);
+
+    Log::debug("[ INIT STEP SIZE = {} ]",stepSize);
+
+    return false;
+}
+
+// INPUT QUANTIZERS INSERTION
+
+static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
     const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
 
@@ -34,180 +68,76 @@ void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbB
         std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
         std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1};
 
-        // INPUT QUANTIZERS INSERTION
+        // Create the input quantizer node
 
-        // TODO : double check this, and use createUniqueName()
-        auto inputQuantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
-        auto inputQuantizerNode = LSQ(signedRange, inputQuantizerName);
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName);
 
-        // Set the step size
+        // Init the step-size using the node call stack
 
-        auto inputStepSizeOp = inputQuantizerNode->getParent(1)->getOperator();
-        auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        inputStepSizeOp->setOutput(0, inputStepSizeTensor);
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
 
         // Absorb the ReLU when possible ...
 
-        // XXX is this safe ???
-        bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); 
-        // bool nodeHasParent = (linearNode->getParents().size() != 0);
+        bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]);  // XXX is this safe ?
 
         if (nodeHasParent) {
             auto parentNode = linearNode->getParents()[0];
             if (parentNode->type() == "ReLU") {
-                auto inputQuantizerOp = std::static_pointer_cast<LSQ_Op> (inputQuantizerNode->getOperator());
-                inputQuantizerOp->range() = unsignedRange;
+                auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator());
+                quantizerOp->range() = unsignedRange;
                 graphView->replace({parentNode}, {}); 
             }
         }
 
-        // We need to handle the case where the linear node is the first one ...
+        // Insert the quantizer in the graphView ...
+        // (We need to handle the case where the linear node is the first one)
 
         if (nodeHasParent) {
-            graphView->insertParent(linearNode, inputQuantizerNode, 0, 0, 0);
+            graphView->insertParent(linearNode, quantizerNode, 0, 0, 0);
         } else {
-            inputQuantizerNode->addChild(graphView);
-            graphView->add(inputQuantizerNode);
+            quantizerNode->addChild(graphView);
+            graphView->add(quantizerNode);
         }
-
-        // PARAM QUANTIZERS INSERTION
-
-        // TODO : double check this, and use createUniqueName()
-        auto paramQuantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
-        auto paramQuantizerNode = LSQ(signedRange, paramQuantizerName); 
-        graphView->insertParent(linearNode, paramQuantizerNode, 1, 0, 0);
-
-        // Set the step size
-
-        auto paramStepSizeOp = paramQuantizerNode->getParent(1)->getOperator();
-        auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        paramStepSizeOp->setOutput(0, paramStepSizeTensor);
     }
-
 }
 
-static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
-{
-    auto backend = tensor->backend();
-    if (backend == "cuda")
-        tensor->setBackend("cpu");
-
-    float acc = 0;
-    float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr());
-    for(std::size_t i = 0; i < tensor->size(); i++)
-        acc += std::abs(castedTensor[i]);
-    acc /= static_cast<float> (tensor->size());
-
-    if (backend == "cuda")
-        tensor->setBackend("cuda");
-
-    return acc;
-}
+// PARAM QUANTIZERS INSERTION
 
-static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> calibrationData, bool useCuda)
+static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
-    // Propagate the calibration tensor
+    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
 
-    SequentialScheduler scheduler(graphView);
-    scheduler.resetScheduling();
-    scheduler.forward(true, {calibrationData});
+    std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
 
-    // Store the input tensor statistics
+    for (const auto& match : matches) 
+    {       
+        auto linearNode = match.graph->rootNode(); 
 
-    if (useCuda)
-        graphView->setBackend("cpu"); 
+        // TODO : double check this, and use createUniqueName()
+        auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);  
+        auto quantizerNode = LSQ(signedRange, quantizerName); 
 
-    std::map<std::string, float> inputStats;
-    for (auto node : graphView->getNodes())
-    {
-        if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!!
-        {
-            const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator());
-            float inputAbsMean = getTensorAbsMean(op->getInput(0));
-            inputStats.insert(std::make_pair(node->name(), inputAbsMean));
-            fmt::println("{} -> {}", node->name(), inputAbsMean);
-        }
-    }
+        // Init the step-size using the node call stack
 
-    if (useCuda)
-        graphView->setBackend("cuda");
+        quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
 
-    return inputStats;
-}
+        // Insert the quantizer in the graphView
 
-static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> graphView, bool useCuda)
-{
-    if (useCuda)
-        graphView->setBackend("cpu");
-
-    std::map<std::string, float> paramStats;
-    for (auto node : graphView->getNodes())
-    {
-        if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!!
-        {
-            const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator());
-            float paramAbsMean = getTensorAbsMean(op->getInput(1));
-            paramStats.insert(std::make_pair(node->name(), paramAbsMean));
-            fmt::println("{} -> {}", node->name(), paramAbsMean);
-        }
+        graphView->insertParent(linearNode, quantizerNode, 1, 0, 0);
     }
-    
-    if (useCuda)
-        graphView->setBackend("cuda");
-
-    return paramStats;
 }
 
-static void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView, std::map<std::string, float> inputStats, std::map<std::string, float> paramStats)
+void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
 {
-    const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)");
-
-    for (const auto& match : matches) 
-    {
-        auto linearNode = match.graph->rootNode();
-
-        // INPUT QUANTIZERS STEP-SIZES
-
-        auto inputQuantNode = linearNode->getParent(0);
-        auto inputQuantOp = std::static_pointer_cast<LSQ_Op>(inputQuantNode->getOperator());
-
-        float absMean = inputStats[linearNode->name()];
-        float stepSize = 2.0f * (absMean / std::sqrt(inputQuantOp->range().second));
-
-        auto inputStepSizeOp = inputQuantNode->getParent(1)->getOperator();
-        // XXX inputStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})));
-        auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        inputStepSizeOp->setOutput(0, inputStepSizeTensor);
-
-        // PARAM QUANTIZERS STEP-SIZES
-
-        auto paramQuantNode = linearNode->getParent(1);
-        auto paramQuantOp = std::static_pointer_cast<LSQ_Op>(paramQuantNode->getOperator());
-
-        absMean = paramStats[linearNode->name()];
-        stepSize = 2.0f * (absMean / std::sqrt(paramQuantOp->range().second));
-
-        auto paramStepSizeOp = paramQuantNode->getParent(1)->getOperator();
-        // XXX paramStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})));
-        auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
-        paramStepSizeOp->setOutput(0, paramStepSizeTensor);
-    }
+    setupInputQuantizers(graphView, nbBits);
+    setupParamQuantizers(graphView, nbBits);
 }
 
-void QuantLSQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData)
+void QuantLSQ::devLSQ(std::shared_ptr<Tensor> tensor)
 {
-    bool useCuda = (calibrationData->backend() == "cuda");
-
-    // Collect the tensor statisics
-    auto inputStats = collectInputStats(graphView, calibrationData, useCuda);
-
-    auto paramStats = collectParamStats(graphView, useCuda);
-
-    // Insert the quantizers
-    insertQuantizers(graphView, nbBits, 1.0);
-
-    // Adjust the quantizers step-sizes
-    adjustQuantizersStepSizes(graphView, inputStats, paramStats);
+    float mean = (tensor->mean()).get<float> (0);
+    Log::debug("MEAN = {}",mean);
 }
 
 }
\ No newline at end of file
diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp
index f86d454245a7fe088edd027732a91f5775cd2acf..54fbde894921a7ee9526a1b68408e30975d66bd9 100644
--- a/src/operator/PTQMetaOps.cpp
+++ b/src/operator/PTQMetaOps.cpp
@@ -51,9 +51,18 @@ std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double cli
     // create the metaop graph
 
     std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode});
+
+    //Producers added to clip to not have dangling inputs
+
+    std::shared_ptr<Node> clipMinProd = addProducer<1>(clipNode, 1, {}, "Min");
+    std::shared_ptr<Node> clipMaxProd = addProducer<1>(clipNode, 2, {}, "Max");
+    clipMinProd->getOperator()->setOutput(0,std::make_shared<Tensor>(clipMin));
+    clipMaxProd->getOperator()->setOutput(0,std::make_shared<Tensor>(clipMax));
+
+
     std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ???
 
-    // return the metaop
+    // return the metaop 
 
     std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype
 
diff --git a/src/recipes/ONNXRecipes.cpp b/src/recipes/ONNXRecipes.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..87b9d2c900f8c5718020cc916d5bf1c35f967d31
--- /dev/null
+++ b/src/recipes/ONNXRecipes.cpp
@@ -0,0 +1,414 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <memory>
+
+#include "aidge/recipes/ONNXRecipes.hpp"
+#include "aidge/graph/Matching.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/graph/Node.hpp"
+
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/Add.hpp"
+#include "aidge/operator/Sub.hpp"
+#include "aidge/operator/Round.hpp"
+#include "aidge/operator/Cast.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/operator/MetaOperator.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/recipes/Recipes.hpp"
+
+
+namespace Aidge {
+std::shared_ptr<Node> createQuantizeLinearNode(float scalingFactor, uint8_t zeroPoint,const std::string basename){
+    const std::shared_ptr<Node> mulNode =  Mul(basename == "" ? "" : basename + "_MulQuant");
+    const std::shared_ptr<Node> roundNode = Round(basename == "" ? "" : basename + "_RoundQuant");
+    const std::shared_ptr<Node> addNode = Add(basename == "" ? "" : basename + "_AddQuant");
+    const std::shared_ptr<Node> castNode = Cast(DataType::UInt8, basename == "" ? "" : basename + "_CastQuant");
+    const std::shared_ptr<Node> castAddNode = Cast(DataType::Float32, basename == "" ? "" : basename + "_Cast_ZeroPointQuant");
+
+    mulNode->getOperator()->setDataType(DataType::Float32);
+    roundNode->getOperator()->setDataType(DataType::Float32);
+    castAddNode->getOperator()->setDataType(DataType::Float32);
+    addNode->getOperator()->setDataType(DataType::Float32);
+    castNode->getOperator()->setDataType(DataType::UInt8);
+
+    const std::shared_ptr<GraphView> qlGraph = Sequential({mulNode, roundNode, addNode, castNode});//Would be less wasteful to just use multiple addChild?
+    castAddNode->addChild(addNode,0,1);
+    const std::shared_ptr<GraphView> Quantizegraph = getConnectedGraphView(castNode);
+
+    const std::shared_ptr<Node> quantizeMetaOp = MetaOperator("QuantizeLinear", Quantizegraph,{},basename == "" ? "" : basename + "_QuantLinear");
+
+    const std::shared_ptr<Node> addProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{zeroPoint}),addNode->name() == "" ? "" : addNode->name() + "_ZeroPoint",true);
+    const std::shared_ptr<Node> mulProd = Producer(std::make_shared<Tensor>(Array1D<float, 1>{scalingFactor}),mulNode->name() == "" ? "" : mulNode->name() + "_ScaleFactor",true);
+
+    mulProd->addChild(quantizeMetaOp,0,1);
+    addProd->addChild(quantizeMetaOp,0,2);
+
+    return quantizeMetaOp;
+}
+
+std::shared_ptr<Node> createDequantizeLinearNode(Tensor descalingFactor, uint8_t zeroPoint,DataType castDtype, const std::string basename){
+    const std::shared_ptr<Node> castNode = Cast(castDtype,basename == "" ? "" : basename + "_CastDequant");
+    const std::shared_ptr<Node> castSubNode = Cast(castDtype,basename == "" ? "" : basename + "_Cast_ZeroPointDequant");
+    const std::shared_ptr<Node> subNode = Sub(basename == "" ? "" : basename + "_SubDequant");
+    const std::shared_ptr<Node> mulNode = Mul(basename == "" ? "" : basename + "_MulDequant");
+
+    castNode->getOperator()->setDataType(castDtype);
+    subNode->getOperator()->setDataType(DataType::Float32);
+    mulNode->getOperator()->setDataType(DataType::Float32);
+
+    const std::shared_ptr<GraphView> dequantGraph = Sequential({castNode, subNode, mulNode});//Would be less wasteful to just use multiple addChild?
+    castSubNode->addChild(subNode,0,1);
+    const std::shared_ptr<GraphView> dequantizegraph = getConnectedGraphView(mulNode);
+
+    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> dequantOrdInputs = dequantizegraph->getOrderedInputs();
+    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> newDequantOrder = {dequantOrdInputs[0],//input
+                                                                                        dequantOrdInputs[2],//scaling factor
+                                                                                        dequantOrdInputs[1]};//zero point
+    dequantizegraph->setOrderedInputs(newDequantOrder);
+
+    const std::shared_ptr<Node> dequantMetaOp = MetaOperator("DequantizeLinear", dequantizegraph,{},basename == "" ? "" : basename + "_DequantLinear");
+
+    //producer must be uint8 for correct zp dtype in onnx export
+    const std::shared_ptr<Node> subProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{0}),subNode->name() == "" ? "" : subNode->name() + "_ZeroPoint",true);
+    const std::shared_ptr<Node> mulProd = Producer(std::make_shared<Tensor>(descalingFactor),mulNode->name() == "" ? "" : mulNode->name() + "_ScaleFactor",true);
+
+    mulProd->addChild(dequantMetaOp,0,1);
+    subProd->addChild(dequantMetaOp,0,2);
+
+    return dequantMetaOp;
+}
+
+void quantizeMatchingtoExport(std::shared_ptr<GraphView> graphView, bool QoperatorFormat, bool foldWeights){
+    //Add quantizeLinear Metaop at the beginning of the graph
+    //according to aidge's quantification, the first input is not quantized so sf of 1 and Zp of 0 is performed
+    //Operator is added to conform with ONNX's quantize models form
+
+    int inptIdx = 0;
+    for (const auto& node : graphView->inputNodes()){
+        const std::shared_ptr<Node> quantizeLinearNode = createQuantizeLinearNode(1.0,0,"in"+std::to_string(inptIdx));
+        const std::shared_ptr<GraphView> quantizeLinearGraph = getConnectedGraphView(quantizeLinearNode);
+
+        graphView->add(quantizeLinearNode);
+        quantizeLinearGraph->add(node);
+        //a better function may be used 
+        graphView->addChild(quantizeLinearGraph,std::pair<NodePtr, IOIndex_t>(quantizeLinearNode, IOIndex_t(0)),std::pair<NodePtr, IOIndex_t>(node, IOIndex_t(0)));
+        inptIdx++;
+    }
+
+    const std::set<SinglePassGraphMatching::MatchingResult> quantizeMatches = SinglePassGraphMatching(graphView).match("Mul#0->Round?;Mul#0<-Producer#0;Mul#0<1-Producer#1");
+
+    if(quantizeMatches.size()<1) Log::warn("no matches found to convert to Quantize/Dequantize operators");
+    int nbfusions = 0;
+    //QuantizeLinear Creation from Mul->Round?
+    //Each quantizeLinear will have an additional Add node(additioning 0) and a DequantizeLinear to conform with quantized ONNX models
+    for (const auto& match : quantizeMatches) {
+        // std::shared_ptr<Node> quantMulProd = nullptr;
+        std::shared_ptr<Node> quantMulOp =  nullptr;
+        for (const auto& node: match.graph->getNodes()){
+            if(node->type() == "Mul"){
+                quantMulOp = node;
+                break;
+            }
+        }
+
+        AIDGE_ASSERT(quantMulOp != nullptr,"Unexpected error, Mul operator, root of QuantizeLinear, not found");
+
+        if (!(quantMulOp->attributes()->hasAttr("isScaling")) && !(quantMulOp->attributes()->hasAttr("isProducerScaling"))){
+            //Mul operator does not have the 'isScaling or 'isProducerScaling' tag so it is not product of quantization, match skipped
+            Log::info("mul operator {} skipped, not part of quantization process",quantMulOp->name());
+            continue;
+        }
+        const std::string mulQuantName = quantMulOp->name();
+        SinglePassGraphMatching::MatchingResult quantizeLinearSubGraph = *SinglePassGraphMatching(match.graph).match("Mul#0->Round?").begin();
+
+        const std::shared_ptr<Node> addNode = Add(mulQuantName == "" ? "" : mulQuantName + "_Add");
+        const std::shared_ptr<Node> castNode = Cast(DataType::UInt8,mulQuantName == "" ? "" : mulQuantName + "_Cast");
+        addNode->getOperator()->setDataType(DataType::Float32);
+        castNode->getOperator()->setDataType(DataType::UInt8);
+
+        const std::shared_ptr<GraphView> qlinearGraph = quantizeLinearSubGraph.graph->clone();
+        qlinearGraph->addChild(addNode);
+        qlinearGraph->addChild(castNode);
+
+        const std::shared_ptr<Node> quantMetaOp = MetaOperator("QuantizeLinear", qlinearGraph, {}, mulQuantName == "" ? "" : mulQuantName + "_QuantLinear");
+
+        //TODO : define datatype of producer tensors CHANGE BACK
+        const std::shared_ptr<Node> addNodeProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{0}),mulQuantName == "" ? "" : mulQuantName + "_ZeroPoint",true);
+        addNodeProd->addChild(quantMetaOp,0,2);
+
+        const std::shared_ptr<Tensor> quantizeSF = std::static_pointer_cast<OperatorTensor>(quantMulOp->getParent(1)->clone()->getOperator())->getOutput(0);
+        const Tensor tempTensor = Tensor(Array1D<float, 1>{1});
+        //Dequantize Scaling factor is the inverse of quantize scaling factor
+        const Tensor dequantizeSF = tempTensor / *quantizeSF;
+
+        const std::shared_ptr<Node> dequantMetaOp = createDequantizeLinearNode(dequantizeSF,0,quantizeSF->dataType(),mulQuantName);
+        quantMetaOp->addChild(dequantMetaOp,0,0);
+        const std::shared_ptr<GraphView> metaOpGraph = getConnectedGraphView(dequantMetaOp);
+
+        graphView->replace(quantizeLinearSubGraph.graph, metaOpGraph);
+        nbfusions++;
+    }
+    Log::info("{} QuantizeLinear and DequantizeLinear added", nbfusions);
+    nbfusions = 0;
+
+    //Modify quantizer so it posseses zero point and conforms with expected metaop in output
+    const std::set<std::shared_ptr<Node>> nodeList = graphView->getNodes();
+    for(const std::shared_ptr<Node> node: nodeList){
+
+        if(node->type() == "Quantizer"){
+            const std::string quantizerName = node->name();
+            const std::shared_ptr<MetaOperator_Op> metaNode = std::static_pointer_cast<MetaOperator_Op>(node->getOperator());
+            const std::shared_ptr<GraphView> quantizeMicro = metaNode->getMicroGraph();
+
+            const std::shared_ptr<Node> addNodeProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{0}), quantizerName == "" ? "" : quantizerName + "_ZeroPoint", true);
+            const std::shared_ptr<Node> addNode = Add(quantizerName == "" ? "" : quantizerName + "_Add");
+            const std::shared_ptr<Node> castNode = Cast(DataType::UInt8, quantizerName == "" ? "" : quantizerName + "_Cast");
+
+            addNode->getOperator()->setDataType(DataType::Float32);
+            castNode->getOperator()->setDataType(DataType::UInt8);
+
+            for(const auto quantNode : quantizeMicro->getNodes()){
+                //Shape of quantizer may vary so Clip will be used as root
+                if(quantNode->type() == "Clip"){
+                    //parent of clip may be a mul or round node
+                    const std::shared_ptr<Node> oldParent = quantNode->getParent(0);
+
+                    oldParent->addChild(addNode,0,0);
+                    addNode->addChild(quantNode,0,0);
+                    quantNode->addChild(castNode,0,0);
+                    break;
+                }
+            }
+            addNodeProd->addChild(addNode,0,1);
+            quantizeMicro->add(castNode);
+            quantizeMicro->add(addNode);
+            quantizeMicro->add(addNodeProd);
+
+            fuseToMetaOps(quantizeMicro,"Clip#0<-Add<-Round?<-Mul; Clip#0<1-Producer#0; Clip#0<2-Producer#1; Clip#0->Cast","QuantizeLinear");
+            graphView->add(addNodeProd);
+
+            //debug code:
+            // Hard coded visualisation and fix of cast with incorrect type=====
+            Log::debug("debug======");
+            for (const auto nde : quantizeMicro->getNodes()){
+                if(nde->type() == "QuantizeLinear"){
+                    Log::debug("{} ==================",nde->name());
+                    const auto quantigraph = std::static_pointer_cast<MetaOperator_Op>(nde->getOperator())->getMicroGraph();
+                    for(const auto nde2 : quantigraph->getNodes() ){
+                        if(nde2->type() == "Cast"){
+                            Log::debug("-- type {}",nde2->type());
+                            Log::debug("dtype {}", std::static_pointer_cast<OperatorTensor>(nde2->getOperator())->getOutput(0)->dataType());
+                            nde2->getOperator()->setDataType(DataType::UInt8);
+                            Log::debug("newdtype {}", std::static_pointer_cast<OperatorTensor>(nde2->getOperator())->getOutput(0)->dataType());
+                        }
+                    }
+                }
+            }
+            //end debug code========
+            
+            std::shared_ptr<GraphView> replacedGraph = std::make_shared<GraphView>();
+            replacedGraph->add(node);
+
+            graphView->replace(replacedGraph, quantizeMicro);
+        }
+    }
+
+    const std::set<SinglePassGraphMatching::MatchingResult> wholeQlinearMatches = SinglePassGraphMatching(graphView).match(
+    //Query is subject to change as quantization operators change
+        "Conv2D#0<1-DequantizeLinear#0<-QuantizeLinear#0<1-Producer#0;"
+        "Conv2D#0<1-DequantizeLinear#0;"
+        "Conv2D#0<2-(DequantizeLinear#1<-QuantizeLinear#1<-Producer#1)?;"
+        "Conv2D#0<2-(DequantizeLinear#1<-QuantizeLinear#1<1-Producer#2)?;"
+        "Conv2D#0<2-DequantizeLinear#1?;"
+        "Conv2D#0<2-(DequantizeLinear#1<1-Producer#0)?;"
+        "Conv2D#0<2-(DequantizeLinear#1<2-Producer#1)?;"
+        "Conv2D#0->QuantizeLinear#2"
+    );
+
+    if(wholeQlinearMatches.size()<1) Log::warn("No quantized convolutions found");
+
+    for (const auto match : wholeQlinearMatches) {
+        bool hasBias = false;
+
+        for (const auto& node: match.graph->getNodes()){
+            //Search the convolution node and look for bias presence
+            if(node->type() == "Conv2D"){
+                if(node->getParents().size() > 2) hasBias = true;
+                //If previous output is quantized add a dequantizelinear node
+                if  (node->getParent(0)->type() == "QlinearConv" || node->getParent(0)->type() == "QuantizeLinear"){
+                    const std::shared_ptr<Node> quantizeNode = node->getParent(0);
+
+                    int idxInput = 1;
+                    if(quantizeNode->type() == "QLinearConv") idxInput = 4;
+                    
+                    const std::shared_ptr<Tensor> quantizeSF = std::static_pointer_cast<OperatorTensor>(quantizeNode->getParent(idxInput)->getOperator())->getOutput(0);
+                    const std::shared_ptr<Node> dequantMetaOp = createDequantizeLinearNode(quantizeSF->clone(),0,quantizeSF->dataType(),node->name());
+                    const std::shared_ptr<GraphView> dequantGraph = getConnectedGraphView(dequantMetaOp);
+                    
+                    quantizeNode->addChild(dequantMetaOp,0,0);
+                    dequantMetaOp->addChild(node,0,0);
+                    graphView->add(dequantGraph);
+                    match.graph->add(dequantGraph);
+                }
+
+                //if conv has bias re calculate values of scaling factor and bias
+                if (QoperatorFormat && hasBias){
+                    //bias and bias scaling factor have to be modified so it corresponds to ONNX's bias scaling factor formula: biasSF = inputSF * weightSF
+
+                    const std::shared_ptr<Tensor> weightSFTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getParent(0)->getParent(1)->getOperator())->getOutput(0);
+
+                    std::shared_ptr<Tensor> inputSFTensor;
+                    if(node->getParent(0)->getParent(0)->type() == "QuantizeLinear"){
+                        inputSFTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(0)->getParent(0)->getParent(1)->getOperator())->getOutput(0);
+                    }
+                    else{
+                        inputSFTensor = std::make_shared<Tensor>(Array1D<double, 1> {1});
+                        inputSFTensor->setDataType(weightSFTensor->dataType());
+                    }
+
+                    const std::shared_ptr<Node> biasProd = node->getParent(2)->getParent(0)->getParent(0);
+                    const std::shared_ptr<Node> biasSFProd = node->getParent(2)->getParent(0)->getParent(1);
+                    const std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(biasProd->getOperator())->getOutput(0);
+                    const std::shared_ptr<Tensor> biasSFTensor = std::static_pointer_cast<OperatorTensor>(biasSFProd->getOperator())->getOutput(0);
+                    
+                    const Tensor newBiasSFTensor = *inputSFTensor* *weightSFTensor;
+                    const Tensor newBiasTensor = (*biasSFTensor* *biasTensor)/newBiasSFTensor;
+
+                    bool biasProdWasConstant = std::static_pointer_cast<Producer_Op>(biasProd->getOperator())->constant();
+                    if(biasProdWasConstant){
+                        const std::shared_ptr<Node> newBiasProd = Producer(std::make_shared<Tensor>(newBiasTensor),biasProd->name(),true);
+                        graphView->replace(std::set<std::shared_ptr<Node>>{biasProd},std::set<std::shared_ptr<Node>>{newBiasProd});
+                    } 
+                    else biasProd->getOperator()->setOutput(0,std::make_shared<Tensor>(newBiasTensor));
+
+                    biasProdWasConstant = std::static_pointer_cast<Producer_Op>(biasSFProd->getOperator())->constant();
+                    if(biasProdWasConstant){
+                        const std::shared_ptr<Node> newBiasSFProd = Producer(std::make_shared<Tensor>(newBiasSFTensor),biasSFProd->name(),true);
+                        graphView->replace(std::set<std::shared_ptr<Node>>{biasSFProd},std::set<std::shared_ptr<Node>>{newBiasSFProd});
+                    } 
+                    else biasSFProd->getOperator()->setOutput(0,std::make_shared<Tensor>(newBiasSFTensor));
+
+                    Log::info("Bias and Bias Scaling factor values changed to ONNX standard");
+                    }
+                break; //only one conv per match
+
+                }
+            }
+
+        //if qop desired match for qlinearconv form and create the corresponding metaoperator
+        if (QoperatorFormat){
+            const std::set<SinglePassGraphMatching::MatchingResult> qlinearMatchs = SinglePassGraphMatching(match.graph).match("Conv2D#0<-DequantizeLinear#0;"
+                                                                            "Conv2D#0<1-DequantizeLinear#1;"
+                                                                            "Conv2D#0<2-(DequantizeLinear#2<1-Producer#0)?;"
+                                                                            "Conv2D#0<2-(DequantizeLinear#2<2-Producer#1)?;"
+                                                                            "Conv2D#0->QuantizeLinear");
+            //Only one match is present in match.graph
+            SinglePassGraphMatching::MatchingResult onlyMatch = *qlinearMatchs.begin();
+            
+            //convolution's name to be able to name metaop accordingly
+            std::string convBaseName;
+            for (const auto node :onlyMatch.graph->getNodes())
+            {
+                if(node->type() == "Conv2D"){
+                    convBaseName = node->name();
+                    break;
+                } 
+            }
+
+            if(hasBias){
+                //metaop/graph inputs reordered to ONNX standard
+                const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> qConvOrdInputs = onlyMatch.graph->getOrderedInputs();
+                const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> newQConvOrder = {qConvOrdInputs[0],//x input
+                                                                                            qConvOrdInputs[1],//x scale
+                                                                                            qConvOrdInputs[2],//x zero point
+                                                                                            qConvOrdInputs[3],//w 
+                                                                                            qConvOrdInputs[4],//w scale
+                                                                                            qConvOrdInputs[5],//w zero point
+                                                                                            qConvOrdInputs[7],//y scale
+                                                                                            qConvOrdInputs[8],//y zero point
+                                                                                            qConvOrdInputs[6]};//b
+                onlyMatch.graph->setOrderedInputs(newQConvOrder);
+            }
+
+            const std::shared_ptr<Node> qlinearMetaOp = MetaOperator("QLinearConv", onlyMatch.graph->clone(),{},convBaseName == "" ? "" : convBaseName+"_QlinearConv");
+            const std::shared_ptr<GraphView> metaOpGraph = std::make_shared<GraphView>();
+            metaOpGraph->add(qlinearMetaOp, false);
+            const bool qlinearReplaced = graphView->replace(onlyMatch.graph, metaOpGraph);
+            AIDGE_ASSERT(qlinearReplaced,"Unexpected error, couldn't replace subgraph with QlinearConv operator")
+
+            if(hasBias){
+                //up to current opset qlnearconv bias input must be in int32
+                const std::shared_ptr<Node> quantizeLinearB = qlinearMetaOp->getParent(8);
+                const auto quantizeNodes = std::static_pointer_cast<MetaOperator_Op>(quantizeLinearB->getOperator())->getMicroGraph()->getNodes();
+
+                //TODO: correct overflow and differences when quantization is performed in Int32 and uint8 (may need to fold in int32 or float and skip this quantizelinear node entirely)
+                for (const auto node : quantizeNodes){
+                    const std::string nodeOPtype= node->type();
+                    if(nodeOPtype == "Cast" ){
+                        node->getOperator()->setDataType(DataType::Int32);
+                        if(nodeOPtype == "Cast") std::static_pointer_cast<Cast_Op>(node->getOperator())->targetType() = DataType::Int32;
+                    } 
+                }
+                std::static_pointer_cast<OperatorTensor>(quantizeLinearB->getParent(2)->getOperator())->getOutput(0)->setDataType(DataType::Int32);
+            }
+            nbfusions++;
+            Log::info("{} QlinearConvs added", nbfusions);    
+        }
+    }
+
+    //add a dequantize node to every output node of types qlinearconv out quantizelinear: (Float output expected by default)
+    for (const auto& node : graphView->outputNodes()){
+        int idxInput;
+        if(node->type() == "QLinearConv") idxInput = 4;
+        else if (node->type() == "QuantizeLinear") idxInput = 1;
+        else continue;
+
+        const std::shared_ptr<Tensor> quantizeSF = std::static_pointer_cast<OperatorTensor>(node->getParent(idxInput)->getOperator())->getOutput(0);
+        const Tensor tempTensor = Tensor(Array1D<float, 1>{1});
+        const Tensor dequantizeSF = tempTensor / *quantizeSF;
+
+        const std::shared_ptr<Node> dequantMetaOp = createDequantizeLinearNode(dequantizeSF,0,quantizeSF->dataType(),node->name());
+        const std::shared_ptr<GraphView> dequantGraph = getConnectedGraphView(dequantMetaOp);
+
+        graphView->addChild(dequantGraph,std::pair<NodePtr, IOIndex_t>(node, IOIndex_t(0)),std::pair<NodePtr, IOIndex_t>(dequantMetaOp, IOIndex_t(0)));
+    }
+
+    graphView->setBackend("cpu");//TODO get dynamically
+
+    //TODO: Bias must be always folded, it may be interesting to just fold when possible instead of giving the choice
+    if(foldWeights){
+        //Fold quantize linear of weights and bias, leaving the quantized producer
+        const std::set<SinglePassGraphMatching::MatchingResult> foldQuantize = SinglePassGraphMatching(graphView).match(
+            //find quantizelinears with only producers as input, meaning they can be folded
+                "QuantizeLinear#0<-Producer#0;"
+                "QuantizeLinear#0<1-Producer#1;"
+                "QuantizeLinear#0<2-Producer#2?"
+        );
+        for(const auto match : foldQuantize){
+            auto quantizeFolder = SequentialScheduler(match.graph);
+            quantizeFolder.forward();
+    
+            const auto quantizeLinearNode = *match.graph->outputNodes().begin();
+    
+            const std::shared_ptr<Tensor> foldedTensor = std::make_shared<Tensor>(std::static_pointer_cast<OperatorTensor>((quantizeLinearNode)->getOperator())->getOutput(0)->clone());
+            const std::shared_ptr<Node> foldedProd = Producer(foldedTensor, quantizeLinearNode->name(), true);
+            const std::shared_ptr<GraphView> foldedGraph = std::make_shared<GraphView>();
+
+            foldedGraph->add(foldedProd);
+            graphView->replace(match.graph,foldedGraph);
+        }
+    }
+
+    }
+
+}
\ No newline at end of file