Skip to content
Snippets Groups Projects
Commit 458c4ae9 authored by Maxence Naud's avatar Maxence Naud
Browse files

[MIN] make FuseMulAdd more resilient to edge cases

parent ad28a0e8
No related branches found
No related tags found
1 merge request!45[Upd] replace() instead of replaceWith() in GraphView
Pipeline #33998 passed
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp"
// Graph Regex // Graph Regex
#include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp" #include "aidge/graphmatching/NodeRegex.hpp"
...@@ -49,19 +51,19 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -49,19 +51,19 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Fetch the output dimension throught the bias size // Fetch the output dimension throught the bias size
std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr; std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr;
Tensor& bias_tensor = bias->getOperator()->output(0); if (!(matmul->getParent(1))) {
AIDGE_INTERNAL_ASSERT("No weight detected to produce the fuseMulAdd recipe.");
std::shared_ptr<Node> weight = (matmul->getParent(1)) ? matmul->getParent(1)->cloneSharedOperators() : nullptr; }
std::shared_ptr<Node> weight = matmul->getParent(1)->cloneSharedOperators();
DimSize_t outSize = weight->getOperator()->output(0).dims<2>()[1];
// Instanciate FC // Instanciate FC
//std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false)); std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(outSize, bias ? false : true));
// Step 2 : Branch existing producers & create the others // Step 2 : Branch existing producers & create the others
// link weights & bias // link weights & bias
if (weight) { weight->addChild(fc, 0, 1);
weight->addChild(fc, 0, 1);
}
if (bias) { if (bias) {
bias->addChild(fc, 0, 2); bias->addChild(fc, 0, 2);
} }
...@@ -74,7 +76,8 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -74,7 +76,8 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// auto nodeToReplace = std::make_shared<GraphView>(); // auto nodeToReplace = std::make_shared<GraphView>();
// nodeToReplace->add(nodes, false); // nodeToReplace->add(nodes, false);
// nodeToReplace->replaceWith({fc}); // nodeToReplace->replaceWith({fc});
GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, {fc, weight, bias}); auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)});
GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, newNodes);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment