Skip to content
Snippets Groups Projects
Commit a1cbe41e authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Implement insertParent function

parent 46767f41
No related branches found
No related tags found
1 merge request!10Implement insertParent function in GraphView
Pipeline #31876 passed
...@@ -320,8 +320,20 @@ public: ...@@ -320,8 +320,20 @@ public:
void link(std::string name1_inID, std::string name2_outID); void link(std::string name1_inID, std::string name2_outID);
void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes, /**
IOIndex_t tensorIdx); * @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
*
* @param childNode
* @param newParentNode
* @param childInputTensorIdx
* @param newParentInputTensorIdx
* @param newParentOutputTensorIdx
*/
void insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx);
/** /**
* @brief Replace the current GraphView with the set of given Nodes if possible * @brief Replace the current GraphView with the set of given Nodes if possible
......
...@@ -522,12 +522,24 @@ void Aidge::GraphView::link(std::string /*name1_inID*/, ...@@ -522,12 +522,24 @@ void Aidge::GraphView::link(std::string /*name1_inID*/,
printf("Not implemented yet.\n"); printf("Not implemented yet.\n");
} }
void Aidge::GraphView::insert(Node & /*newNode*/, Node & /*inNode*/, void Aidge::GraphView::insertParent(NodePtr childNode,
std::initializer_list<Node> /*outNodes*/, NodePtr newParentNode,
IOIndex_t /*tensorIdx*/) { IOIndex_t childInputTensorIdx,
printf("Not implemented yet.\n"); IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
// Add child
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
} }
bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
// TODO : only supports one input/output node for now // TODO : only supports one input/output node for now
assert(mNodes.size()>0 && "There must be at least one Node to replace"); assert(mNodes.size()>0 && "There must be at least one Node to replace");
......
...@@ -330,4 +330,48 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { ...@@ -330,4 +330,48 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4); REQUIRE((r1->output(0))[0].first == r4);
} }
}
TEST_CASE("[core/graph] GraphView(insertParent)") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(32, 64, {1, 1}, "conv3");
auto g = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g->add(conv1);
g->addChild(conv2, conv1, 0);
g->addChild(conv3, conv1, 0);
g->save("graphForwardDims");
g->forwardDims();
auto newConv = Conv(32, 32, {1, 1}, "newConv");
SECTION("Check insertParent conv2 then insertParent conv3") {
g->insertParent(conv2, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children = {conv3, newConv};
std::set<NodePtr> expectedNewConvChildren = {conv2};
REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren);
REQUIRE((conv1->getChildren()) == expectedConv1Children);
g->insertParent(conv3, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children2 = {newConv};
std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3};
REQUIRE(conv1->getOperator()->getOutput(0) != conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren2);
REQUIRE((conv1->getChildren()) == expectedConv1Children2);
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment