From 811b7e16e851459344f9eda3e18e4a89ada57375 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 3 Oct 2023 13:41:06 +0000 Subject: [PATCH] [Add] FuseMulAdd test and fix typo in test_tensor --- aidge_core/unit_tests/test_tensor.py | 2 +- unit_tests/recipies/Test_FuseMulAdd.cpp | 77 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 unit_tests/recipies/Test_FuseMulAdd.cpp diff --git a/aidge_core/unit_tests/test_tensor.py b/aidge_core/unit_tests/test_tensor.py index 15d2f1a7b..a214a0e35 100644 --- a/aidge_core/unit_tests/test_tensor.py +++ b/aidge_core/unit_tests/test_tensor.py @@ -14,7 +14,7 @@ import aidge_core from functools import reduce import numpy as np -class test_tesnor(unittest.TestCase): +class test_tensor(unittest.TestCase): """ """ def setUp(self): diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp new file mode 100644 index 000000000..da5364205 --- /dev/null +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -0,0 +1,77 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +// #include "aidge/backend/cpu/operator/AddImpl.hpp" +// #include "aidge/backend/cpu/operator/ConvImpl.hpp" +// #include "aidge/backend/cpu/operator/FCImpl.hpp" +// #include "aidge/backend/cpu/operator/MatMulImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Recipies.hpp" + +namespace Aidge { + +TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { + // generate the original GraphView + auto matmul0 = MatMul(5, "matmul0"); + auto add0 = Add<2>("add0"); + auto matmul1 = MatMul(5, "matmul1"); + auto add1 = Add<2>("add1"); + + auto b0 = Producer({5}, "B0"); + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(add0, 0, 0); + b0->addChild(add0, 0, 1); + + add0->addChild(matmul1, 0, 0); + w1->addChild(matmul1, 0, 1); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto g = std::make_shared<GraphView>(); + g->add({matmul0, add0, matmul1, add1, b0, b1}); + + // Check original graph + REQUIRE(g->getNodes() == + std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); + REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0))); + REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0))); + REQUIRE(((matmul1->getParent(0) == add0) && (matmul1->getParent(1) == w1))); + REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); + + // Transform GraphView inplace + fuseMulAdd(g); + g->save("bonjour"); + + // Check new GraphView + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); + REQUIRE(newNodes.size() == 6); + for (const auto& node : newNodes) { + REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); + } +} +} // namespace Aidge \ No newline at end of file -- GitLab