diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 3c895c4b162e04fd3efe1904e42e428e6c9e3db9..15d4bb6e0ae86c6e361d034f5b211c4ecdd5abeb 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -20,6 +20,8 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" +#include "aidge/utils/ErrorHandling.hpp" + // Graph Regex #include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/NodeRegex.hpp" @@ -49,19 +51,19 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Fetch the output dimension throught the bias size std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr; - Tensor& bias_tensor = bias->getOperator()->output(0); - - std::shared_ptr<Node> weight = (matmul->getParent(1)) ? matmul->getParent(1)->cloneSharedOperators() : nullptr; + if (!(matmul->getParent(1))) { + AIDGE_INTERNAL_ASSERT("No weight detected to produce the fuseMulAdd recipe."); + } + std::shared_ptr<Node> weight = matmul->getParent(1)->cloneSharedOperators(); + DimSize_t outSize = weight->getOperator()->output(0).dims<2>()[1]; // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); - std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false)); + std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(outSize, bias ? false : true)); // Step 2 : Branch existing producers & create the others // link weights & bias - if (weight) { - weight->addChild(fc, 0, 1); - } + weight->addChild(fc, 0, 1); if (bias) { bias->addChild(fc, 0, 2); } @@ -74,7 +76,8 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // auto nodeToReplace = std::make_shared<GraphView>(); // nodeToReplace->add(nodes, false); // nodeToReplace->replaceWith({fc}); - GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, {fc, weight, bias}); + auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)}); + GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, newNodes); }