From f38da961a7249bbc6a704b17d5b9f709eaffd3f0 Mon Sep 17 00:00:00 2001 From: Noam ZERAH <noam.zerah@cea.fr> Date: Tue, 4 Mar 2025 10:19:21 +0000 Subject: [PATCH] first change --- src/PTQ/PTQ.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index dbf4b62..9cbe63f 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -1164,7 +1164,8 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std: tensor->setDataType(dataType); } -void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) +void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, + Clipping clippingMode, DataType targetType,bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose) { Log::notice(" === QUANT PTQ 0.2.21 === "); @@ -1211,7 +1212,18 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::notice(" Performing the Single-Shift approximation ..."); performSingleShiftApproximation(graphView, noQuant); } - + if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) + { + Log::warn("HEREA\n"); + AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") + Log::notice("Starting to cast operators into the desired type ..."); + castQuantizedGraph(graphView,DataType::Int32,singleShift); + Log::warn("HEREB\n"); + } + else + { + setupDataType(graphView, inputDataSet, targetType); + } if (verbose) printScalingFactors(graphView); -- GitLab