Skip to content
Snippets Groups Projects
Commit c2e7afcd authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

fix FuseMulAdd recipe and its tests

parent ad799c3a
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!76Matmul rework
...@@ -41,7 +41,19 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -41,7 +41,19 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe."); AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe.");
std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators();
const DimSize_t outSize = std::dynamic_pointer_cast<MatMul_Op>(matmulNode->getOperator()) -> getAttr<DimSize_t>("OutChannels"); // TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator());
for (size_t i = 0; i < op->nbInputs(); i++)
{
const auto& inTensor = op->getInput(i);
if(inTensor->nbDims() > 0) {
outSize = inTensor->dims()[inTensor->nbDims()-1];
break;
}
}
AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator.");
// Instanciate FC // Instanciate FC
//std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
......
...@@ -126,9 +126,9 @@ TEST_CASE("GraphRegexUser") { ...@@ -126,9 +126,9 @@ TEST_CASE("GraphRegexUser") {
SECTION("Applied Recipes"){ SECTION("Applied Recipes"){
// generate the original GraphView // generate the original GraphView
auto matmul0 = MatMul(5, 5, "matmul0"); auto matmul0 = MatMul("matmul0");
auto add0 = Add(2, "add0"); auto add0 = Add(2, "add0");
auto matmul1 = MatMul(5, 5, "matmul1"); auto matmul1 = MatMul("matmul1");
auto add1 = Add(2, "add1"); auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0"); auto b0 = Producer({5}, "B0");
...@@ -154,7 +154,7 @@ TEST_CASE("GraphRegexUser") { ...@@ -154,7 +154,7 @@ TEST_CASE("GraphRegexUser") {
auto g = std::make_shared<GraphView>(); auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc}); g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc});
std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>();
......
...@@ -25,9 +25,9 @@ namespace Aidge { ...@@ -25,9 +25,9 @@ namespace Aidge {
TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
// generate the original GraphView // generate the original GraphView
auto matmul0 = MatMul(5, 5, "matmul0"); auto matmul0 = MatMul("matmul0");
auto add0 = Add(2, "add0"); auto add0 = Add(2, "add0");
auto matmul1 = MatMul(5, 5, "matmul1"); auto matmul1 = MatMul("matmul1");
auto add1 = Add(2, "add1"); auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0"); auto b0 = Producer({5}, "B0");
...@@ -49,7 +49,7 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { ...@@ -49,7 +49,7 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
b1->addChild(add1, 0, 1); b1->addChild(add1, 0, 1);
auto g = std::make_shared<GraphView>(); auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1}); g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1});
// Check original graph // Check original graph
REQUIRE(g->getNodes() == REQUIRE(g->getNodes() ==
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment