Skip to content
Snippets Groups Projects
Commit c2e7afcd authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

fix FuseMulAdd recipe and its tests

parent ad799c3a
No related branches found
No related tags found
No related merge requests found
......@@ -41,7 +41,19 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe.");
std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators();
const DimSize_t outSize = std::dynamic_pointer_cast<MatMul_Op>(matmulNode->getOperator()) -> getAttr<DimSize_t>("OutChannels");
// TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator());
for (size_t i = 0; i < op->nbInputs(); i++)
{
const auto& inTensor = op->getInput(i);
if(inTensor->nbDims() > 0) {
outSize = inTensor->dims()[inTensor->nbDims()-1];
break;
}
}
AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator.");
// Instanciate FC
//std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
......
......@@ -126,9 +126,9 @@ TEST_CASE("GraphRegexUser") {
SECTION("Applied Recipes"){
// generate the original GraphView
auto matmul0 = MatMul(5, 5, "matmul0");
auto matmul0 = MatMul("matmul0");
auto add0 = Add(2, "add0");
auto matmul1 = MatMul(5, 5, "matmul1");
auto matmul1 = MatMul("matmul1");
auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0");
......@@ -154,7 +154,7 @@ TEST_CASE("GraphRegexUser") {
auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc});
g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc});
std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>();
......
......@@ -25,9 +25,9 @@ namespace Aidge {
TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
// generate the original GraphView
auto matmul0 = MatMul(5, 5, "matmul0");
auto matmul0 = MatMul("matmul0");
auto add0 = Add(2, "add0");
auto matmul1 = MatMul(5, 5, "matmul1");
auto matmul1 = MatMul("matmul1");
auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0");
......@@ -49,7 +49,7 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
b1->addChild(add1, 0, 1);
auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1});
g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1});
// Check original graph
REQUIRE(g->getNodes() ==
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment