diff --git a/include/aidge/quantization_version.h b/include/aidge/quantization_version.h index 1ca687d1b556ad312b11947ee21ea0d62d54b86f..546263af3a7e8b7a73991173f48d0b095c7d9501 100644 --- a/include/aidge/quantization_version.h +++ b/include/aidge/quantization_version.h @@ -3,9 +3,9 @@ namespace Aidge { static constexpr const int PROJECT_VERSION_MAJOR = 0; -static constexpr const int PROJECT_VERSION_MINOR = 3; -static constexpr const int PROJECT_VERSION_PATCH = 1; -static constexpr const char * PROJECT_VERSION = "0.3.1"; -static constexpr const char * PROJECT_GIT_HASH = "418bc3e"; +static constexpr const int PROJECT_VERSION_MINOR = 2; +static constexpr const int PROJECT_VERSION_PATCH = 0; +static constexpr const char * PROJECT_VERSION = "0.2.0"; +static constexpr const char * PROJECT_GIT_HASH = "f50c860"; } #endif // VERSION_H diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index e1500e3d83bf68833fc5cf0be4097cfc76d24f65..b93466577292d2be5d377e77e2ece623c91a1af2 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -57,7 +57,16 @@ bool isNotQuantized(std::shared_ptr<Node> node) { return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); } - +void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView) +{ + for (std::shared_ptr<Aidge::Node> input_node: graphView->inputNodes()) + { + for (Aidge::IOIndex_t index = input_node->getFirstFreeDataInput();index < input_node->getNbFreeDataInputs(); index++) + { + std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->resetInput(index); + } + } +} bool checkArchitecture(std::shared_ptr<GraphView> graphView) { std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"}); @@ -1275,12 +1284,14 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, } //Mandatory to handle all of the newly added connections! graphView->updateInputsOutputs(); + + //Clearing input nodes + Log::notice("Clearing all input nodes ..."); + clearGraphViewInputNodes(graphView); + if (verbose) printScalingFactors(graphView); - - if (useCuda) - // graphView->setBackend("cuda"); - + Log::notice(" Reseting the scheduler ..."); SequentialScheduler scheduler(graphView); scheduler.resetScheduling();