Skip to content
Snippets Groups Projects
Commit 811b7e16 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] FuseMulAdd test and fix typo in test_tensor

parent cf344748
No related branches found
No related tags found
1 merge request!9Fuse bn
Pipeline #32369 failed
...@@ -14,7 +14,7 @@ import aidge_core ...@@ -14,7 +14,7 @@ import aidge_core
from functools import reduce from functools import reduce
import numpy as np import numpy as np
class test_tesnor(unittest.TestCase): class test_tensor(unittest.TestCase):
""" """
""" """
def setUp(self): def setUp(self):
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
// #include "aidge/backend/cpu/operator/AddImpl.hpp"
// #include "aidge/backend/cpu/operator/ConvImpl.hpp"
// #include "aidge/backend/cpu/operator/FCImpl.hpp"
// #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Recipies.hpp"
namespace Aidge {
TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
// generate the original GraphView
auto matmul0 = MatMul(5, "matmul0");
auto add0 = Add<2>("add0");
auto matmul1 = MatMul(5, "matmul1");
auto add1 = Add<2>("add1");
auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);
auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1});
// Check original graph
REQUIRE(g->getNodes() ==
std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0)));
REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0)));
REQUIRE(((matmul1->getParent(0) == add0) && (matmul1->getParent(1) == w1)));
REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));
// Transform GraphView inplace
fuseMulAdd(g);
g->save("bonjour");
// Check new GraphView
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(newNodes.size() == 6);
for (const auto& node : newNodes) {
REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
}
}
} // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment