Skip to content
Snippets Groups Projects
Commit 4b3540a3 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] replace() member function for tiling handling

parent f646e90d
No related branches found
No related tags found
1 merge request!45[Upd] replace() instead of replaceWith() in GraphView
Pipeline #33911 failed
......@@ -344,17 +344,17 @@ public:
bool replaceWith(std::set<NodePtr> newNodes);
/**
* @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible.
* @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible.
* Both sets should include all the necessary Producers.
* @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing
* them will not be affected by the replacement. The oldNodes set should have only one input/output
* Node for automatic connections of newNodes set.
* @details Replaced Nodes are removed from any GraphView pointing at them all.
* The oldNodes set should have only one input/output
* Tensor for automatic connections of newNodes set.
* @param oldNodes actual set of shared_ptr<Node> to replace.
* @param newNodes new set of shared_ptr<Node>.
* @return true
* @return false
*/
bool replace(std::set<NodePtr>& oldNodes, std::set<NodePtr>& newNodes);
static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes);
void updateInputNodes();
/**
......
......@@ -595,41 +595,38 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
return replacable;
}
bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidge::NodePtr>& newNodes) {
for (const auto& node : oldNodes) {
if (mNodes.find(node) == mNodes.end()) {
AIDGE_INTERNAL_ASSERT("GraphView asked to replace a Node it does not contain.");
}
}
bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
auto oldG = std::make_shared<GraphView>();
auto oldG = std::make_shared<GraphView>("oldG");
oldG->add(oldNodes, false);
auto newG = std::make_shared<GraphView>();
auto newG = std::make_shared<GraphView>("newG");
newG->add(newNodes, false);
if ((oldG->inputNodes().size() != 1) || (oldG->outputNodes().size() != 1)) {
if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) {
return false;
}
if (!(newNodes.empty()) && ((newG->inputNodes().size() != 1) ||
if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) ||
(newG->outputNodes().size() != 1))) {
return false;
}
std::shared_ptr<Node> previousInputNode = (*(oldG->inputNodes()).begin());
std::shared_ptr<Node> previousOutputNode = (*(oldG->outputNodes()).begin());
// there is at least one inputNode in the old/new GraphView
std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin());
std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin());
// find Node to link to new input Node
//compute number of input for previousInputNode not in oldNodes set
//compute number of input for firstPreviousInputNode not in oldNodes set
std::size_t nbExternalInputs = 0;
std::shared_ptr<Node> externalInput = nullptr;
IOIndex_t externalInputId = gk_IODefaultIndex;
for (const auto& input : previousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) {
for (const auto& input : firstPreviousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG
nbExternalInputs++;
externalInput = input.first;
externalInputId = input.second;
......@@ -638,14 +635,28 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
if (nbExternalInputs > 1) {
AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set");
}
if (previousOutputNode->nbOutputs() != 1) {
if (oldG->inputNodes().size() > 1){
// one or no input has been identified. Checking every input points to the same source
for (const auto& previousInputNode : oldG->inputNodes()) {
for (const auto& input : previousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) {
if ( (externalInput != input.first) || (externalInputId != input.second) ) {
return false; // an inputNode points to an external Node different from the registered one
}
}
}
}
}
if (firstPreviousOutputNode->nbOutputs() != 1) {
return false;
}
// find Node to replicate output connections
std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin());
auto copyOutputs = previousOutputNode->outputs();
auto copyOutputs = firstPreviousOutputNode->outputs();
// manage Views for newNodes
// only keep common views to each node for the new set
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
......@@ -657,6 +668,8 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection;
}
commonGraphViews.erase(oldG);
commonGraphViews.erase(newG);
// clean Nodes to replace
// Do not include common Nodes to avoid cleaning Producers linked to newNodes
......@@ -667,7 +680,7 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
// copy output connections
for (IOIndex_t o = 0; o < previousOutputNode->nbOutputs(); ++o) {
for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second);
......@@ -675,22 +688,32 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
}
// copy input connections
if (!newNodes.empty()) {
std::shared_ptr<Node> newInputNode = (*(newG->inputNodes()).begin());
if (newInputNode->getNbFreeDataInputs() > 1) {
return false;
for (const auto& newInputNode : newG->inputNodes()) {
IOIndex_t inputId = 0;
for (const auto& input : newInputNode->inputs()) {
if (newNodes.find(input.first) == newNodes.end()) {
externalInput->addChild(newInputNode, externalInputId, inputId);
}
inputId++;
}
}
// one non-connected input in newNodes set
externalInput->addChild(newInputNode, externalInputId, newInputNode->getFirstFreeDataInput());
}
// insert new Nodes in the right GraphViews
for (auto& graphPtr : commonGraphViews) {
for (const auto& graphPtr : commonGraphViews) {
graphPtr->add(newNodes, false);
if (newNodes.empty()) {
graphPtr->updateInputNodes();
graphPtr->updateOutputNodes();
}
}
for (const auto& node : oldNodes) {
node->removeView(oldG);
}
for (const auto& node : newNodes) {
node->removeView(newG);
}
return true;
}
......
......@@ -364,7 +364,7 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias});
// replace
g->replace(nodeToReplace, newNodes);
GraphView::replace(nodeToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC}));
REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias)));
......@@ -381,19 +381,42 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
g->add({r1, r2, r3, r4});
auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3});
auto newNodes = std::set<std::shared_ptr<Node>>({});
g->replace(nodesToReplace, newNodes);
GraphView::replace(nodesToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4);
}
// SECTION("replace for tiling") {
// std::shared_ptr<GraphView> g = std::make_shared<GraphView>();
// auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
// auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
// auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv");
// auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
// otherInput->addChild()
// }
SECTION("replace for tiling") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv");
auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
otherInput->addChild(other1);
other1->addChild(myConv);
myConv->addChild(other2);
g->add({other1, myConv, other2});
// create tiled Conv
auto conv1 = GenericOperator("Conv", 1, 1, 1, "myConv1");
auto conv2 = GenericOperator("Conv", 1, 1, 1, "myConv2");
auto conv3 = GenericOperator("Conv", 1, 1, 1, "myConv3");
auto conv4 = GenericOperator("Conv", 1, 1, 1, "myConv4");
auto concat = GenericOperator("Concat", 4, 4, 1, "myConcat");
conv1->addChild(concat);
conv2->addChild(concat);
conv3->addChild(concat);
conv4->addChild(concat);
GraphView::replace({myConv}, {conv1, conv2, conv3, conv4, concat});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, conv1, conv2, conv3, conv4, concat, other2}));
GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
}
}
TEST_CASE("[GraphView] clone") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment