diff --git a/include/aidge/graph/Testing.hpp b/include/aidge/graph/Testing.hpp index 7de03daf9ea0a3e715198c834cd309fc1a10a9f4..b448c3672493b8460eac32c8362fa4f582264889 100644 --- a/include/aidge/graph/Testing.hpp +++ b/include/aidge/graph/Testing.hpp @@ -40,6 +40,8 @@ struct RandomDAG { std::vector<std::string> types = {"Fictive"}; /// @brief Weights of each node type, used to compute the probability of generating this type std::vector<float> typesWeights = {1.0}; + /// @brief Type of node that should be omitted from the generated topology + std::string omitType; /** * Generate a DAG according to the parameters of the class. diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp index bb029520f25bcbee4a6a9d6530e3927aeed7583f..8b075d4088b0dcade503d105b6c8252a5c3d6eb2 100644 --- a/src/graph/Testing.cpp +++ b/src/graph/Testing.cpp @@ -20,9 +20,11 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::m std::discrete_distribution<> dType(typesWeights.begin(), typesWeights.end()); std::vector<std::pair<int, int>> nbIOs; + std::vector<std::string> nodesType; for (size_t i = 0; i < nbNodes; ++i) { const auto nbIn = 1 + dIn(gen); nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen))); + nodesType.push_back(types[dType(gen)]); } std::vector<int> nodesSeq(nbNodes); @@ -32,9 +34,8 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::m std::vector<NodePtr> nodes(nbNodes, nullptr); for (auto idx : nodesSeq) { - const std::string type = types[dType(gen)]; - const std::string name = type + std::to_string(idx); - nodes[idx] = GenericOperator(type.c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str()); + const std::string name = nodesType[idx] + std::to_string(idx); + nodes[idx] = GenericOperator(nodesType[idx].c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str()); } for (size_t i = 0; i < nbNodes; ++i) { @@ -42,14 +43,28 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::m for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { if (dLink(gen)) { - nodes[i]->addChild(nodes[j], outId, inId); + if (nodes[i]->type() != omitType && nodes[j]->type() != omitType) { + nodes[i]->addChild(nodes[j], outId, inId); + } break; } } } } } - return std::make_pair(nodes[0], std::set<NodePtr>(nodes.begin(), nodes.end())); + + NodePtr rootNode = nullptr; + std::set<NodePtr> nodesSet; + for (size_t i = 0; i < nbNodes; ++i) { + if (nodes[i]->type() != omitType) { + if (rootNode == nullptr) { + rootNode = nodes[i]; + } + nodesSet.insert(nodes[i]); + } + } + + return std::make_pair(rootNode, nodesSet); } std::string Aidge::nodePtrToType(NodePtr node) { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 36203532ed4fefca1421acab5e5827dbdc4ad9f7..f77e5ac635934b129ddca0de2509288ebba854c9 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -95,24 +95,27 @@ TEST_CASE("clone_with_delete") { randDAG.types = {"Fictive", "DelFictive"}; randDAG.typesWeights = {0.9, 0.1}; const auto g1 = std::make_shared<GraphView>("g1"); - g1->add(randDAG.gen(seed, 10)); + const bool unicity1 = g1->add(randDAG.gen(seed, 10)); - g1->save("./clone_with_delete1"); + if (unicity1) { + randDAG.omitType = "DelFictive"; + const auto g2 = std::make_shared<GraphView>("g2"); + const bool unicity2 = g2->add(randDAG.gen(seed, 10)); - try { - const auto g2 = g1->cloneCallback(&nodeDel); + g1->save("./clone_with_delete1"); + g2->save("./clone_with_delete2"); - if (g2->getNodes().size() < g1->getNodes().size()) { - g2->save("./clone_with_delete2"); + try { + const auto gCloned = g1->cloneCallback(&nodeDel); - // These tests are not necessarily true if the deleted node is an input/output node! - //REQUIRE(g1->getOrderedInputs().size() == g2->getOrderedInputs().size()); - //REQUIRE(g1->getOrderedOutputs().size() == g2->getOrderedOutputs().size()); + REQUIRE(nodePtrTo(gCloned->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(gCloned->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + REQUIRE(nodePtrTo(gCloned->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); ++nbClonedWithDelete; } - } - catch (const std::runtime_error& error) { - // pass + catch (const std::runtime_error& error) { + // pass + } } }