diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index dbf4b6234d2a64c7fd9572c77882a1dcf64f3f49..9cbe63fa36976166af13d6032f045f199514d0bc 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -1164,7 +1164,8 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std: tensor->setDataType(dataType); } -void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) +void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, + Clipping clippingMode, DataType targetType,bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose) { Log::notice(" === QUANT PTQ 0.2.21 === "); @@ -1211,7 +1212,18 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::notice(" Performing the Single-Shift approximation ..."); performSingleShiftApproximation(graphView, noQuant); } - + if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16) + { + Log::warn("HEREA\n"); + AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") + Log::notice("Starting to cast operators into the desired type ..."); + castQuantizedGraph(graphView,DataType::Int32,singleShift); + Log::warn("HEREB\n"); + } + else + { + setupDataType(graphView, inputDataSet, targetType); + } if (verbose) printScalingFactors(graphView);