diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 4256774056379969c7406a35e4bcde3ff25c6550..6b36832776146dedcd397491fbaa3771e6558fdd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -21,6 +21,7 @@ include: - '.gitlab/ci/ubuntu_python.gitlab-ci.yml' - '.gitlab/ci/release/cibuildwheel_ubuntu.gitlab-ci.yml' + # Cannot find successful job on aidge_backend_cuda yet # - '.gitlab/ci/windows_cpp.gitlab-ci.yml' # - '.gitlab/ci/windows_python.gitlab-ci.yml' diff --git a/aidge_quantization/__init__.py b/aidge_quantization/__init__.py index b00fae178421997967a79fc9fb0f680ed4afbe84..c321e4695d7a230eda90cc7edc9f3427fc45aa19 100644 --- a/aidge_quantization/__init__.py +++ b/aidge_quantization/__init__.py @@ -1 +1,2 @@ from aidge_quantization.aidge_quantization import * # import so generated by PyBind +from .freezeProducers import * \ No newline at end of file diff --git a/aidge_quantization/_version.py b/aidge_quantization/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..2d34d3557071ed5c22aea83c63bfb7684b180cf9 --- /dev/null +++ b/aidge_quantization/_version.py @@ -0,0 +1,4 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +__version__ = version = '0.2.1.dev60+g8044e79.d20250106' +__version_tuple__ = version_tuple = (0, 2, 1, 'dev60', 'g8044e79.d20250106') \ No newline at end of file diff --git a/aidge_quantization/freezeProducers.py b/aidge_quantization/freezeProducers.py new file mode 100644 index 0000000000000000000000000000000000000000..87839718cce84fa233504aca712c30595797a307 --- /dev/null +++ b/aidge_quantization/freezeProducers.py @@ -0,0 +1,38 @@ +import aidge_core +import aidge_onnx + +def freeze_weights(graphview: aidge_core.GraphView, all_producers: bool = False): + """freeze the weights and bias of Convolution and fully connected nodes. Usage primarly lies so constant folding may be used in those parts of the graph + + :param graphview: model to freeze the weights in + :type graphview: py:class:`aidge_core.GraphView` + :param all_producers: defaults to False, if true, freezes all producers that are part of the wieght input and bias input of the conv or FC + :type all_producers: bool + """ + def freeze_all(node): + for inpt in node.get_parents(): + if inpt is None: + break + elif inpt.type()!= "Producer": + freeze_all(inpt) + else: + inpt.get_operator().attr.set_attr("constant",True) + + #Possible way to have a registry of looked at nodes to prevent unecessary iterations + for node in graphview.get_nodes(): + #Search for Convolution and Fully connected nodes + if node.type() in ["FC","Conv1D", "Conv2D", "Conv3D","ConvDepthWise1D", "ConvDepthWise2D", "ConvDepthWise3D"]: + #iterate over it's weights and if present, bias + for inputs_id in range(node.get_nb_inputs() - 1): + parent_node = node.get_parent(inputs_id + 1) + + #get parent until getting the producer, if directly connected no iteration will be performed + #loop present to also be able to freeze producers so that they can get constant folded + if all_producers: + freeze_all(parent_node) + else: + while(parent_node.type() != "Producer"): + parent_node = parent_node.get_parent(0) + if parent_node is None: + raise RuntimeError(f"Could not find a parent producer for node {node.name()}") + parent_node.get_operator().attr.set_attr("constant",True) diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp index b1e7b6fcf99a50e707da2fdc7f7c35cdb2d778f7..754848a51cfb37571820716f5b0e9396f5fda27d 100644 --- a/include/aidge/quantization/QAT/QAT_LSQ.hpp +++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp @@ -9,30 +9,26 @@ * ********************************************************************************/ - #ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ - #define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ - - #include <cstddef> // std::size_t - #include <memory> - - #include "aidge/data/Tensor.hpp" - #include "aidge/graph/GraphView.hpp" - - namespace Aidge { - namespace QuantLSQ { - - /** - * @brief Given a GraphView with parameters properly initialized, insert - * the LSQ quantizer nodes, and setup the adjustment their step-sizes. - * @param graphView The GraphView containing the network to quantize. - * @param nbBits Number of quantization bits. - */ - - void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); - - } // namespace QuantLSQ - } // namespace Aidge - - #endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */ - - \ No newline at end of file +#ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ +#define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ + +#include <cstddef> // std::size_t +#include <memory> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" + +namespace Aidge { +namespace QuantLSQ { + +/** + * @brief Given a GraphView with parameters properly initialized, insert + * the LSQ quantizer nodes, and setup the adjustment their step-sizes. + * @param graphView The GraphView containing the network to quantize. + * @param nbBits Number of quantization bits. + */ +void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); +} // namespace QuantLSQ +} // namespace Aidge + +#endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */ diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h deleted file mode 100644 index 546263af3a7e8b7a73991173f48d0b095c7d9501..0000000000000000000000000000000000000000 --- a/include/aidge/quantization_version.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef VERSION_H -#define VERSION_H - -namespace Aidge { -static constexpr const int PROJECT_VERSION_MAJOR = 0; -static constexpr const int PROJECT_VERSION_MINOR = 2; -static constexpr const int PROJECT_VERSION_PATCH = 0; -static constexpr const char * PROJECT_VERSION = "0.2.0"; -static constexpr const char * PROJECT_GIT_HASH = "f50c860"; -} -#endif // VERSION_H diff --git a/include/aidge/recipes/ONNXRecipes.hpp b/include/aidge/recipes/ONNXRecipes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..058deeae6e24221862f0ef2acadab406840059be --- /dev/null +++ b/include/aidge/recipes/ONNXRecipes.hpp @@ -0,0 +1,47 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_QUANTIZATION_RECIPES_ONNXRECIPES_H_ +#define AIDGE_QUANTIZATION_RECIPES_ONNXRECIPES_H_ + +#include <memory> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" +#include "aidge/operator/MetaOperator.hpp" + +namespace Aidge { + /** + * @brief Prepare a Aidge model for ONNX export: regroup aidge nodes into quantizelinear,dequantizelinear or qlinearconv operators. + * @param graphView The GraphView to process. + * @param QoperatorFormat if true indicates inclusion of metaoperator qlinearconv, if false qdq or QuantizeDequantize format will be used(see https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#onnx-quantization-representation-format) + */ +void quantizeMatchingtoExport(std::shared_ptr<GraphView> graphView, bool QoperatorFormat = true, bool foldWeights = true); + + /** + * @brief Prepare a Aidge model for ONNX export: regroup aidge nodes into quantizelinear,dequantizelinear or qlinearconv operators. + * @param scalingFactor Scaling factor used in the quantization operation + * @param zeroPoint Zero point used in the quantization operation, for aidge quantization this should always be equal to 0 + * @param basename name used as base for the names of the quantizelinear metaoperator and its components + */ +std::shared_ptr<Node> createQuantizeLinearNode(float scalingFactor = 1.0, uint8_t zeroPoint = 0,const std::string basename = ""); + + /** + * @brief Prepare a Aidge model for ONNX export: regroup aidge nodes into quantizelinear,dequantizelinear or qlinearconv operators. + * @param descalingFactor Scaling factor used in the quantization operation + * @param zeroPoint Zero point used in the quantization operation, for aidge quantization this should always be equal to 0 + * @param castDtype Dtype of the output of the dequantizelinear metaop. This argument may be deprecated in the future because of ONNX's imposed dtypes + * @param basename name used as base for the names of the quantizelinear metaoperator and its components + */ +std::shared_ptr<Node> createDequantizeLinearNode(Tensor descalingFactor, uint8_t zeroPoint,Aidge::DataType castDtype,const std::string basename = ""); +} + +#endif //AIDGE_QUANTIZATION_RECIPES_ONNXRECIPES_H_ diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index 12d14340f9353114d06121fa8f1e1fd4f050e3f4..471d6e5ada638a4bf131d75813df1538e6ebb8c5 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -17,7 +17,6 @@ #include "aidge/quantization/PTQ/Clipping.hpp" #include "aidge/quantization/PTQ/CLE.hpp" #include "aidge/quantization/PTQ/PTQ.hpp" - #include "aidge/graph/GraphView.hpp" namespace py = pybind11; @@ -48,6 +47,14 @@ void init_PTQ(py::module &m) { :type network: :py:class:`aidge_core.GraphView` )mydelimiter"); + m.def( "multiply_scaling_factor",&multiplyScalingFactor,py::arg("node"), py::arg("coeff"), + R"mydelimiter( + Updates the scaling factor of a "Mul" node in a graph if the node is marked as a scaling node. This function multiplies the existing scaling factor by a given coefficient. + :param node: A node representing the node to modify. + :param coeff: A floating value representing the multiplication coefficient to apply to the scaling factor. + )mydelimiter" + ); + m.def("normalize_parameters", &normalizeParameters, py::arg("network"), R"mydelimiter( Normalize the parameters of each parametrized node, so that they fit in the [-1:1] range. diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp index dd118dccc24dca71185c9401a924fbae0d22cc6c..fd263e3f4eff37b74b4155540972611cbd3108cf 100644 --- a/python_binding/pybind_QAT_LSQ.cpp +++ b/python_binding/pybind_QAT_LSQ.cpp @@ -9,21 +9,19 @@ * ********************************************************************************/ - #include <pybind11/pybind11.h> - #include <pybind11/stl.h> - - #include "aidge/quantization/QAT/QAT_LSQ.hpp" - #include "aidge/graph/GraphView.hpp" - - namespace py = pybind11; - - namespace Aidge { - - void init_QAT_LSQ(py::module &m) { - - auto mQuantLSQ = m.def_submodule("lsq"); - - mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits")); - } - } // namespace Aidge - \ No newline at end of file +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/quantization/QAT/QAT_LSQ.hpp" +#include "aidge/graph/GraphView.hpp" + +namespace py = pybind11; + +namespace Aidge { + +void init_QAT_LSQ(py::module &m) { + auto mQuantLSQ = m.def_submodule("lsq"); + mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits")); +} +} // namespace Aidge + diff --git a/python_binding/pybind_Quantization.cpp b/python_binding/pybind_Quantization.cpp index 7ac344dcfcd4fc93e3bba1dcd19c1413f5a29d0c..b91b0b80996c84c7fb1d4906fcb407acca465c63 100644 --- a/python_binding/pybind_Quantization.cpp +++ b/python_binding/pybind_Quantization.cpp @@ -34,6 +34,7 @@ void init_PTQ(py::module &m); void init_QAT_FixedQ(py::module &m); void init_QAT_LSQ(py::module &m); void init_QuantRecipes(py::module &m); +void init_ONNXRecipes(py::module &m); void init_QuantizationVersionInfo(py::module &m); @@ -48,6 +49,7 @@ PYBIND11_MODULE(aidge_quantization, m) init_QAT_FixedQ(m); init_QAT_LSQ(m); init_QuantRecipes(m); + init_ONNXRecipes(m); init_QuantizationVersionInfo(m); } diff --git a/python_binding/recipes/pybind_ONNXRecipes.cpp b/python_binding/recipes/pybind_ONNXRecipes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9664f76782dcdbfff8735e59cdb178b0753bc993 --- /dev/null +++ b/python_binding/recipes/pybind_ONNXRecipes.cpp @@ -0,0 +1,28 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/recipes/ONNXRecipes.hpp" +#include "aidge/graph/GraphView.hpp" + +namespace py = pybind11; + +namespace Aidge { + +void init_ONNXRecipes(py::module &m) { + + m.def("quantize_matching_to_export", &quantizeMatchingtoExport, py::arg("graph_view"), py::arg("qop")=true, py::arg("fold_weights")=true); + +} + +} // namespace Aidge diff --git a/setup.py b/setup.py index cde7c1e513e8f3092474bddcb57842efced415e6..ad42ebeca6d482e353917fa0951e5f8a4414df83 100644 --- a/setup.py +++ b/setup.py @@ -74,30 +74,20 @@ class AidgePkgBuild(build_ext): ) test_onoff = os.environ.get("AIDGE_BUILD_TEST", "OFF") - os.chdir(str(build_temp)) - - cmake_cmd = [ - "cmake", - *build_gen_opts, - str(cwd), - f"-DTEST={test_onoff}", - f"-DCMAKE_INSTALL_PREFIX:PATH={install_path}", - f"-DCMAKE_BUILD_TYPE={build_type}", - f"-DCMAKE_C_COMPILER={c_compiler}", - f"-DCMAKE_CXX_COMPILER={cxx_compiler}", - f"-DENABLE_ASAN={asan}", - f"-DCUDA={with_cuda}", - "-DPYBIND=ON", - f"-DPYBIND_INSTALL_PREFIX:PATH={pybind_install_prefix}", - "-DCMAKE_EXPORT_COMPILE_COMMANDS=1", - "-DCOVERAGE=OFF", - ] - - # Append architecture-specific arguments if provided - if cmake_arch: - cmake_cmd.append(cmake_arch) - - self.spawn(cmake_cmd) + self.spawn( + [ + "cmake", + *build_gen_opts, + str(cwd), + f"-DTEST={test_onoff}", + f"-DCMAKE_INSTALL_PREFIX:PATH={install_path}", + "-DCMAKE_BUILD_TYPE=Debug", #f"-DCMAKE_BUILD_TYPE={compile_type}", + "-DPYBIND=ON", + f"-DPYBIND_INSTALL_PREFIX:PATH={pybind_install_prefix}", + "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", + "-DCOVERAGE=OFF", + ] + ) if not self.dry_run: self.spawn( diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index 7115a2f17726c21666306aad8f75bd51eed3eb29..77c0ebf9fa54a8e9c99d4e038b01fe94c38c6e38 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -39,6 +39,12 @@ #include "aidge/operator/Reshape.hpp" #include "aidge/operator/Round.hpp" +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/operator/Abs.hpp" +#include "aidge/operator/Reshape.hpp" +#include "aidge/operator/Round.hpp" + namespace Aidge { diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp index dcac6819365e134d777be7479a95d6b8e4093b5e..49a38cf7668b5e3fb8d81e6670937be284f9cd8f 100644 --- a/src/QAT/QAT_LSQ.cpp +++ b/src/QAT/QAT_LSQ.cpp @@ -9,164 +9,163 @@ * ********************************************************************************/ - #include "aidge/quantization/QAT/QAT_LSQ.hpp" - #include "aidge/operator/LSQ.hpp" - #include "aidge/operator/ReLU.hpp" - - - #include "aidge/data/Tensor.hpp" - #include "aidge/graph/GraphView.hpp" - #include "aidge/scheduler/SequentialScheduler.hpp" - #include "aidge/scheduler/Scheduler.hpp" - #include "aidge/graph/Matching.hpp" - #include "aidge/recipes/QuantRecipes.hpp" - - - namespace Aidge - { - - static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) - { - auto valueTensor = (*tensor).abs().mean(); - std::shared_ptr<Tensor> fallback; - const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); - return localTensor.get<float>(0); - } - - static float getTensorStd(std::shared_ptr<Tensor> tensor) - { - auto valueTensor = (*tensor); - - auto skewedTensor = valueTensor - valueTensor.mean(); - auto squaredTensor = skewedTensor * skewedTensor; - auto varianceTensor = squaredTensor.mean(); - - std::shared_ptr<Tensor> fallback; - auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); - - float variance = localTensor.get<float>(0); - return std::sqrt(variance); - } - - - // INIT THE STEP SIZE OF A QUANTIZER NODE - - static bool initStepSize(std::shared_ptr<Node> quantizer) - { - const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); - - // This formula is the one proposed in the paper ... - - // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); - // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); - - // .. but this formula seems to work better !!! - - float inputStd = getTensorStd(quantizerOp->getInput(0)); - float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); - - // TODO : use the scalar constructor - auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - - // XXX Manage backend here ? - stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); - stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); - - auto stepSizeProducer = quantizer->getParent(1); - - stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); - - Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); - - return false; - } - - static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) - { - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); - - // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); - - std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; - - // Create the input quantizer node - - auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); - auto quantizerNode = LSQ(signedRange, quantizerName); - - // Init the step-size using the node call stack - - quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - - // Absorb the ReLU when possible ... - - bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? - - if (nodeHasParent) - { - bool allParentsAreReLU = true; - for (auto parentNode : linearNode->getParents()) - if (parentNode->type() != "ReLU") - allParentsAreReLU = false; - - if (allParentsAreReLU) { - auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); - quantizerOp->range() = unsignedRange; - } - - // TODO : remove the ReLUs when possible - } - - // Insert the quantizer in the graphView ... - // (We need to handle the case where the linear node is the first one) - - if (nodeHasParent) { - graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); - } else { - quantizerNode->addChild(graphView); - graphView->add(quantizerNode); - } - } - } - - // PARAM QUANTIZERS INSERTION - - static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) - { - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - - std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); - - // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); - - // TODO : double check this, and use createUniqueName() - auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); - auto quantizerNode = LSQ(signedRange, quantizerName); - - // Init the step-size using the node call stack - - quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - - // Insert the quantizer in the graphView - - graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); - } - } - - void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) - { - sanitizeNodeNames(graphView); - setupInputQuantizers(graphView, nbBits); - setupParamQuantizers(graphView, nbBits); - } - - } \ No newline at end of file +#include "aidge/quantization/QAT/QAT_LSQ.hpp" +#include "aidge/operator/LSQ.hpp" +#include "aidge/operator/ReLU.hpp" + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/scheduler/Scheduler.hpp" +#include "aidge/graph/Matching.hpp" +#include "aidge/recipes/QuantRecipes.hpp" + + +namespace Aidge +{ + +static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor).abs().mean(); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + return localTensor.get<float>(0); +} + +static float getTensorStd(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor); + + auto skewedTensor = valueTensor - valueTensor.mean(); + auto squaredTensor = skewedTensor * skewedTensor; + auto varianceTensor = squaredTensor.mean(); + + std::shared_ptr<Tensor> fallback; + auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + + float variance = localTensor.get<float>(0); + return std::sqrt(variance); +} + + +// INIT THE STEP SIZE OF A QUANTIZER NODE + +static bool initStepSize(std::shared_ptr<Node> quantizer) +{ + const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); + + // This formula is the one proposed in the paper ... + + // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); + // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); + + // .. but this formula seems to work better !!! + + float inputStd = getTensorStd(quantizerOp->getInput(0)); + float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); + + // TODO : use the scalar constructor + auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); + + // XXX Manage backend here ? + stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); + stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); + + auto stepSizeProducer = quantizer->getParent(1); + + stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); + + Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); + + return false; +} + +static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); + + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); + + // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); + + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; + + // Create the input quantizer node + + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); + + // Init the step-size using the node call stack + + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); + + // Absorb the ReLU when possible ... + + bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? + + if (nodeHasParent) + { + bool allParentsAreReLU = true; + for (auto parentNode : linearNode->getParents()) + if (parentNode->type() != "ReLU") + allParentsAreReLU = false; + + if (allParentsAreReLU) { + auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); + quantizerOp->range() = unsignedRange; + } + + // TODO : remove the ReLUs when possible + } + + // Insert the quantizer in the graphView ... + // (We need to handle the case where the linear node is the first one) + + if (nodeHasParent) { + graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); + } else { + quantizerNode->addChild(graphView); + graphView->add(quantizerNode); + } + } +} + +// PARAM QUANTIZERS INSERTION + +static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); + + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); + + // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); + + // TODO : double check this, and use createUniqueName() + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); + + // Init the step-size using the node call stack + + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); + + // Insert the quantizer in the graphView + + graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); + } +} + +void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + sanitizeNodeNames(graphView); + setupInputQuantizers(graphView, nbBits); + setupParamQuantizers(graphView, nbBits); +} + +} \ No newline at end of file diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index f86d454245a7fe088edd027732a91f5775cd2acf..54fbde894921a7ee9526a1b68408e30975d66bd9 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -51,9 +51,18 @@ std::shared_ptr<Node> Quantizer(double scalingFactor, double clipMin, double cli // create the metaop graph std::shared_ptr<GraphView> graphView = Sequential({mulNode, roundNode, clipNode}); + + //Producers added to clip to not have dangling inputs + + std::shared_ptr<Node> clipMinProd = addProducer<1>(clipNode, 1, {}, "Min"); + std::shared_ptr<Node> clipMaxProd = addProducer<1>(clipNode, 2, {}, "Max"); + clipMinProd->getOperator()->setOutput(0,std::make_shared<Tensor>(clipMin)); + clipMaxProd->getOperator()->setOutput(0,std::make_shared<Tensor>(clipMax)); + + std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); // XXX why not use the graphView ??? - // return the metaop + // return the metaop std::shared_ptr<Node> metaopNode = MetaOperator("Quantizer", connectedGraphView, {}, name); // XXX alternative prototype diff --git a/src/recipes/ONNXRecipes.cpp b/src/recipes/ONNXRecipes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..87b9d2c900f8c5718020cc916d5bf1c35f967d31 --- /dev/null +++ b/src/recipes/ONNXRecipes.cpp @@ -0,0 +1,414 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/recipes/ONNXRecipes.hpp" +#include "aidge/graph/Matching.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/Sub.hpp" +#include "aidge/operator/Round.hpp" +#include "aidge/operator/Cast.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/recipes/Recipes.hpp" + + +namespace Aidge { +std::shared_ptr<Node> createQuantizeLinearNode(float scalingFactor, uint8_t zeroPoint,const std::string basename){ + const std::shared_ptr<Node> mulNode = Mul(basename == "" ? "" : basename + "_MulQuant"); + const std::shared_ptr<Node> roundNode = Round(basename == "" ? "" : basename + "_RoundQuant"); + const std::shared_ptr<Node> addNode = Add(basename == "" ? "" : basename + "_AddQuant"); + const std::shared_ptr<Node> castNode = Cast(DataType::UInt8, basename == "" ? "" : basename + "_CastQuant"); + const std::shared_ptr<Node> castAddNode = Cast(DataType::Float32, basename == "" ? "" : basename + "_Cast_ZeroPointQuant"); + + mulNode->getOperator()->setDataType(DataType::Float32); + roundNode->getOperator()->setDataType(DataType::Float32); + castAddNode->getOperator()->setDataType(DataType::Float32); + addNode->getOperator()->setDataType(DataType::Float32); + castNode->getOperator()->setDataType(DataType::UInt8); + + const std::shared_ptr<GraphView> qlGraph = Sequential({mulNode, roundNode, addNode, castNode});//Would be less wasteful to just use multiple addChild? + castAddNode->addChild(addNode,0,1); + const std::shared_ptr<GraphView> Quantizegraph = getConnectedGraphView(castNode); + + const std::shared_ptr<Node> quantizeMetaOp = MetaOperator("QuantizeLinear", Quantizegraph,{},basename == "" ? "" : basename + "_QuantLinear"); + + const std::shared_ptr<Node> addProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{zeroPoint}),addNode->name() == "" ? "" : addNode->name() + "_ZeroPoint",true); + const std::shared_ptr<Node> mulProd = Producer(std::make_shared<Tensor>(Array1D<float, 1>{scalingFactor}),mulNode->name() == "" ? "" : mulNode->name() + "_ScaleFactor",true); + + mulProd->addChild(quantizeMetaOp,0,1); + addProd->addChild(quantizeMetaOp,0,2); + + return quantizeMetaOp; +} + +std::shared_ptr<Node> createDequantizeLinearNode(Tensor descalingFactor, uint8_t zeroPoint,DataType castDtype, const std::string basename){ + const std::shared_ptr<Node> castNode = Cast(castDtype,basename == "" ? "" : basename + "_CastDequant"); + const std::shared_ptr<Node> castSubNode = Cast(castDtype,basename == "" ? "" : basename + "_Cast_ZeroPointDequant"); + const std::shared_ptr<Node> subNode = Sub(basename == "" ? "" : basename + "_SubDequant"); + const std::shared_ptr<Node> mulNode = Mul(basename == "" ? "" : basename + "_MulDequant"); + + castNode->getOperator()->setDataType(castDtype); + subNode->getOperator()->setDataType(DataType::Float32); + mulNode->getOperator()->setDataType(DataType::Float32); + + const std::shared_ptr<GraphView> dequantGraph = Sequential({castNode, subNode, mulNode});//Would be less wasteful to just use multiple addChild? + castSubNode->addChild(subNode,0,1); + const std::shared_ptr<GraphView> dequantizegraph = getConnectedGraphView(mulNode); + + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> dequantOrdInputs = dequantizegraph->getOrderedInputs(); + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> newDequantOrder = {dequantOrdInputs[0],//input + dequantOrdInputs[2],//scaling factor + dequantOrdInputs[1]};//zero point + dequantizegraph->setOrderedInputs(newDequantOrder); + + const std::shared_ptr<Node> dequantMetaOp = MetaOperator("DequantizeLinear", dequantizegraph,{},basename == "" ? "" : basename + "_DequantLinear"); + + //producer must be uint8 for correct zp dtype in onnx export + const std::shared_ptr<Node> subProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{0}),subNode->name() == "" ? "" : subNode->name() + "_ZeroPoint",true); + const std::shared_ptr<Node> mulProd = Producer(std::make_shared<Tensor>(descalingFactor),mulNode->name() == "" ? "" : mulNode->name() + "_ScaleFactor",true); + + mulProd->addChild(dequantMetaOp,0,1); + subProd->addChild(dequantMetaOp,0,2); + + return dequantMetaOp; +} + +void quantizeMatchingtoExport(std::shared_ptr<GraphView> graphView, bool QoperatorFormat, bool foldWeights){ + //Add quantizeLinear Metaop at the beginning of the graph + //according to aidge's quantification, the first input is not quantized so sf of 1 and Zp of 0 is performed + //Operator is added to conform with ONNX's quantize models form + + int inptIdx = 0; + for (const auto& node : graphView->inputNodes()){ + const std::shared_ptr<Node> quantizeLinearNode = createQuantizeLinearNode(1.0,0,"in"+std::to_string(inptIdx)); + const std::shared_ptr<GraphView> quantizeLinearGraph = getConnectedGraphView(quantizeLinearNode); + + graphView->add(quantizeLinearNode); + quantizeLinearGraph->add(node); + //a better function may be used + graphView->addChild(quantizeLinearGraph,std::pair<NodePtr, IOIndex_t>(quantizeLinearNode, IOIndex_t(0)),std::pair<NodePtr, IOIndex_t>(node, IOIndex_t(0))); + inptIdx++; + } + + const std::set<SinglePassGraphMatching::MatchingResult> quantizeMatches = SinglePassGraphMatching(graphView).match("Mul#0->Round?;Mul#0<-Producer#0;Mul#0<1-Producer#1"); + + if(quantizeMatches.size()<1) Log::warn("no matches found to convert to Quantize/Dequantize operators"); + int nbfusions = 0; + //QuantizeLinear Creation from Mul->Round? + //Each quantizeLinear will have an additional Add node(additioning 0) and a DequantizeLinear to conform with quantized ONNX models + for (const auto& match : quantizeMatches) { + // std::shared_ptr<Node> quantMulProd = nullptr; + std::shared_ptr<Node> quantMulOp = nullptr; + for (const auto& node: match.graph->getNodes()){ + if(node->type() == "Mul"){ + quantMulOp = node; + break; + } + } + + AIDGE_ASSERT(quantMulOp != nullptr,"Unexpected error, Mul operator, root of QuantizeLinear, not found"); + + if (!(quantMulOp->attributes()->hasAttr("isScaling")) && !(quantMulOp->attributes()->hasAttr("isProducerScaling"))){ + //Mul operator does not have the 'isScaling or 'isProducerScaling' tag so it is not product of quantization, match skipped + Log::info("mul operator {} skipped, not part of quantization process",quantMulOp->name()); + continue; + } + const std::string mulQuantName = quantMulOp->name(); + SinglePassGraphMatching::MatchingResult quantizeLinearSubGraph = *SinglePassGraphMatching(match.graph).match("Mul#0->Round?").begin(); + + const std::shared_ptr<Node> addNode = Add(mulQuantName == "" ? "" : mulQuantName + "_Add"); + const std::shared_ptr<Node> castNode = Cast(DataType::UInt8,mulQuantName == "" ? "" : mulQuantName + "_Cast"); + addNode->getOperator()->setDataType(DataType::Float32); + castNode->getOperator()->setDataType(DataType::UInt8); + + const std::shared_ptr<GraphView> qlinearGraph = quantizeLinearSubGraph.graph->clone(); + qlinearGraph->addChild(addNode); + qlinearGraph->addChild(castNode); + + const std::shared_ptr<Node> quantMetaOp = MetaOperator("QuantizeLinear", qlinearGraph, {}, mulQuantName == "" ? "" : mulQuantName + "_QuantLinear"); + + //TODO : define datatype of producer tensors CHANGE BACK + const std::shared_ptr<Node> addNodeProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{0}),mulQuantName == "" ? "" : mulQuantName + "_ZeroPoint",true); + addNodeProd->addChild(quantMetaOp,0,2); + + const std::shared_ptr<Tensor> quantizeSF = std::static_pointer_cast<OperatorTensor>(quantMulOp->getParent(1)->clone()->getOperator())->getOutput(0); + const Tensor tempTensor = Tensor(Array1D<float, 1>{1}); + //Dequantize Scaling factor is the inverse of quantize scaling factor + const Tensor dequantizeSF = tempTensor / *quantizeSF; + + const std::shared_ptr<Node> dequantMetaOp = createDequantizeLinearNode(dequantizeSF,0,quantizeSF->dataType(),mulQuantName); + quantMetaOp->addChild(dequantMetaOp,0,0); + const std::shared_ptr<GraphView> metaOpGraph = getConnectedGraphView(dequantMetaOp); + + graphView->replace(quantizeLinearSubGraph.graph, metaOpGraph); + nbfusions++; + } + Log::info("{} QuantizeLinear and DequantizeLinear added", nbfusions); + nbfusions = 0; + + //Modify quantizer so it posseses zero point and conforms with expected metaop in output + const std::set<std::shared_ptr<Node>> nodeList = graphView->getNodes(); + for(const std::shared_ptr<Node> node: nodeList){ + + if(node->type() == "Quantizer"){ + const std::string quantizerName = node->name(); + const std::shared_ptr<MetaOperator_Op> metaNode = std::static_pointer_cast<MetaOperator_Op>(node->getOperator()); + const std::shared_ptr<GraphView> quantizeMicro = metaNode->getMicroGraph(); + + const std::shared_ptr<Node> addNodeProd = Producer(std::make_shared<Tensor>(Array1D<uint8_t, 1>{0}), quantizerName == "" ? "" : quantizerName + "_ZeroPoint", true); + const std::shared_ptr<Node> addNode = Add(quantizerName == "" ? "" : quantizerName + "_Add"); + const std::shared_ptr<Node> castNode = Cast(DataType::UInt8, quantizerName == "" ? "" : quantizerName + "_Cast"); + + addNode->getOperator()->setDataType(DataType::Float32); + castNode->getOperator()->setDataType(DataType::UInt8); + + for(const auto quantNode : quantizeMicro->getNodes()){ + //Shape of quantizer may vary so Clip will be used as root + if(quantNode->type() == "Clip"){ + //parent of clip may be a mul or round node + const std::shared_ptr<Node> oldParent = quantNode->getParent(0); + + oldParent->addChild(addNode,0,0); + addNode->addChild(quantNode,0,0); + quantNode->addChild(castNode,0,0); + break; + } + } + addNodeProd->addChild(addNode,0,1); + quantizeMicro->add(castNode); + quantizeMicro->add(addNode); + quantizeMicro->add(addNodeProd); + + fuseToMetaOps(quantizeMicro,"Clip#0<-Add<-Round?<-Mul; Clip#0<1-Producer#0; Clip#0<2-Producer#1; Clip#0->Cast","QuantizeLinear"); + graphView->add(addNodeProd); + + //debug code: + // Hard coded visualisation and fix of cast with incorrect type===== + Log::debug("debug======"); + for (const auto nde : quantizeMicro->getNodes()){ + if(nde->type() == "QuantizeLinear"){ + Log::debug("{} ==================",nde->name()); + const auto quantigraph = std::static_pointer_cast<MetaOperator_Op>(nde->getOperator())->getMicroGraph(); + for(const auto nde2 : quantigraph->getNodes() ){ + if(nde2->type() == "Cast"){ + Log::debug("-- type {}",nde2->type()); + Log::debug("dtype {}", std::static_pointer_cast<OperatorTensor>(nde2->getOperator())->getOutput(0)->dataType()); + nde2->getOperator()->setDataType(DataType::UInt8); + Log::debug("newdtype {}", std::static_pointer_cast<OperatorTensor>(nde2->getOperator())->getOutput(0)->dataType()); + } + } + } + } + //end debug code======== + + std::shared_ptr<GraphView> replacedGraph = std::make_shared<GraphView>(); + replacedGraph->add(node); + + graphView->replace(replacedGraph, quantizeMicro); + } + } + + const std::set<SinglePassGraphMatching::MatchingResult> wholeQlinearMatches = SinglePassGraphMatching(graphView).match( + //Query is subject to change as quantization operators change + "Conv2D#0<1-DequantizeLinear#0<-QuantizeLinear#0<1-Producer#0;" + "Conv2D#0<1-DequantizeLinear#0;" + "Conv2D#0<2-(DequantizeLinear#1<-QuantizeLinear#1<-Producer#1)?;" + "Conv2D#0<2-(DequantizeLinear#1<-QuantizeLinear#1<1-Producer#2)?;" + "Conv2D#0<2-DequantizeLinear#1?;" + "Conv2D#0<2-(DequantizeLinear#1<1-Producer#0)?;" + "Conv2D#0<2-(DequantizeLinear#1<2-Producer#1)?;" + "Conv2D#0->QuantizeLinear#2" + ); + + if(wholeQlinearMatches.size()<1) Log::warn("No quantized convolutions found"); + + for (const auto match : wholeQlinearMatches) { + bool hasBias = false; + + for (const auto& node: match.graph->getNodes()){ + //Search the convolution node and look for bias presence + if(node->type() == "Conv2D"){ + if(node->getParents().size() > 2) hasBias = true; + //If previous output is quantized add a dequantizelinear node + if (node->getParent(0)->type() == "QlinearConv" || node->getParent(0)->type() == "QuantizeLinear"){ + const std::shared_ptr<Node> quantizeNode = node->getParent(0); + + int idxInput = 1; + if(quantizeNode->type() == "QLinearConv") idxInput = 4; + + const std::shared_ptr<Tensor> quantizeSF = std::static_pointer_cast<OperatorTensor>(quantizeNode->getParent(idxInput)->getOperator())->getOutput(0); + const std::shared_ptr<Node> dequantMetaOp = createDequantizeLinearNode(quantizeSF->clone(),0,quantizeSF->dataType(),node->name()); + const std::shared_ptr<GraphView> dequantGraph = getConnectedGraphView(dequantMetaOp); + + quantizeNode->addChild(dequantMetaOp,0,0); + dequantMetaOp->addChild(node,0,0); + graphView->add(dequantGraph); + match.graph->add(dequantGraph); + } + + //if conv has bias re calculate values of scaling factor and bias + if (QoperatorFormat && hasBias){ + //bias and bias scaling factor have to be modified so it corresponds to ONNX's bias scaling factor formula: biasSF = inputSF * weightSF + + const std::shared_ptr<Tensor> weightSFTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getParent(0)->getParent(1)->getOperator())->getOutput(0); + + std::shared_ptr<Tensor> inputSFTensor; + if(node->getParent(0)->getParent(0)->type() == "QuantizeLinear"){ + inputSFTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(0)->getParent(0)->getParent(1)->getOperator())->getOutput(0); + } + else{ + inputSFTensor = std::make_shared<Tensor>(Array1D<double, 1> {1}); + inputSFTensor->setDataType(weightSFTensor->dataType()); + } + + const std::shared_ptr<Node> biasProd = node->getParent(2)->getParent(0)->getParent(0); + const std::shared_ptr<Node> biasSFProd = node->getParent(2)->getParent(0)->getParent(1); + const std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(biasProd->getOperator())->getOutput(0); + const std::shared_ptr<Tensor> biasSFTensor = std::static_pointer_cast<OperatorTensor>(biasSFProd->getOperator())->getOutput(0); + + const Tensor newBiasSFTensor = *inputSFTensor* *weightSFTensor; + const Tensor newBiasTensor = (*biasSFTensor* *biasTensor)/newBiasSFTensor; + + bool biasProdWasConstant = std::static_pointer_cast<Producer_Op>(biasProd->getOperator())->constant(); + if(biasProdWasConstant){ + const std::shared_ptr<Node> newBiasProd = Producer(std::make_shared<Tensor>(newBiasTensor),biasProd->name(),true); + graphView->replace(std::set<std::shared_ptr<Node>>{biasProd},std::set<std::shared_ptr<Node>>{newBiasProd}); + } + else biasProd->getOperator()->setOutput(0,std::make_shared<Tensor>(newBiasTensor)); + + biasProdWasConstant = std::static_pointer_cast<Producer_Op>(biasSFProd->getOperator())->constant(); + if(biasProdWasConstant){ + const std::shared_ptr<Node> newBiasSFProd = Producer(std::make_shared<Tensor>(newBiasSFTensor),biasSFProd->name(),true); + graphView->replace(std::set<std::shared_ptr<Node>>{biasSFProd},std::set<std::shared_ptr<Node>>{newBiasSFProd}); + } + else biasSFProd->getOperator()->setOutput(0,std::make_shared<Tensor>(newBiasSFTensor)); + + Log::info("Bias and Bias Scaling factor values changed to ONNX standard"); + } + break; //only one conv per match + + } + } + + //if qop desired match for qlinearconv form and create the corresponding metaoperator + if (QoperatorFormat){ + const std::set<SinglePassGraphMatching::MatchingResult> qlinearMatchs = SinglePassGraphMatching(match.graph).match("Conv2D#0<-DequantizeLinear#0;" + "Conv2D#0<1-DequantizeLinear#1;" + "Conv2D#0<2-(DequantizeLinear#2<1-Producer#0)?;" + "Conv2D#0<2-(DequantizeLinear#2<2-Producer#1)?;" + "Conv2D#0->QuantizeLinear"); + //Only one match is present in match.graph + SinglePassGraphMatching::MatchingResult onlyMatch = *qlinearMatchs.begin(); + + //convolution's name to be able to name metaop accordingly + std::string convBaseName; + for (const auto node :onlyMatch.graph->getNodes()) + { + if(node->type() == "Conv2D"){ + convBaseName = node->name(); + break; + } + } + + if(hasBias){ + //metaop/graph inputs reordered to ONNX standard + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> qConvOrdInputs = onlyMatch.graph->getOrderedInputs(); + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> newQConvOrder = {qConvOrdInputs[0],//x input + qConvOrdInputs[1],//x scale + qConvOrdInputs[2],//x zero point + qConvOrdInputs[3],//w + qConvOrdInputs[4],//w scale + qConvOrdInputs[5],//w zero point + qConvOrdInputs[7],//y scale + qConvOrdInputs[8],//y zero point + qConvOrdInputs[6]};//b + onlyMatch.graph->setOrderedInputs(newQConvOrder); + } + + const std::shared_ptr<Node> qlinearMetaOp = MetaOperator("QLinearConv", onlyMatch.graph->clone(),{},convBaseName == "" ? "" : convBaseName+"_QlinearConv"); + const std::shared_ptr<GraphView> metaOpGraph = std::make_shared<GraphView>(); + metaOpGraph->add(qlinearMetaOp, false); + const bool qlinearReplaced = graphView->replace(onlyMatch.graph, metaOpGraph); + AIDGE_ASSERT(qlinearReplaced,"Unexpected error, couldn't replace subgraph with QlinearConv operator") + + if(hasBias){ + //up to current opset qlnearconv bias input must be in int32 + const std::shared_ptr<Node> quantizeLinearB = qlinearMetaOp->getParent(8); + const auto quantizeNodes = std::static_pointer_cast<MetaOperator_Op>(quantizeLinearB->getOperator())->getMicroGraph()->getNodes(); + + //TODO: correct overflow and differences when quantization is performed in Int32 and uint8 (may need to fold in int32 or float and skip this quantizelinear node entirely) + for (const auto node : quantizeNodes){ + const std::string nodeOPtype= node->type(); + if(nodeOPtype == "Cast" ){ + node->getOperator()->setDataType(DataType::Int32); + if(nodeOPtype == "Cast") std::static_pointer_cast<Cast_Op>(node->getOperator())->targetType() = DataType::Int32; + } + } + std::static_pointer_cast<OperatorTensor>(quantizeLinearB->getParent(2)->getOperator())->getOutput(0)->setDataType(DataType::Int32); + } + nbfusions++; + Log::info("{} QlinearConvs added", nbfusions); + } + } + + //add a dequantize node to every output node of types qlinearconv out quantizelinear: (Float output expected by default) + for (const auto& node : graphView->outputNodes()){ + int idxInput; + if(node->type() == "QLinearConv") idxInput = 4; + else if (node->type() == "QuantizeLinear") idxInput = 1; + else continue; + + const std::shared_ptr<Tensor> quantizeSF = std::static_pointer_cast<OperatorTensor>(node->getParent(idxInput)->getOperator())->getOutput(0); + const Tensor tempTensor = Tensor(Array1D<float, 1>{1}); + const Tensor dequantizeSF = tempTensor / *quantizeSF; + + const std::shared_ptr<Node> dequantMetaOp = createDequantizeLinearNode(dequantizeSF,0,quantizeSF->dataType(),node->name()); + const std::shared_ptr<GraphView> dequantGraph = getConnectedGraphView(dequantMetaOp); + + graphView->addChild(dequantGraph,std::pair<NodePtr, IOIndex_t>(node, IOIndex_t(0)),std::pair<NodePtr, IOIndex_t>(dequantMetaOp, IOIndex_t(0))); + } + + graphView->setBackend("cpu");//TODO get dynamically + + //TODO: Bias must be always folded, it may be interesting to just fold when possible instead of giving the choice + if(foldWeights){ + //Fold quantize linear of weights and bias, leaving the quantized producer + const std::set<SinglePassGraphMatching::MatchingResult> foldQuantize = SinglePassGraphMatching(graphView).match( + //find quantizelinears with only producers as input, meaning they can be folded + "QuantizeLinear#0<-Producer#0;" + "QuantizeLinear#0<1-Producer#1;" + "QuantizeLinear#0<2-Producer#2?" + ); + for(const auto match : foldQuantize){ + auto quantizeFolder = SequentialScheduler(match.graph); + quantizeFolder.forward(); + + const auto quantizeLinearNode = *match.graph->outputNodes().begin(); + + const std::shared_ptr<Tensor> foldedTensor = std::make_shared<Tensor>(std::static_pointer_cast<OperatorTensor>((quantizeLinearNode)->getOperator())->getOutput(0)->clone()); + const std::shared_ptr<Node> foldedProd = Producer(foldedTensor, quantizeLinearNode->name(), true); + const std::shared_ptr<GraphView> foldedGraph = std::make_shared<GraphView>(); + + foldedGraph->add(foldedProd); + graphView->replace(match.graph,foldedGraph); + } + } + + } + +} \ No newline at end of file