Skip to content
Snippets Groups Projects
Commit ce4864db authored by Noam Zerah's avatar Noam Zerah
Browse files

Last Fix

parent cd79e1c3
No related branches found
No related tags found
No related merge requests found
Pipeline #68773 passed
This commit is part of merge request !46. Comments created here will be created in the context of that merge request.
......@@ -58,21 +58,20 @@ bool isNotQuantized(std::shared_ptr<Node> node)
{
return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
}
std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
return graphView->getOrderedInputs()[0].first;
}
void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType)
{
for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes())
{
if(expandMetaOp(inputNode)) //Reset Tensor does not work on MetaOps
{
inputNode = std::static_pointer_cast<MetaOperator_Op>(inputNode->getOperator())->getMicroGraph()->getOrderedNodes()[0];
}
for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++)
{
std::shared_ptr<OperatorTensor> firstOperatorTensor = std::static_pointer_cast<OperatorTensor>(inputNode->getOperator());
firstOperatorTensor->resetInput(inputNode->getFirstFreeDataInput());
inputNode->getOperator()->resetInput(index);
}
}
}
bool checkArchitecture(std::shared_ptr<GraphView> graphView)
{
......@@ -564,7 +563,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
// ITERATE OVER THE GRAPH /////////////////////////////////////////////////
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
std::shared_ptr<Node> firstNode =getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeVector)
{
......@@ -762,7 +761,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges)
{
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
// CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
......@@ -856,7 +855,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose)
{
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first;
std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap;
......@@ -1291,6 +1290,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
Log::notice("Starting to cast operators into the desired type ...");
castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding);
graphView->updateInputsOutputs();
clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment