Skip to content
Snippets Groups Projects
Commit 54f12a68 authored by Cyril Moineau's avatar Cyril Moineau Committed by Cyril Moineau
Browse files

Fix unittest, to remove cases where implementation is required.

parent 9d58d7e8
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!137Fix fuse mulAdd (and more ;))
Pipeline #48498 failed
...@@ -58,7 +58,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -58,7 +58,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
std::vector<DimSize_t> shape = weightTensor->dims(); std::vector<DimSize_t> shape = weightTensor->dims();
std::reverse(shape.begin(), shape.end()); std::reverse(shape.begin(), shape.end());
weightTensor->copyTranspose(*weightTensor, std::vector<Aidge::DimSize_t>({1ul, 0ul})); weightTensor->copyTranspose(*weightTensor, std::vector<Aidge::DimSize_t>({1ul, 0ul}));
// weightOpTensor->setOutput(0, std::make_shared<Aidge::Tensor>(weightTensor->transpose(shape)));
} }
else if ((matmulNode->getParent(0) && !matmulNode->getParent(1)) else if ((matmulNode->getParent(0) && !matmulNode->getParent(1))
|| (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type
......
...@@ -31,11 +31,11 @@ TEST_CASE("GraphRegexUser") { ...@@ -31,11 +31,11 @@ TEST_CASE("GraphRegexUser") {
g1->addChild(fc, "c"); g1->addChild(fc, "c");
g1->addChild(conv2, "c1"); g1->addChild(conv2, "c1");
g1->addChild(fc2, "c2"); g1->addChild(fc2, "c2");
/// ///
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
sut->setNodeKey("C",+[](NodePtr NodeOp){return NodeOp->type() == "FC";}); sut->setNodeKey("C",+[](NodePtr NodeOp){return NodeOp->type() == "FC";});
sut->setNodeKey("A","C($)==True"); sut->setNodeKey("A","C($)==True");
sut->addQuery("A"); sut->addQuery("A");
auto match = sut->match(g1); auto match = sut->match(g1);
...@@ -163,14 +163,14 @@ TEST_CASE("GraphRegexUser") { ...@@ -163,14 +163,14 @@ TEST_CASE("GraphRegexUser") {
auto w1 = Producer({5,5},"W1"); auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input"); auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0); input->addChild(matmul0, 0, 1);
w0->addChild(matmul0, 0, 1); w0->addChild(matmul0, 0, 0);
matmul0->addChild(add0, 0, 0); matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1); b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 0); add0->addChild(matmul1, 0, 1);
w1->addChild(matmul1, 0, 1); w1->addChild(matmul1, 0, 0);
matmul1->addChild(add1, 0, 0); matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1); b1->addChild(add1, 0, 1);
...@@ -201,4 +201,4 @@ TEST_CASE("GraphRegexUser") { ...@@ -201,4 +201,4 @@ TEST_CASE("GraphRegexUser") {
} }
} }
\ No newline at end of file
...@@ -42,8 +42,8 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { ...@@ -42,8 +42,8 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") {
matmul0->addChild(add0, 0, 0); matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1); b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 0); add0->addChild(matmul1, 0, 1);
w1->addChild(matmul1, 0, 1); w1->addChild(matmul1, 0, 0);
matmul1->addChild(add1, 0, 0); matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1); b1->addChild(add1, 0, 1);
...@@ -56,14 +56,14 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { ...@@ -56,14 +56,14 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") {
std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0))); REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0)));
REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0))); REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0)));
REQUIRE(((matmul1->getParent(0) == add0) && (matmul1->getParent(1) == w1))); REQUIRE(((matmul1->getParent(1) == add0) && (matmul1->getParent(0) == w1)));
REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));
// Transform GraphView inplace // Transform GraphView inplace
fuseMulAdd(g); fuseMulAdd(g);
// Check new GraphView // Check new GraphView
std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(newNodes.size() == 6); REQUIRE(newNodes.size() == 6);
for (const auto& node : newNodes) { for (const auto& node : newNodes) {
...@@ -71,4 +71,4 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { ...@@ -71,4 +71,4 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") {
} }
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment