From 6c9a568d18530bf1d60d63eb51413fb5c826746c Mon Sep 17 00:00:00 2001
From: Noam ZERAH <noam.zerah@cea.fr>
Date: Tue, 4 Mar 2025 15:46:39 +0000
Subject: [PATCH] Adding clearInput method to supress Int32 inputs after PTQ

---
 include/aidge/quantization_version.h |  8 ++++----
 src/PTQ/PTQ.cpp                      | 21 ++++++++++++++++-----
 2 files changed, 20 insertions(+), 9 deletions(-)

diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h
index 1ca687d..546263a 100644
--- a/include/aidge/quantization_version.h
+++ b/include/aidge/quantization_version.h
@@ -3,9 +3,9 @@
 
 namespace Aidge {
 static constexpr const int PROJECT_VERSION_MAJOR = 0;
-static constexpr const int PROJECT_VERSION_MINOR = 3;
-static constexpr const int PROJECT_VERSION_PATCH = 1;
-static constexpr const char * PROJECT_VERSION = "0.3.1";
-static constexpr const char * PROJECT_GIT_HASH = "418bc3e";
+static constexpr const int PROJECT_VERSION_MINOR = 2;
+static constexpr const int PROJECT_VERSION_PATCH = 0;
+static constexpr const char * PROJECT_VERSION = "0.2.0";
+static constexpr const char * PROJECT_GIT_HASH = "f50c860";
 }
 #endif // VERSION_H
diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index e1500e3..b934665 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -57,7 +57,16 @@ bool isNotQuantized(std::shared_ptr<Node> node)
 {
     return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
 }
-
+void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView)
+{
+    for (std::shared_ptr<Aidge::Node> input_node: graphView->inputNodes())
+    {
+        for (Aidge::IOIndex_t index = input_node->getFirstFreeDataInput();index < input_node->getNbFreeDataInputs(); index++)
+        {
+            std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->resetInput(index);
+        }
+    }
+}
 bool checkArchitecture(std::shared_ptr<GraphView> graphView)
 {
     std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"});
@@ -1275,12 +1284,14 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
     }
     //Mandatory to handle all of the newly added connections!
     graphView->updateInputsOutputs();
+    
+    //Clearing input nodes
+    Log::notice("Clearing all input nodes ...");
+    clearGraphViewInputNodes(graphView);
+
     if (verbose)
         printScalingFactors(graphView);
-
-    if (useCuda)
-       // graphView->setBackend("cuda");
-
+    
     Log::notice(" Reseting the scheduler ...");
     SequentialScheduler scheduler(graphView);
     scheduler.resetScheduling();
-- 
GitLab