From 811b7e16e851459344f9eda3e18e4a89ada57375 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 3 Oct 2023 13:41:06 +0000
Subject: [PATCH] [Add] FuseMulAdd test and fix typo in test_tensor

---
 aidge_core/unit_tests/test_tensor.py    |  2 +-
 unit_tests/recipies/Test_FuseMulAdd.cpp | 77 +++++++++++++++++++++++++
 2 files changed, 78 insertions(+), 1 deletion(-)
 create mode 100644 unit_tests/recipies/Test_FuseMulAdd.cpp

diff --git a/aidge_core/unit_tests/test_tensor.py b/aidge_core/unit_tests/test_tensor.py
index 15d2f1a7b..a214a0e35 100644
--- a/aidge_core/unit_tests/test_tensor.py
+++ b/aidge_core/unit_tests/test_tensor.py
@@ -14,7 +14,7 @@ import aidge_core
 from functools import reduce
 import numpy as np
 
-class test_tesnor(unittest.TestCase):
+class test_tensor(unittest.TestCase):
     """
     """
     def setUp(self):
diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp
new file mode 100644
index 000000000..da5364205
--- /dev/null
+++ b/unit_tests/recipies/Test_FuseMulAdd.cpp
@@ -0,0 +1,77 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <catch2/catch_test_macros.hpp>
+#include <set>
+
+// #include "aidge/backend/cpu/operator/AddImpl.hpp"
+// #include "aidge/backend/cpu/operator/ConvImpl.hpp"
+// #include "aidge/backend/cpu/operator/FCImpl.hpp"
+// #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/operator/Add.hpp"
+#include "aidge/operator/FC.hpp"
+#include "aidge/operator/MatMul.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/utils/Recipies.hpp"
+
+namespace Aidge {
+
+TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
+    // generate the original GraphView
+    auto matmul0 = MatMul(5, "matmul0");
+    auto add0 = Add<2>("add0");
+    auto matmul1 = MatMul(5, "matmul1");
+    auto add1 = Add<2>("add1");
+
+    auto b0 = Producer({5}, "B0");
+    auto w0 = Producer({5, 5}, "W0");
+    auto b1 = Producer({5}, "B1");
+    auto w1 = Producer({5,5},"W1");
+    auto input = Producer({2,5}, "input");
+
+    input->addChild(matmul0, 0, 0);
+    w0->addChild(matmul0, 0, 1);
+
+    matmul0->addChild(add0, 0, 0);
+    b0->addChild(add0, 0, 1);
+
+    add0->addChild(matmul1, 0, 0);
+    w1->addChild(matmul1, 0, 1);
+
+    matmul1->addChild(add1, 0, 0);
+    b1->addChild(add1, 0, 1);
+
+    auto g = std::make_shared<GraphView>();
+    g->add({matmul0, add0, matmul1, add1, b0, b1});
+
+    // Check original graph
+    REQUIRE(g->getNodes() ==
+            std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
+    REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0)));
+    REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0)));
+    REQUIRE(((matmul1->getParent(0) == add0) && (matmul1->getParent(1) == w1)));
+    REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));
+
+	// Transform GraphView inplace
+    fuseMulAdd(g);
+	g->save("bonjour");
+
+	// Check new GraphView
+	 std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
+	REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
+	REQUIRE(newNodes.size() == 6);
+	for (const auto& node : newNodes) {
+		REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
+	}
+}
+}  // namespace Aidge
\ No newline at end of file
-- 
GitLab