Skip to content
Snippets Groups Projects
Commit 4b3540a3 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] replace() member function for tiling handling

parent f646e90d
No related branches found
No related tags found
No related merge requests found
...@@ -344,17 +344,17 @@ public: ...@@ -344,17 +344,17 @@ public:
bool replaceWith(std::set<NodePtr> newNodes); bool replaceWith(std::set<NodePtr> newNodes);
/** /**
* @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible. * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible.
* Both sets should include all the necessary Producers. * Both sets should include all the necessary Producers.
* @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing * @details Replaced Nodes are removed from any GraphView pointing at them all.
* them will not be affected by the replacement. The oldNodes set should have only one input/output * The oldNodes set should have only one input/output
* Node for automatic connections of newNodes set. * Tensor for automatic connections of newNodes set.
* @param oldNodes actual set of shared_ptr<Node> to replace. * @param oldNodes actual set of shared_ptr<Node> to replace.
* @param newNodes new set of shared_ptr<Node>. * @param newNodes new set of shared_ptr<Node>.
* @return true * @return true
* @return false * @return false
*/ */
bool replace(std::set<NodePtr>& oldNodes, std::set<NodePtr>& newNodes); static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes);
void updateInputNodes(); void updateInputNodes();
/** /**
......
...@@ -595,41 +595,38 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { ...@@ -595,41 +595,38 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
return replacable; return replacable;
} }
bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidge::NodePtr>& newNodes) { bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
for (const auto& node : oldNodes) {
if (mNodes.find(node) == mNodes.end()) {
AIDGE_INTERNAL_ASSERT("GraphView asked to replace a Node it does not contain.");
}
}
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input? // How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions // TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included // It also avoids specifying each producer since they are automatically included
auto oldG = std::make_shared<GraphView>(); auto oldG = std::make_shared<GraphView>("oldG");
oldG->add(oldNodes, false); oldG->add(oldNodes, false);
auto newG = std::make_shared<GraphView>(); auto newG = std::make_shared<GraphView>("newG");
newG->add(newNodes, false); newG->add(newNodes, false);
if ((oldG->inputNodes().size() != 1) || (oldG->outputNodes().size() != 1)) { if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) {
return false; return false;
} }
if (!(newNodes.empty()) && ((newG->inputNodes().size() != 1) || if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) ||
(newG->outputNodes().size() != 1))) { (newG->outputNodes().size() != 1))) {
return false; return false;
} }
std::shared_ptr<Node> previousInputNode = (*(oldG->inputNodes()).begin()); // there is at least one inputNode in the old/new GraphView
std::shared_ptr<Node> previousOutputNode = (*(oldG->outputNodes()).begin()); std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin());
std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin());
// find Node to link to new input Node // find Node to link to new input Node
//compute number of input for previousInputNode not in oldNodes set //compute number of input for firstPreviousInputNode not in oldNodes set
std::size_t nbExternalInputs = 0; std::size_t nbExternalInputs = 0;
std::shared_ptr<Node> externalInput = nullptr; std::shared_ptr<Node> externalInput = nullptr;
IOIndex_t externalInputId = gk_IODefaultIndex; IOIndex_t externalInputId = gk_IODefaultIndex;
for (const auto& input : previousInputNode->inputs()) { for (const auto& input : firstPreviousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) { if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG
nbExternalInputs++; nbExternalInputs++;
externalInput = input.first; externalInput = input.first;
externalInputId = input.second; externalInputId = input.second;
...@@ -638,14 +635,28 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg ...@@ -638,14 +635,28 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
if (nbExternalInputs > 1) { if (nbExternalInputs > 1) {
AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set");
} }
if (previousOutputNode->nbOutputs() != 1) {
if (oldG->inputNodes().size() > 1){
// one or no input has been identified. Checking every input points to the same source
for (const auto& previousInputNode : oldG->inputNodes()) {
for (const auto& input : previousInputNode->inputs()) {
if (oldNodes.find(input.first) == oldNodes.end()) {
if ( (externalInput != input.first) || (externalInputId != input.second) ) {
return false; // an inputNode points to an external Node different from the registered one
}
}
}
}
}
if (firstPreviousOutputNode->nbOutputs() != 1) {
return false; return false;
} }
// find Node to replicate output connections // find Node to replicate output connections
std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin());
auto copyOutputs = previousOutputNode->outputs(); auto copyOutputs = firstPreviousOutputNode->outputs();
// manage Views for newNodes // manage Views for newNodes
// only keep common views to each node for the new set // only keep common views to each node for the new set
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
...@@ -657,6 +668,8 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg ...@@ -657,6 +668,8 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
std::inserter(intersection, intersection.begin())); std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection; commonGraphViews = intersection;
} }
commonGraphViews.erase(oldG);
commonGraphViews.erase(newG);
// clean Nodes to replace // clean Nodes to replace
// Do not include common Nodes to avoid cleaning Producers linked to newNodes // Do not include common Nodes to avoid cleaning Producers linked to newNodes
...@@ -667,7 +680,7 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg ...@@ -667,7 +680,7 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
// copy output connections // copy output connections
for (IOIndex_t o = 0; o < previousOutputNode->nbOutputs(); ++o) { for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
auto outputPairs = copyOutputs[o]; auto outputPairs = copyOutputs[o];
for (const auto& onePair : outputPairs) { for (const auto& onePair : outputPairs) {
newOutputNode->addChild(onePair.first, o, onePair.second); newOutputNode->addChild(onePair.first, o, onePair.second);
...@@ -675,22 +688,32 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg ...@@ -675,22 +688,32 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
} }
// copy input connections // copy input connections
if (!newNodes.empty()) { if (!newNodes.empty()) {
std::shared_ptr<Node> newInputNode = (*(newG->inputNodes()).begin()); for (const auto& newInputNode : newG->inputNodes()) {
if (newInputNode->getNbFreeDataInputs() > 1) { IOIndex_t inputId = 0;
return false; for (const auto& input : newInputNode->inputs()) {
if (newNodes.find(input.first) == newNodes.end()) {
externalInput->addChild(newInputNode, externalInputId, inputId);
}
inputId++;
}
} }
// one non-connected input in newNodes set
externalInput->addChild(newInputNode, externalInputId, newInputNode->getFirstFreeDataInput());
} }
// insert new Nodes in the right GraphViews // insert new Nodes in the right GraphViews
for (auto& graphPtr : commonGraphViews) { for (const auto& graphPtr : commonGraphViews) {
graphPtr->add(newNodes, false); graphPtr->add(newNodes, false);
if (newNodes.empty()) { if (newNodes.empty()) {
graphPtr->updateInputNodes(); graphPtr->updateInputNodes();
graphPtr->updateOutputNodes(); graphPtr->updateOutputNodes();
} }
} }
for (const auto& node : oldNodes) {
node->removeView(oldG);
}
for (const auto& node : newNodes) {
node->removeView(newG);
}
return true; return true;
} }
......
...@@ -364,7 +364,7 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { ...@@ -364,7 +364,7 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias}); std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias});
// replace // replace
g->replace(nodeToReplace, newNodes); GraphView::replace(nodeToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC}));
REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias))); REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias)));
...@@ -381,19 +381,42 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { ...@@ -381,19 +381,42 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
g->add({r1, r2, r3, r4}); g->add({r1, r2, r3, r4});
auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3}); auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3});
auto newNodes = std::set<std::shared_ptr<Node>>({}); auto newNodes = std::set<std::shared_ptr<Node>>({});
g->replace(nodesToReplace, newNodes); GraphView::replace(nodesToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4); REQUIRE((r1->output(0))[0].first == r4);
} }
// SECTION("replace for tiling") {
// std::shared_ptr<GraphView> g = std::make_shared<GraphView>(); SECTION("replace for tiling") {
// auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
// auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
// auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv"); auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
// auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv");
// otherInput->addChild() auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
// } otherInput->addChild(other1);
other1->addChild(myConv);
myConv->addChild(other2);
g->add({other1, myConv, other2});
// create tiled Conv
auto conv1 = GenericOperator("Conv", 1, 1, 1, "myConv1");
auto conv2 = GenericOperator("Conv", 1, 1, 1, "myConv2");
auto conv3 = GenericOperator("Conv", 1, 1, 1, "myConv3");
auto conv4 = GenericOperator("Conv", 1, 1, 1, "myConv4");
auto concat = GenericOperator("Concat", 4, 4, 1, "myConcat");
conv1->addChild(concat);
conv2->addChild(concat);
conv3->addChild(concat);
conv4->addChild(concat);
GraphView::replace({myConv}, {conv1, conv2, conv3, conv4, concat});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, conv1, conv2, conv3, conv4, concat, other2}));
GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
}
} }
TEST_CASE("[GraphView] clone") { TEST_CASE("[GraphView] clone") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment