Skip to content
Snippets Groups Projects
Commit ad28a0e8 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] replace() member function to run python tests successfully

parent 2e9f0013
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,9 @@
#ifndef AIDGE_CORE_UTILS_RECIPIES_H_
#define AIDGE_CORE_UTILS_RECIPIES_H_
#include <memory>
#include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
......@@ -47,7 +50,7 @@ void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void removeFlatten(std::shared_ptr<GraphView> graphView);
// FUSE BN + FC || CONV -> FC || CONV
/**
......
......@@ -619,14 +619,17 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
// copy output connections
for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
if (newOutputNode) {
for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
}
}
}
// copy input connections
if (!newNodes.empty()) {
if (!newNodes.empty() && externalInput) {
for (const auto& newInputNode : newG->inputNodes()) {
IOIndex_t inputId = 0;
for (const auto& input : newInputNode->inputs()) {
......
......@@ -12,6 +12,7 @@
#include <cassert>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <catch2/catch_test_macros.hpp>
......@@ -419,6 +420,44 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
}
SECTION("Change every Nodes in a GraphView") {
auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0");
auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0");
auto matmul0 = GenericOperator("MatMul", 1, 2, 1, "matmul0");
auto add0 = GenericOperator("Add", 1, 2, 1, "add0");
auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1");
auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1");
auto matmul1 = GenericOperator("MatMul", 1, 2, 1, "matmul1");
auto add1 = GenericOperator("Add", 1, 2, 1, "add1");
matmulWeight0 -> addChild(matmul0, 0, 1);
addBias0 -> addChild(add0, 0, 1);
matmulWeight1 -> addChild(matmul1, 0, 1);
addBias1 -> addChild(add1, 0, 1);
matmul0 -> addChild(add0, 0, 0);
add0 -> addChild(matmul1, 0, 0);
matmul1 -> addChild(add1, 0, 0);
auto g = std::make_shared<GraphView>("TestGraph");
g -> add({matmulWeight0, addBias0, matmulWeight1, addBias1, matmul0, add0, matmul1, add1});
auto newMatmulWeight0 = matmulWeight0->cloneSharedOperators();
auto newAddBias0 = addBias0->cloneSharedOperators();
auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators();
auto newAddBias1 = addBias1->cloneSharedOperators();
auto fc0 = GenericOperator("FC", 1, 3, 1, "fc0");
auto fc1 = GenericOperator("FC", 1, 3, 1, "fc1");
newMatmulWeight0 -> addChild(fc0, 0, 1);
newAddBias0 -> addChild(fc0, 0, 2);
newMatmulWeight1 -> addChild(fc1, 0, 1);
newAddBias1 -> addChild(fc1, 0, 2);
GraphView::replace({matmul0, add0, matmulWeight0, addBias0}, {newMatmulWeight0, newAddBias0, fc0});
GraphView::replace({matmul1, add1, matmulWeight1, addBias1}, {newMatmulWeight1, newAddBias1, fc1});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0}));
}
}
TEST_CASE("[GraphView] clone") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment