diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index fab9be91556c5ffc0bd446edcbc5abb80e99a1bb..315844858103cbce91049ec2195ff0a3bd7a9d81 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -1114,6 +1114,28 @@ void Aidge::GraphView::insertParent(NodePtr childNode, add(newParentNode); } +/** + * Inputs conditions: + * | old \ new | 1 node, 1 input | >1 node, 1 input | 1 node, >1 inputs | >1 node, >1 inputs | + * | ------------------- | ---------------- | ----------------- | ------------------ | ------------------ | + * | 1 node, 1 input | trivial | trivial | broadcast | broadcast | + * | >1 node, 1 input | trivial | trivial | broadcast | broadcast | + * | 1 node, >1 inputs | (take first) | (take first) | same order | X | + * | >1 node, >1 inputs | X | X | X | X | + * + * Outputs conditions: + * | old \ new | 1 node, 1 output | >1 node, 1 output | 1 node, >1 outputs | >1 node, >1 outputs | + * | ------------------- | ---------------- | ----------------- | ------------------ | ------------------- | + * | 1 node, 1 output | trivial | trivial | take first | X | + * | >1 node, 1 output | trivial | trivial | take first | X | + * | 1 node, >1 outputs | (take first) | (take first) | same order | X | + * | >1 node, >1 outputs | X | X | X | X | + * + * Only the X cases cannot possibly be resolved deterministically with sets of node. + * These cases are therefore forbidden for the set-based `replace()` interface. + * The remaining cases are handled by the GraphView-based `replace()` interface. + * If they are not supported, the function returns false. + */ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { // (1) create GraphViews from both sets of Nodes auto oldG = std::make_shared<GraphView>("oldG"); @@ -1121,6 +1143,14 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); + AIDGE_ASSERT(!((oldNodes.size() > 1 && oldG->getOrderedInputs().size() > 1) || (newNodes.size() > 1 && newG->getOrderedInputs().size() > 1 && oldG->getOrderedInputs().size() > 1)), + "GraphView::replace(): don't know how to match {} input(s) from {} node(s) (old set) to {} input(s) from {} node(s) (new set). Use GraphView instead of set in this case.", + oldG->getOrderedInputs().size(), oldNodes.size(), newG->getOrderedInputs().size(), newNodes.size()); + + AIDGE_ASSERT(!((oldNodes.size() > 1 && oldG->getOrderedOutputs().size() > 1) || (newNodes.size() > 1 && newG->getOrderedOutputs().size() > 1)), + "GraphView::replace(): don't know how to match {} output(s) from {} node(s) (old set) to {} output(s) from {} node(s) (new set). Use GraphView instead of set in this case.", + oldG->getOrderedOutputs().size(), oldNodes.size(), newG->getOrderedOutputs().size(), newNodes.size()); + return GraphView::replace(oldG, newG); } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 5bd435e28718d663519e504995fd5b030913d254..d2f41269eb46e6935f11fbf2ac0ab7dcf1232945 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -659,9 +659,12 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, conv1, conv2, conv3, conv4, concat, other2})); - GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv}); + // This doesn't make sense in the general case: replace() cannot possibly + // know how to match the inputs of the 4 conv with the single input of myCond + // The implicit assumption here is that they are the same! + //GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv}); - REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); + //REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); } SECTION("replace same input category 1") {