diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 322b1d9a0632b893a912c6225ac5b13d63278f8d..85bfc408f092d9f234265db51a01eff1ab64005b 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -41,7 +41,19 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe."); std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); - const DimSize_t outSize = std::dynamic_pointer_cast<MatMul_Op>(matmulNode->getOperator()) -> getAttr<DimSize_t>("OutChannels"); + // TODO: find another way to get OutChannels for FC operator. + // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output + DimSize_t outSize = 0; + const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator()); + for (size_t i = 0; i < op->nbInputs(); i++) + { + const auto& inTensor = op->getInput(i); + if(inTensor->nbDims() > 0) { + outSize = inTensor->dims()[inTensor->nbDims()-1]; + break; + } + } + AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator."); // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index 924aac79ea8492f6ea0f2cd4d93676876c5a8331..1330a8e620ae5d49d6ef61257a587b914ffed1cd 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -126,9 +126,9 @@ TEST_CASE("GraphRegexUser") { SECTION("Applied Recipes"){ // generate the original GraphView - auto matmul0 = MatMul(5, 5, "matmul0"); + auto matmul0 = MatMul("matmul0"); auto add0 = Add(2, "add0"); - auto matmul1 = MatMul(5, 5, "matmul1"); + auto matmul1 = MatMul("matmul1"); auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); @@ -154,7 +154,7 @@ TEST_CASE("GraphRegexUser") { auto g = std::make_shared<GraphView>(); - g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc}); + g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc}); std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 968826230dfdf85290ee377aee155e06855c4b28..d0875fe10078eb9d8e3a97e0703191b5697f3fda 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -25,9 +25,9 @@ namespace Aidge { TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView - auto matmul0 = MatMul(5, 5, "matmul0"); + auto matmul0 = MatMul("matmul0"); auto add0 = Add(2, "add0"); - auto matmul1 = MatMul(5, 5, "matmul1"); + auto matmul1 = MatMul("matmul1"); auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); @@ -49,7 +49,7 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { b1->addChild(add1, 0, 1); auto g = std::make_shared<GraphView>(); - g->add({matmul0, add0, matmul1, add1, b0, b1}); + g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1}); // Check original graph REQUIRE(g->getNodes() ==