Skip to content
Snippets Groups Projects
Commit 080743a9 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] replace() instead of replaceWith() in GraphView

parent 4b783082
No related branches found
No related tags found
No related merge requests found
......@@ -322,17 +322,17 @@ public:
/**
* @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
*
*
* @param childNode Node that gets a new parent.
* @param newParentNode Inserted Node.
* @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
* @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
* @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
*/
void insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
void insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx);
/**
......@@ -342,6 +342,20 @@ public:
* @return false
*/
bool replaceWith(std::set<NodePtr> newNodes);
/**
* @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible.
* Both sets should include all the necessary Producers.
* @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing
* them will not be affected by the replacement. The oldNodes set should have only one input/output
* Node for automatic connections of newNodes set.
* @param oldNodes actual set of shared_ptr<Node> to replace.
* @param newNodes new set of shared_ptr<Node>.
* @return true
* @return false
*/
bool replace(std::set<NodePtr>& oldNodes, std::set<NodePtr>& newNodes);
void updateInputNodes();
/**
* @brief Process from zero the set of output Nodes.
......@@ -379,6 +393,12 @@ public:
*/
std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const;
/**
* @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t getNbFreeDataInputs() const;
private:
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
......@@ -390,12 +410,6 @@ private:
*/
IOIndex_t getNbDataInputs() const;
/**
* @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t getNbFreeDataInputs() const;
/**
* @brief Update the set of inputNodes with a new Node, checking if it can be
* added and removing any Node not part of mInputNode anymore.
......
......@@ -17,6 +17,7 @@
#include "aidge/utils/Types.h"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
......@@ -594,6 +595,106 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
return replacable;
}
bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidge::NodePtr>& newNodes) {
for (const auto& node : oldNodes) {
if (mNodes.find(node) == mNodes.end()) {
AIDGE_INTERNAL_ASSERT("GraphView asked to replace a Node it does not contain.");
}
}
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
auto oldG = std::make_shared<GraphView>();
oldG->add(oldNodes, false);
auto newG = std::make_shared<GraphView>();
newG->add(newNodes, false);
if ((oldG->inputNodes().size() != 1) || (oldG->outputNodes().size() != 1)) {
return false;
}
if (!(newNodes.empty()) && ((newG->inputNodes().size() != 1) ||
(newG->outputNodes().size() != 1))) {
return false;
}
std::shared_ptr<Node> previousInputNode = (*(oldG->inputNodes()).begin());
std::shared_ptr<Node> previousOutputNode = (*(oldG->outputNodes()).begin());
// find Node to link to new input Node
//compute number of input for previousInputNode not in oldNodes set
std::size_t nbExternalInputs = 0;
std::shared_ptr<Node> externalInput = nullptr;
IOIndex_t externalInputId = gk_IODefaultIndex;
for (const auto& input : previousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) {
nbExternalInputs++;
externalInput = input.first;
externalInputId = input.second;
}
}
if (nbExternalInputs > 1) {
AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set");
}
if (previousOutputNode->nbOutputs() != 1) {
return false;
}
// find Node to replicate output connections
std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin());
auto copyOutputs = previousOutputNode->outputs();
// manage Views for newNodes
// only keep common views to each node for the new set
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
for (const auto& nodePtr : oldNodes) {
const auto nodeView = nodePtr->views();
std::set<std::shared_ptr<GraphView>> intersection;
std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
nodeView.begin(), nodeView.end(),
std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection;
}
// clean Nodes to replace
// Do not include common Nodes to avoid cleaning Producers linked to newNodes
std::set<std::shared_ptr<Node>> nodesToClean;
std::set_difference(oldNodes.begin(), oldNodes.end(),
newNodes.begin(), newNodes.end(),
std::inserter(nodesToClean, nodesToClean.begin()));
for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
// copy output connections
for (IOIndex_t o = 0; o < previousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
}
}
// copy input connections
if (!newNodes.empty()) {
std::shared_ptr<Node> newInputNode = (*(newG->inputNodes()).begin());
if (newInputNode->getNbFreeDataInputs() > 1) {
return false;
}
// one non-connected input in newNodes set
externalInput->addChild(newInputNode, externalInputId, newInputNode->getFirstFreeDataInput());
}
// insert new Nodes in the right GraphViews
for (auto& graphPtr : commonGraphViews) {
graphPtr->add(newNodes, false);
if (newNodes.empty()) {
graphPtr->updateInputNodes();
graphPtr->updateOutputNodes();
}
}
return true;
}
void Aidge::GraphView::updateInputNodes() {
mInputNodes.clear();
for (const std::shared_ptr<Node>& go_ptr : mNodes) {
......
......@@ -332,6 +332,70 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
}
}
TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
SECTION("replace small pattern") {
// create original graph
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w");
auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b");
auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
auto matmul = GenericOperator("MatMul", 1, 2, 1, "matmul");
auto add = GenericOperator("Add", 1, 2, 1, "add");
otherInput->addChild(other1);
other1->addChild(matmul);
matmul->addChild(add);
add->addChild(other2);
matmulWeight->addChild(matmul, 0, 1);
addBias->addChild(add, 0, 1);
g->add({other1, matmul, add, other2});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add}));
// create graph to replace
std::set<std::shared_ptr<Node>> nodeToReplace = std::set<std::shared_ptr<Node>>({matmulWeight, addBias, matmul, add});
// create replacing graph
std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 3, 1, "fc");
auto newMatmulWeight = matmulWeight->cloneSharedOperators();
newMatmulWeight->addChild(myFC, 0, 1);
auto newAddBias = addBias->cloneSharedOperators();
newAddBias->addChild(myFC, 0, 2);
std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias});
// replace
g->replace(nodeToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC}));
REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias)));
}
SECTION("replace with nothing") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
auto r1 = GenericOperator("relu", 0, 0, 1);
auto r2 = GenericOperator("relu", 1, 1, 1);
auto r3 = GenericOperator("relu", 1, 1, 1);
auto r4 = GenericOperator("relu", 1, 1, 0);
r1->addChild(r2);
r2->addChild(r3);
r3->addChild(r4);
g->add({r1, r2, r3, r4});
auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3});
auto newNodes = std::set<std::shared_ptr<Node>>({});
g->replace(nodesToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4);
}
// SECTION("replace for tiling") {
// std::shared_ptr<GraphView> g = std::make_shared<GraphView>();
// auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
// auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
// auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv");
// auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
// otherInput->addChild()
// }
}
TEST_CASE("[GraphView] clone") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment