Skip to content
Snippets Groups Projects
Commit 44afa3b4 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'graphView_insert' into 'main'

Implement insertParent function in GraphView

See merge request !10
parents 29ba2cbf 47ee48d5
No related branches found
No related tags found
1 merge request!10Implement insertParent function in GraphView
Pipeline #32176 passed
......@@ -320,8 +320,20 @@ public:
void link(std::string name1_inID, std::string name2_outID);
void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes,
IOIndex_t tensorIdx);
/**
* @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
*
* @param childNode Node that gets a new parent.
* @param newParentNode Inserted Node.
* @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
* @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
* @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
*/
void insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx);
/**
* @brief Replace the current GraphView with the set of given Nodes if possible
......
......@@ -519,12 +519,24 @@ void Aidge::GraphView::link(std::string /*name1_inID*/,
printf("Not implemented yet.\n");
}
void Aidge::GraphView::insert(Node & /*newNode*/, Node & /*inNode*/,
std::initializer_list<Node> /*outNodes*/,
IOIndex_t /*tensorIdx*/) {
printf("Not implemented yet.\n");
void Aidge::GraphView::insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
// Add child
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
}
bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
// TODO : only supports one input/output node for now
assert(mNodes.size()>0 && "There must be at least one Node to replace");
......
......@@ -330,4 +330,48 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4);
}
}
TEST_CASE("[core/graph] GraphView(insertParent)") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(32, 64, {1, 1}, "conv3");
auto g = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g->add(conv1);
g->addChild(conv2, conv1, 0);
g->addChild(conv3, conv1, 0);
g->save("graphForwardDims");
g->forwardDims();
auto newConv = Conv(32, 32, {1, 1}, "newConv");
SECTION("Check insertParent conv2 then insertParent conv3") {
g->insertParent(conv2, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children = {conv3, newConv};
std::set<NodePtr> expectedNewConvChildren = {conv2};
REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren);
REQUIRE((conv1->getChildren()) == expectedConv1Children);
g->insertParent(conv3, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children2 = {newConv};
std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3};
REQUIRE(conv1->getOperator()->getOutput(0) != conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren2);
REQUIRE((conv1->getChildren()) == expectedConv1Children2);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment