From 6c9a568d18530bf1d60d63eb51413fb5c826746c Mon Sep 17 00:00:00 2001 From: Noam ZERAH <noam.zerah@cea.fr> Date: Tue, 4 Mar 2025 15:46:39 +0000 Subject: [PATCH] Adding clearInput method to supress Int32 inputs after PTQ --- include/aidge/quantization_version.h | 8 ++++---- src/PTQ/PTQ.cpp | 21 ++++++++++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h index 1ca687d..546263a 100644 --- a/include/aidge/quantization_version.h +++ b/include/aidge/quantization_version.h @@ -3,9 +3,9 @@ namespace Aidge { static constexpr const int PROJECT_VERSION_MAJOR = 0; -static constexpr const int PROJECT_VERSION_MINOR = 3; -static constexpr const int PROJECT_VERSION_PATCH = 1; -static constexpr const char * PROJECT_VERSION = "0.3.1"; -static constexpr const char * PROJECT_GIT_HASH = "418bc3e"; +static constexpr const int PROJECT_VERSION_MINOR = 2; +static constexpr const int PROJECT_VERSION_PATCH = 0; +static constexpr const char * PROJECT_VERSION = "0.2.0"; +static constexpr const char * PROJECT_GIT_HASH = "f50c860"; } #endif // VERSION_H diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index e1500e3..b934665 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -57,7 +57,16 @@ bool isNotQuantized(std::shared_ptr<Node> node) { return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } - +void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView) +{ + for (std::shared_ptr<Aidge::Node> input_node: graphView->inputNodes()) + { + for (Aidge::IOIndex_t index = input_node->getFirstFreeDataInput();index < input_node->getNbFreeDataInputs(); index++) + { + std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->resetInput(index); + } + } +} bool checkArchitecture(std::shared_ptr<GraphView> graphView) { std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"}); @@ -1275,12 +1284,14 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, } //Mandatory to handle all of the newly added connections! graphView->updateInputsOutputs(); + + //Clearing input nodes + Log::notice("Clearing all input nodes ..."); + clearGraphViewInputNodes(graphView); + if (verbose) printScalingFactors(graphView); - - if (useCuda) - // graphView->setBackend("cuda"); - + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling(); -- GitLab