/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <catch2/catch_test_macros.hpp> #include <set> #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/FC.hpp" #include "aidge/operator/MatMul.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/recipies/Recipies.hpp" namespace Aidge { TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView auto matmul0 = MatMul("matmul0"); auto add0 = Add(2, "add0"); auto matmul1 = MatMul("matmul1"); auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); auto w0 = Producer({5, 5}, "W0"); auto b1 = Producer({5}, "B1"); auto w1 = Producer({5,5},"W1"); auto input = Producer({2,5}, "input"); input->addChild(matmul0, 0, 0); w0->addChild(matmul0, 0, 1); matmul0->addChild(add0, 0, 0); b0->addChild(add0, 0, 1); add0->addChild(matmul1, 0, 0); w1->addChild(matmul1, 0, 1); matmul1->addChild(add1, 0, 0); b1->addChild(add1, 0, 1); auto g = std::make_shared<GraphView>(); g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1}); // Check original graph REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0))); REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0))); REQUIRE(((matmul1->getParent(0) == add0) && (matmul1->getParent(1) == w1))); REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); // Transform GraphView inplace fuseMulAdd(g); // Check new GraphView std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); REQUIRE(newNodes.size() == 6); for (const auto& node : newNodes) { REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); } } } // namespace Aidge