diff --git a/aidge_quantization/unit_tests/assets/BranchNetV4.onnx b/aidge_quantization/unit_tests/assets/BranchNetV4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..34cccc47c4b5014f0adc4757d0b8e362a8e5ddce Binary files /dev/null and b/aidge_quantization/unit_tests/assets/BranchNetV4.onnx differ diff --git a/aidge_quantization/unit_tests/assets/MLP.onnx b/aidge_quantization/unit_tests/assets/MLP.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f6b72dbbd8c829a1d3609d923478887892b9e231 Binary files /dev/null and b/aidge_quantization/unit_tests/assets/MLP.onnx differ diff --git a/aidge_quantization/unit_tests/assets/TestNet.onnx b/aidge_quantization/unit_tests/assets/TestNet.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7f73e9b11d8a2ca43c88e52295dd201211f1e741 Binary files /dev/null and b/aidge_quantization/unit_tests/assets/TestNet.onnx differ diff --git a/cmake/PybindModuleCreation.cmake b/cmake/PybindModuleCreation.cmake index 07cdd658b0e6ae6549b5dfd7663e9973c59c6a9f..e3fe6a7383656e053fe7f89da2fda1083d6374ae 100644 --- a/cmake/PybindModuleCreation.cmake +++ b/cmake/PybindModuleCreation.cmake @@ -4,7 +4,7 @@ function(generate_python_binding pybind_module_name target_to_bind) Include(FetchContent) - set(PYBIND_VERSION v2.10.4) + set(PYBIND_VERSION v2.13.6) message(STATUS "Retrieving pybind ${PYBIND_VERSION} from git") FetchContent_Declare( diff --git a/include/aidge/operator/LSQ.hpp b/include/aidge/operator/LSQ.hpp index 970c476cb7be18b8d001edb27d60079de85b9349..b6abf90371a3053fa7971b9242a5309362ea478e 100644 --- a/include/aidge/operator/LSQ.hpp +++ b/include/aidge/operator/LSQ.hpp @@ -55,7 +55,7 @@ public: */ LSQ_Op(const LSQ_Op& op) : OperatorTensor(op), - mAttributes(op.mAttributes) + mAttributes(std::make_shared<Attributes_>(*op.mAttributes)) { if (op.mImpl){ SET_IMPL_MACRO(LSQ_Op, *this, op.backend()); diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 52edce0a1cad3453bf7dc0ba3f7ec2de0590bf47..3775d18989306ab871702a65f2e5e5d67c5d9c9f 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -20,6 +20,7 @@ #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Mul.hpp" @@ -30,6 +31,9 @@ #include "aidge/operator/ArgMax.hpp" #include "aidge/operator/Reshape.hpp" #include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Cast.hpp" + + #include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/QuantRecipes.hpp" @@ -81,6 +85,20 @@ bool isNotQuantized(std::shared_ptr<Node> node) return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } +std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView) +{ + return graphView->getOrderedInputs()[0].first; +} +void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) +{ + for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes()) + { + for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++) + { + inputNode->getOperator()->resetInput(index); + } + } +} bool checkArchitecture(std::shared_ptr<GraphView> graphView) { std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"}); @@ -235,6 +253,59 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n graphView->add(newNode); } +void applyConstFold(std::shared_ptr<GraphView> &graphView) +{ + for (const std::shared_ptr<Node> node : graphView->getNodes()) + { + if (node->type() == "Producer" ) + { + const auto& producer = std::static_pointer_cast<Producer_Op>(node->getOperator()); + producer->constant() = true; + } + } + constantFolding(graphView); +} + +bool castQuantizedGraph(std::shared_ptr<GraphView> &graphView, Aidge::DataType targetType, bool singleShift,bool bitshiftRounding) +{ + //We need a deepcopy of the graphs nodes since we will replace some nodes + std::vector<std::shared_ptr<Node>> nodeVector(graphView->getNodes().begin(), graphView->getNodes().end()); + + for (std::shared_ptr<Node> node : nodeVector) + { + if (node->type() == "Round" && node->attributes()->hasAttr("quantization.ptq.isProducerRounding")) + { + std::shared_ptr<Aidge::Node> castNode = Cast(targetType,node->name() + "_Cast"); + castNode->getOperator()->setDataType(targetType); + castNode->getOperator()->setBackend(node->getOperator()->backend()); + insertChildren(node,castNode,graphView); + castNode->attributes()->addAttr("quantization.ptq.isProducerCasting",0.0); + node->getOperator()->setDataType(DataType::Float64); + } + else if(node->type() == "Quantizer") + { + if(singleShift) + { + std::shared_ptr<Node> newBitShiftQuantizer = BitShiftQuantizer(node,targetType,bitshiftRounding,node->name()+"_BitShift_Quantizer"); + newBitShiftQuantizer->getOperator()->setBackend(node->getOperator()->backend()); + graphView->replace({node},{newBitShiftQuantizer}); + + } + else //If single shift is not enabled we keep using the alternative Int Quantizer (which cast the data before and after the regular Quantizer Operations) + { + std::shared_ptr<Node> newIntQuantizer = IntQuantizer(node,targetType,node->name()); + newIntQuantizer->getOperator()->setBackend(node->getOperator()->backend()); + graphView->replace({node},{newIntQuantizer}); + } + } + else if (node->type() != "Producer" && + !node->attributes()->hasAttr("quantization.ptq.isProducerScaling")) + { + node->getOperator()->setDataType(targetType); + } + } + return true; +} void foldProducerQuantizers(std::shared_ptr<GraphView> graphView) { @@ -440,10 +511,6 @@ std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> return nodeVector; } -static std::shared_ptr<Node> getFirstNode(std::shared_ptr<GraphView> graphView) -{ - return retrieveNodeVector(graphView)[0]; -} // TODO : enhance this by modifying OperatorImpl in "core" ... static DataType getDataType(std::shared_ptr<Node> node) @@ -605,7 +672,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // ITERATE OVER THE GRAPH ///////////////////////////////////////////////// - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode =getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeVector) { @@ -719,8 +786,6 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< std::unordered_map<std::shared_ptr<Node>, double> valueRanges; std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); - // std::shared_ptr<Node> inputNode = getFirstNode(graphView); - for (std::shared_ptr<Node> node : nodeSet) if ((scalingNodesOnly && hasAttr(node, "isActivationQuantizer")) || (!scalingNodesOnly && (node->type() != "Producer"))) valueRanges.insert(std::make_pair(node, 0)); @@ -788,7 +853,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) { - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode = getFirstNode(graphView); // CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// @@ -888,7 +953,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) { - std::shared_ptr<Node> firstNode = getFirstNode(graphView); + std::shared_ptr<Node> firstNode = getFirstNode(graphView); std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; @@ -1319,13 +1384,33 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::notice(" Folding the Producer's Quantizers ..."); foldProducerQuantizers(graphView); } + if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) + { + AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") + Log::notice("Starting to cast operators into the desired type ..."); + castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding); + + graphView->updateInputsOutputs(); + clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType + } + else + { + setupDataType(graphView, inputDataSet, targetType); + } + if(foldGraph) + { + Log::notice("Applying constant folding recipe to the graph ..."); + applyConstFold(graphView); + } + //Mandatory to handle all of the newly added connections! + graphView->updateInputsOutputs(); + + //Clearing input nodes + Log::notice("Clearing all input nodes ..."); if (verbose) printScalingFactors(graphView); - - if (useCuda) - graphView->setBackend("cuda"); - + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); diff --git a/src/operator/FixedQ.cpp b/src/operator/FixedQ.cpp index 9828ce98f4918b3d2336c57fe018c9129804cf01..ce9a65defc71909a61e03b1d603b6037a777697a 100644 --- a/src/operator/FixedQ.cpp +++ b/src/operator/FixedQ.cpp @@ -22,7 +22,7 @@ const std::string Aidge::FixedQ_Op::Type = "FixedQ"; Aidge::FixedQ_Op::FixedQ_Op(const Aidge::FixedQ_Op& op) : OperatorTensor(op), - mAttributes(op.mAttributes) + mAttributes(std::make_shared<Attributes_>(*op.mAttributes)) { if (op.mImpl){ SET_IMPL_MACRO(FixedQ_Op, *this, op.backend()); diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index 450880d2fd8d553192146e611e6a084abbf73eb7..8686242668e610923abb20c5c002eda0292b2157 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -34,6 +34,12 @@ namespace Aidge { +static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) +{ + std::shared_ptr<Node> mulNode = nullptr; + for(std::shared_ptr<Node> node : graphView->getNodes()) + if (node->type() == nodeType) + mulNode = node; static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr) { diff --git a/src/operator/SAT/DoReFa.cpp b/src/operator/SAT/DoReFa.cpp index 426e330e7f8426d256ca76a843548a91a62b036a..f722631e543c46b1307a372a8d2cb35e65215b2f 100644 --- a/src/operator/SAT/DoReFa.cpp +++ b/src/operator/SAT/DoReFa.cpp @@ -23,7 +23,7 @@ const std::string DoReFa_Op::Type = "DoReFa"; DoReFa_Op::DoReFa_Op(const DoReFa_Op& op) : OperatorTensor(op), - mAttributes(op.mAttributes) + mAttributes(std::make_shared<Attributes_>(*op.mAttributes)) { if (op.mImpl) { SET_IMPL_MACRO(DoReFa_Op, *this, op.backend());