From f38da961a7249bbc6a704b17d5b9f709eaffd3f0 Mon Sep 17 00:00:00 2001
From: Noam ZERAH <noam.zerah@cea.fr>
Date: Tue, 4 Mar 2025 10:19:21 +0000
Subject: [PATCH] first change

---
 src/PTQ/PTQ.cpp | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp
index dbf4b62..9cbe63f 100644
--- a/src/PTQ/PTQ.cpp
+++ b/src/PTQ/PTQ.cpp
@@ -1164,7 +1164,8 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std:
         tensor->setDataType(dataType);
 }
 
-void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose)
+void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet,
+    Clipping clippingMode, DataType targetType,bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose)
 {
     Log::notice(" === QUANT PTQ 0.2.21 === ");
 
@@ -1211,7 +1212,18 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
         Log::notice(" Performing the Single-Shift approximation ...");
         performSingleShiftApproximation(graphView, noQuant);
     }
-
+    if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) 
+    {
+        Log::warn("HEREA\n");
+        AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
+        Log::notice("Starting to cast operators into the desired type ...");
+        castQuantizedGraph(graphView,DataType::Int32,singleShift);
+        Log::warn("HEREB\n");
+    }
+    else
+    {
+        setupDataType(graphView, inputDataSet, targetType);
+    }
     if (verbose)
         printScalingFactors(graphView);
 
-- 
GitLab