Skip to content
Snippets Groups Projects
Commit ad28a0e8 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] replace() member function to run python tests successfully

parent 2e9f0013
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
#ifndef AIDGE_CORE_UTILS_RECIPIES_H_ #ifndef AIDGE_CORE_UTILS_RECIPIES_H_
#define AIDGE_CORE_UTILS_RECIPIES_H_ #define AIDGE_CORE_UTILS_RECIPIES_H_
#include <memory>
#include <set>
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
...@@ -47,7 +50,7 @@ void removeFlatten(std::set<std::shared_ptr<Node>> nodes); ...@@ -47,7 +50,7 @@ void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
* @param graphView Graph view to use graph matching on, in order to apply transfomrations. * @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/ */
void removeFlatten(std::shared_ptr<GraphView> graphView); void removeFlatten(std::shared_ptr<GraphView> graphView);
// FUSE BN + FC || CONV -> FC || CONV // FUSE BN + FC || CONV -> FC || CONV
/** /**
......
...@@ -619,14 +619,17 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s ...@@ -619,14 +619,17 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
// copy output connections // copy output connections
for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { if (newOutputNode) {
auto outputPairs = copyOutputs[o]; for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
for (const auto& onePair : outputPairs) { auto outputPairs = copyOutputs[o];
newOutputNode->addChild(onePair.first, o, onePair.second); for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
}
} }
} }
// copy input connections // copy input connections
if (!newNodes.empty()) { if (!newNodes.empty() && externalInput) {
for (const auto& newInputNode : newG->inputNodes()) { for (const auto& newInputNode : newG->inputNodes()) {
IOIndex_t inputId = 0; IOIndex_t inputId = 0;
for (const auto& input : newInputNode->inputs()) { for (const auto& input : newInputNode->inputs()) {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <cassert> #include <cassert>
#include <map> #include <map>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
...@@ -419,6 +420,44 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { ...@@ -419,6 +420,44 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
} }
SECTION("Change every Nodes in a GraphView") {
auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0");
auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0");
auto matmul0 = GenericOperator("MatMul", 1, 2, 1, "matmul0");
auto add0 = GenericOperator("Add", 1, 2, 1, "add0");
auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1");
auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1");
auto matmul1 = GenericOperator("MatMul", 1, 2, 1, "matmul1");
auto add1 = GenericOperator("Add", 1, 2, 1, "add1");
matmulWeight0 -> addChild(matmul0, 0, 1);
addBias0 -> addChild(add0, 0, 1);
matmulWeight1 -> addChild(matmul1, 0, 1);
addBias1 -> addChild(add1, 0, 1);
matmul0 -> addChild(add0, 0, 0);
add0 -> addChild(matmul1, 0, 0);
matmul1 -> addChild(add1, 0, 0);
auto g = std::make_shared<GraphView>("TestGraph");
g -> add({matmulWeight0, addBias0, matmulWeight1, addBias1, matmul0, add0, matmul1, add1});
auto newMatmulWeight0 = matmulWeight0->cloneSharedOperators();
auto newAddBias0 = addBias0->cloneSharedOperators();
auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators();
auto newAddBias1 = addBias1->cloneSharedOperators();
auto fc0 = GenericOperator("FC", 1, 3, 1, "fc0");
auto fc1 = GenericOperator("FC", 1, 3, 1, "fc1");
newMatmulWeight0 -> addChild(fc0, 0, 1);
newAddBias0 -> addChild(fc0, 0, 2);
newMatmulWeight1 -> addChild(fc1, 0, 1);
newAddBias1 -> addChild(fc1, 0, 2);
GraphView::replace({matmul0, add0, matmulWeight0, addBias0}, {newMatmulWeight0, newAddBias0, fc0});
GraphView::replace({matmul1, add1, matmulWeight1, addBias1}, {newMatmulWeight1, newAddBias1, fc1});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0}));
}
} }
TEST_CASE("[GraphView] clone") { TEST_CASE("[GraphView] clone") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment