Skip to content
Snippets Groups Projects
Commit ce4864db authored by Noam Zerah's avatar Noam Zerah
Browse files

Last Fix

parent cd79e1c3
No related branches found
No related tags found
No related merge requests found
Pipeline #68773 passed
...@@ -58,21 +58,20 @@ bool isNotQuantized(std::shared_ptr<Node> node) ...@@ -58,21 +58,20 @@ bool isNotQuantized(std::shared_ptr<Node> node)
{ {
return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end()); return (notQuantizedNodeTypes.find(node->type()) != notQuantizedNodeTypes.end());
} }
std::shared_ptr<Aidge::Node> getFirstNode(std::shared_ptr<GraphView> graphView)
{
return graphView->getOrderedInputs()[0].first;
}
void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType) void clearGraphViewInputNodes(std::shared_ptr<GraphView> graphView,DataType targetType)
{ {
for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes()) for (std::shared_ptr<Aidge::Node> inputNode: graphView->inputNodes())
{ {
if(expandMetaOp(inputNode)) //Reset Tensor does not work on MetaOps
{
inputNode = std::static_pointer_cast<MetaOperator_Op>(inputNode->getOperator())->getMicroGraph()->getOrderedNodes()[0];
}
for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++) for (Aidge::IOIndex_t index = inputNode->getFirstFreeDataInput();index < inputNode->getNbFreeDataInputs(); index++)
{ {
std::shared_ptr<OperatorTensor> firstOperatorTensor = std::static_pointer_cast<OperatorTensor>(inputNode->getOperator()); inputNode->getOperator()->resetInput(index);
firstOperatorTensor->resetInput(inputNode->getFirstFreeDataInput());
} }
} }
} }
bool checkArchitecture(std::shared_ptr<GraphView> graphView) bool checkArchitecture(std::shared_ptr<GraphView> graphView)
{ {
...@@ -564,7 +563,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) ...@@ -564,7 +563,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
// ITERATE OVER THE GRAPH ///////////////////////////////////////////////// // ITERATE OVER THE GRAPH /////////////////////////////////////////////////
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; std::shared_ptr<Node> firstNode =getFirstNode(graphView);
for (std::shared_ptr<Node> node : nodeVector) for (std::shared_ptr<Node> node : nodeVector)
{ {
...@@ -762,7 +761,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< ...@@ -762,7 +761,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr<
void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges) void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_map<std::shared_ptr<Node>, double> valueRanges)
{ {
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; std::shared_ptr<Node> firstNode = getFirstNode(graphView);
// CREATE THE ACCUMULATED RATIO MAP /////////////////////////////////////// // CREATE THE ACCUMULATED RATIO MAP ///////////////////////////////////////
...@@ -856,7 +855,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m ...@@ -856,7 +855,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose) std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap(std::shared_ptr<GraphView> graphView, bool verbose)
{ {
std::shared_ptr<Node> firstNode = graphView->getOrderedInputs()[0].first; std::shared_ptr<Node> firstNode = getFirstNode(graphView);
std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap; std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> signMap;
...@@ -1291,6 +1290,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, ...@@ -1291,6 +1290,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView,
AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!") AIDGE_ASSERT(!noQuant,"Cannot cast operators with the noQuant (Fake Quantization) flag set to true!")
Log::notice("Starting to cast operators into the desired type ..."); Log::notice("Starting to cast operators into the desired type ...");
castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding); castQuantizedGraph(graphView,targetType,singleShift,bitshiftRounding);
graphView->updateInputsOutputs(); graphView->updateInputsOutputs();
clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType clearGraphViewInputNodes(graphView,targetType); //Convert all input tensors of the GV into targetType
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment