diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 859956efb8ccd8d20fef2a09378fa839ca217f9a..c786ee1f38935187b725e3253bf9c22e7a597ab3 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -448,6 +448,14 @@ public: */ IOIndex_t getNbFreeDataInputs() const; +protected: + /** + * @brief Update inputs/outputs of the GraphView, with no particular order. + * This function DOES NOT preserve inputs/outputs order and should NOT BE USED. + * It is here only to leave time to adapt the replace() function. + */ + [[deprecated]] void updateInputsOutputsNodes(); + private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT @@ -461,13 +469,6 @@ private: */ IOIndex_t getNbDataInputs() const; - /** - * @brief Update inputs/outputs of the GraphView, with no particular order. - * This function DOES NOT preserve inputs/outputs order and should NOT BE USED. - * It is here only to leave time to adapt the replace() function. - */ - [[deprecated]] void updateInputsOutputsNodes(); - /** * @brief Automatically update GraphView inputs/outputs with a new Node, checking if * it this Node becomes an input/output for the graph and if previous inputs are still diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index ec02f2dd8f9718fbdd72ce8915991088218c607a..9e5c1ef8648e011e1a54158167696efc48a9236f 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -92,12 +92,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { IOIndex_t outputIdx = 0; for (auto childs : node_ptr->getOrderedChildren()) { for (auto child : childs) { - if (child != nullptr && mNodes.find(child) != mNodes.end()) { + if (child != nullptr) { IOIndex_t inputIdx = 0; - for (auto pa_ptr : child->getParents()) { - if (pa_ptr == node_ptr) { - std::fprintf(fp, "%s-->|%u..%u|%s\n", namePtrTable[node_ptr].c_str(), - outputIdx, inputIdx, namePtrTable[child].c_str()); + for (auto parent : child->inputs()) { + if (parent.first == node_ptr && parent.second == outputIdx) { + if (mNodes.find(child) != mNodes.end()) { + std::fprintf(fp, "%s-->|%u..%u|%s\n", namePtrTable[node_ptr].c_str(), + outputIdx, inputIdx, namePtrTable[child].c_str()); + } + else if (verbose) { + std::fprintf(fp, "%s-->|%u..%u|%p:::externalCls\n", namePtrTable[node_ptr].c_str(), + outputIdx, inputIdx, static_cast<void*>(child.get())); + } break; } ++inputIdx; @@ -125,6 +131,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { std::fprintf(fp, "classDef inputCls fill:#afa\n"); std::fprintf(fp, "classDef outputCls fill:#ffa\n"); + std::fprintf(fp, "classDef externalCls fill:#ccc\n"); std::fprintf(fp, "classDef rootCls stroke:#f00\n"); if (verbose) { @@ -623,26 +630,28 @@ void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnab if (includeLearnableParam) { for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) { auto inputI = nodePtr->input(i); - bool removeNode = true; - for (const auto& parentOutput : inputI.first->outputs()) { - for (const auto& childOfParentOutput : parentOutput) { - // only remove the learnable parameter if not related to any other Node in the GraphView - if (childOfParentOutput.first != nodePtr) { - removeNode = false; - break; + if (inputI.first != nullptr) { + bool removeNode = true; + for (const auto& parentOutput : inputI.first->outputs()) { + for (const auto& childOfParentOutput : parentOutput) { + // only remove the learnable parameter if not related to any other Node in the GraphView + if (childOfParentOutput.first != nodePtr) { + removeNode = false; + break; + } } } - } - if (removeNode) { - // assert Learnable Parameter in the GraphView scope - if (mNodes.find(inputI.first) != mNodes.end()) { - mNodes.erase(inputI.first); - inputI.first->removeView(shared_from_this()); - } - if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } + if (removeNode) { + // assert Learnable Parameter in the GraphView scope + if (mNodes.find(inputI.first) != mNodes.end()) { + mNodes.erase(inputI.first); + inputI.first->removeView(shared_from_this()); + } + if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } - // check if the node was an input/output node - updateInputsOutputsDelete(inputI.first); + // check if the node was an input/output node + updateInputsOutputsDelete(inputI.first); + } } } } @@ -650,11 +659,11 @@ void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnab if (mNodes.find(nodePtr) != mNodes.end()) { mNodes.erase(nodePtr); nodePtr->removeView(shared_from_this()); + + // check if the nodePtr was an input/output node + updateInputsOutputsDelete(nodePtr); } if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } - - // check if the nodePtr was an input/output node - updateInputsOutputsDelete(nodePtr); } @@ -942,7 +951,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val); if (iter != mInputNodes.end()) { - // The first old (removed) input becomes the insertion point for newNode GraphView inputs + // The first old (removed) input becomes the insertion point for new GraphView inputs if (std::distance(newInputsInsertionPoint, iter) <= 0) { newInputsInsertionPoint = mInputNodes.erase(iter); } @@ -963,9 +972,10 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo // If newNode was connected to it if (pa_ptr == deletedNode) { const auto val = std::make_pair(ch_ptr, inputIdx); - AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()); - newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); - newInputsInsertionPoint = std::next(newInputsInsertionPoint); + if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) { + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); + } } ++inputIdx; } @@ -994,27 +1004,26 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo // Add parent node outputs that become GraphView output following the removal of the node // Outputs addition order follows deletedNode inputs order for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { - if (parent == nullptr) { - continue; - } - - IOIndex_t outputIdx = 0; - for (auto orderedChilds : parent->getOrderedChildren()) { - bool noInsideConnection = true; - for (auto ch_ptr : orderedChilds) { - if (mNodes.find(ch_ptr) != mNodes.end()) { - noInsideConnection = false; - break; + if (parent != nullptr && mNodes.find(parent) != mNodes.end()) { + IOIndex_t outputIdx = 0; + for (auto orderedChilds : parent->getOrderedChildren()) { + bool noInsideConnection = true; + for (auto ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.end()) { + noInsideConnection = false; + break; + } } - } - if (noInsideConnection) { - const auto val = std::make_pair(parent, outputIdx); - AIDGE_INTERNAL_ASSERT(std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()); - newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); - newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); + if (noInsideConnection) { + const auto val = std::make_pair(parent, outputIdx); + if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) { + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); + } + } + ++outputIdx; } - ++outputIdx; } } } @@ -1038,14 +1047,16 @@ void Aidge::GraphView::updateInputsOutputsNodes() { for (const std::shared_ptr<Node>& go_ptr : mNodes) { IOIndex_t outputIdx = 0; for (auto orderedChilds : go_ptr->getOrderedChildren()) { + bool noInsideConnection = true; for (auto ch_ptr : orderedChilds) { - if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph - mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx)); + if (mNodes.find(ch_ptr) != mNodes.end()) { + noInsideConnection = false; + break; } } - if (orderedChilds.empty()) { - // an output linked to nothing - mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx)); + + if (noInsideConnection) { + mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx)); } ++outputIdx; } diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp index 8b075d4088b0dcade503d105b6c8252a5c3d6eb2..6685b8862ba8bdc9e0ab7661ab8a445c15b27203 100644 --- a/src/graph/Testing.cpp +++ b/src/graph/Testing.cpp @@ -35,7 +35,7 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::m std::vector<NodePtr> nodes(nbNodes, nullptr); for (auto idx : nodesSeq) { const std::string name = nodesType[idx] + std::to_string(idx); - nodes[idx] = GenericOperator(nodesType[idx].c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str()); + nodes[idx] = GenericOperator(nodesType[idx].c_str(), nbIOs[idx].first, 0, nbIOs[idx].second, name.c_str()); } for (size_t i = 0; i < nbNodes; ++i) { @@ -43,9 +43,31 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::m for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { if (dLink(gen)) { + // Warning: connections can be set multiple time for the + // same node input! In this case, the previous connection + // is overwritten. This is the expected behavior. + nodes[i]->addChild(nodes[j], outId, inId); + if (nodes[i]->type() == omitType || nodes[j]->type() == omitType) { + // Let nodes[i]->addChild() overwrite the previous connection. + // Now we remove the new one! + nodes[i]->removeChild(nodes[j], outId); + nodes[j]->removeParent(inId); + } +/* + // Alternative: only add child if no node is omitted + // and remove the potential previous connection, like this: if (nodes[i]->type() != omitType && nodes[j]->type() != omitType) { nodes[i]->addChild(nodes[j], outId, inId); } + else { + const auto prevIn = nodes[j]->input(inId); + + if (prevIn.first != nullptr) { + prevIn.first->removeChild(nodes[j], prevIn.second); + nodes[j]->removeParent(inId); + } + } +*/ break; } } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index f77e5ac635934b129ddca0de2509288ebba854c9..487c2dfef10063f7e0a0c10dcdd1acc904740c68 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -27,6 +27,19 @@ using namespace Aidge; +class GraphView_Test : public GraphView { +public: + GraphView_Test(std::string name="") + : GraphView(name) + { + // ctor + } + + void updateInputsOutputsNodes_Test() { + GraphView::updateInputsOutputsNodes(); + } +}; + TEST_CASE("genRandomDAG") { const size_t nbTests = 100; size_t nbUnicity = 0; @@ -36,7 +49,7 @@ TEST_CASE("genRandomDAG") { const std::mt19937::result_type seed(rd()); RandomDAG randDAG; - const auto g1 = std::make_shared<GraphView>("g1"); + const auto g1 = std::make_shared<GraphView_Test>("g1"); const bool unicity1 = g1->add(randDAG.gen(seed, 10)); const auto g2 = std::make_shared<GraphView>("g2"); const bool unicity2 = g2->add(randDAG.gen(seed, 10)); @@ -50,7 +63,26 @@ TEST_CASE("genRandomDAG") { REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); ++nbUnicity; + + // Test deprecated function + g1->updateInputsOutputsNodes_Test(); + // Check that inputs/outputs are the same regardless of the order + auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); + auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); + auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); + auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); + std::sort(orderedInputs1.begin(), orderedInputs1.end()); + std::sort(orderedInputs2.begin(), orderedInputs2.end()); + std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); + std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); + + REQUIRE(orderedInputs1 == orderedInputs2); + REQUIRE(orderedOutputs1 == orderedOutputs2); + REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); } } @@ -87,10 +119,13 @@ TEST_CASE("clone_with_delete") { const size_t nbTests = 100; size_t nbClonedWithDelete = 0; - for (int test = 0; test < nbTests; ++test) { - std::random_device rd; - const std::mt19937::result_type seed(rd()); + // Note: initial seed is chosen such that for nbTests=100, the generated + // graphs keep the same inputs/outputs despites the deleted nodes + // (meaning the deleted nodes are not input/output of the graph). + // Otherwise, the last two REQUIRE are not garanteed to be true! + std::mt19937::result_type seed(42); + for (int test = 0; test < nbTests; ++test) { RandomDAG randDAG; randDAG.types = {"Fictive", "DelFictive"}; randDAG.typesWeights = {0.9, 0.1}; @@ -117,11 +152,71 @@ TEST_CASE("clone_with_delete") { // pass } } + + ++seed; } printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests); } +TEST_CASE("remove") { + const size_t nbTests = 100; + size_t nbTested = 0; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomDAG randDAG; + randDAG.types = {"Fictive", "DelFictive"}; + randDAG.typesWeights = {0.8, 0.2}; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randDAG.gen(seed, 10)); + + if (unicity1) { + g1->save("./remove1_before"); + const auto nodes = g1->getNodes(); + int step = 1; + for (auto node : nodes) { + if (node->type() == "DelFictive") { + g1->remove(node, false); + g1->save("./remove1_after" + std::to_string(step)); + step++; + } + } + + randDAG.omitType = "DelFictive"; + const auto g2 = std::make_shared<GraphView>("g2"); + g2->add(randDAG.gen(seed, 10)); + + g1->save("./remove1"); + g2->save("./remove2"); + + REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + // Order not garanteed, because when a node is removed, it can create new GraphView inputs/outputs + // Their order thus depends on the deletion order! + //REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + //REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + + // Check that inputs/outputs are the same regardless of the order + auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); + auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); + auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); + auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); + std::sort(orderedInputs1.begin(), orderedInputs1.end()); + std::sort(orderedInputs2.begin(), orderedInputs2.end()); + std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); + std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); + + REQUIRE(orderedInputs1 == orderedInputs2); + REQUIRE(orderedOutputs1 == orderedOutputs2); + ++nbTested; + } + } + + printf("nbTested = %zu/%zu\n", nbTested, nbTests); +} + TEST_CASE("[core/graph] GraphView(Constructor)") { std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1");