Skip to content
Snippets Groups Projects
Commit f38da961 authored by Noam Zerah's avatar Noam Zerah
Browse files

first change

parent f2713d00
No related branches found
No related tags found
No related merge requests found
...@@ -1164,7 +1164,8 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std: ...@@ -1164,7 +1164,8 @@ static void setupDataType(std::shared_ptr<GraphView> graphView, std::vector<std:
tensor->setDataType(dataType); tensor->setDataType(dataType);
} }
void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet,
Clipping clippingMode, DataType targetType,bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose)
{ {
Log::notice(" === QUANT PTQ 0.2.21 === "); Log::notice(" === QUANT PTQ 0.2.21 === ");
...@@ -1211,7 +1212,18 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, ...@@ -1211,7 +1212,18 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits,
Log::notice(" Performing the Single-Shift approximation ..."); Log::notice(" Performing the Single-Shift approximation ...");
performSingleShiftApproximation(graphView, noQuant); performSingleShiftApproximation(graphView, noQuant);
} }
if( targetType != DataType::Float64 && targetType != DataType::Float32 && targetType != DataType::Float16)
{
Log::warn("HEREA\n");
AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
Log::notice("Starting to cast operators into the desired type ...");
castQuantizedGraph(graphView,DataType::Int32,singleShift);
Log::warn("HEREB\n");
}
else
{
setupDataType(graphView, inputDataSet, targetType);
}
if (verbose) if (verbose)
printScalingFactors(graphView); printScalingFactors(graphView);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment