From 5df2101f2b378657df179c31c75afc6c4b602811 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 31 Jan 2024 16:41:32 +0100 Subject: [PATCH] fix FuseMulAdd recipe and its tests --- src/recipies/FuseMulAdd.cpp | 14 +++++++++++++- unit_tests/graphRegex/Test_GraphRegex.cpp | 6 +++--- unit_tests/recipies/Test_FuseMulAdd.cpp | 6 +++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 322b1d9a0..85bfc408f 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -41,7 +41,19 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe."); std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); - const DimSize_t outSize = std::dynamic_pointer_cast<MatMul_Op>(matmulNode->getOperator()) -> getAttr<DimSize_t>("OutChannels"); + // TODO: find another way to get OutChannels for FC operator. + // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output + DimSize_t outSize = 0; + const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator()); + for (size_t i = 0; i < op->nbInputs(); i++) + { + const auto& inTensor = op->getInput(i); + if(inTensor->nbDims() > 0) { + outSize = inTensor->dims()[inTensor->nbDims()-1]; + break; + } + } + AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator."); // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index 924aac79e..1330a8e62 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -126,9 +126,9 @@ TEST_CASE("GraphRegexUser") { SECTION("Applied Recipes"){ // generate the original GraphView - auto matmul0 = MatMul(5, 5, "matmul0"); + auto matmul0 = MatMul("matmul0"); auto add0 = Add(2, "add0"); - auto matmul1 = MatMul(5, 5, "matmul1"); + auto matmul1 = MatMul("matmul1"); auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); @@ -154,7 +154,7 @@ TEST_CASE("GraphRegexUser") { auto g = std::make_shared<GraphView>(); - g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc}); + g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc}); std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 968826230..d0875fe10 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -25,9 +25,9 @@ namespace Aidge { TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView - auto matmul0 = MatMul(5, 5, "matmul0"); + auto matmul0 = MatMul("matmul0"); auto add0 = Add(2, "add0"); - auto matmul1 = MatMul(5, 5, "matmul1"); + auto matmul1 = MatMul("matmul1"); auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); @@ -49,7 +49,7 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { b1->addChild(add1, 0, 1); auto g = std::make_shared<GraphView>(); - g->add({matmul0, add0, matmul1, add1, b0, b1}); + g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1}); // Check original graph REQUIRE(g->getNodes() == -- GitLab