Skip to content
Snippets Groups Projects
Commit 6c9a568d authored by Noam Zerah's avatar Noam Zerah
Browse files

Adding clearInput method to supress Int32 inputs after PTQ

parent 8db2a996
No related branches found
No related tags found
No related merge requests found
Pipeline #67167 failed
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
namespace Aidge { namespace Aidge {
static constexpr const int PROJECT_VERSION_MAJOR = 0; static constexpr const int PROJECT_VERSION_MAJOR = 0;
static constexpr const int PROJECT_VERSION_MINOR = 3; static constexpr const int PROJECT_VERSION_MINOR = 2;
static constexpr const int PROJECT_VERSION_PATCH = 1; static constexpr const int PROJECT_VERSION_PATCH = 0;
static constexpr const char * PROJECT_VERSION = "0.3.1"; static constexpr const char * PROJECT_VERSION = "0.2.0";
static constexpr const char * PROJECT_GIT_HASH = "418bc3e"; static constexpr const char * PROJECT_GIT_HASH = "f50c860";
} }
#endif // VERSION_H #endif // VERSION_H
...@@ -57,7 +57,16 @@ bool isNotQuantized(std::shared_ptr<Node> node) ...@@ -57,7 +57,16 @@ bool isNotQuantized(std::shared_ptr<Node> node)
{ {
return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
} }
void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView)
{
for (std::shared_ptr<Aidge::Node> input_node: graphView->inputNodes())
{
for (Aidge::IOIndex_t index = input_node->getFirstFreeDataInput();index < input_node->getNbFreeDataInputs(); index++)
{
std::static_pointer_cast<OperatorTensor>(input_node->getOperator())->resetInput(index);
}
}
}
bool checkArchitecture(std::shared_ptr<GraphView> graphView) bool checkArchitecture(std::shared_ptr<GraphView> graphView)
{ {
std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"}); std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"});
...@@ -1275,12 +1284,14 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, ...@@ -1275,12 +1284,14 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
} }
//Mandatory to handle all of the newly added connections! //Mandatory to handle all of the newly added connections!
graphView->updateInputsOutputs(); graphView->updateInputsOutputs();
//Clearing input nodes
Log::notice("Clearing all input nodes ...");
clearGraphViewInputNodes(graphView);
if (verbose) if (verbose)
printScalingFactors(graphView); printScalingFactors(graphView);
if (useCuda)
// graphView->setBackend("cuda");
Log::notice(" Reseting the scheduler ..."); Log::notice(" Reseting the scheduler ...");
SequentialScheduler scheduler(graphView); SequentialScheduler scheduler(graphView);
scheduler.resetScheduling(); scheduler.resetScheduling();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment