diff --git a/include/aidge/graph/Testing.hpp b/include/aidge/graph/Testing.hpp index b448c3672493b8460eac32c8362fa4f582264889..06bbd0f554bab0b9a5ef123dc287c2397258fae3 100644 --- a/include/aidge/graph/Testing.hpp +++ b/include/aidge/graph/Testing.hpp @@ -23,9 +23,11 @@ namespace Aidge { /** - * Random DAG generator + * Random (directed) graph generator */ -struct RandomDAG { +struct RandomGraph { + /// @brief If true, the generated graph is a DAG (no cycle) + bool acyclic = false; /// @brief Connection density (between 0 and 1) float density = 0.5; /// @brief Max number of inputs per node (regardless if they are connected or not) diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp index 6685b8862ba8bdc9e0ab7661ab8a445c15b27203..349b1e3456437e6be3fec251626107ed99679ec1 100644 --- a/src/graph/Testing.cpp +++ b/src/graph/Testing.cpp @@ -12,7 +12,7 @@ #include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" -std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::mt19937::result_type seed, size_t nbNodes) const { +std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomGraph::gen(std::mt19937::result_type seed, size_t nbNodes) const { std::mt19937 gen(seed); std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn); std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut); @@ -39,7 +39,7 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomDAG::gen(std::m } for (size_t i = 0; i < nbNodes; ++i) { - for (size_t j = i + 1; j < nbNodes; ++j) { + for (size_t j = (acyclic) ? i + 1 : 0; j < nbNodes; ++j) { for (size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { if (dLink(gen)) { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 392c2478012eaabe5fb8d508aa014a4592537cc5..66f110cb13c4c68f062c9393817ebd589c40f6e6 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -40,7 +40,7 @@ public: } }; -TEST_CASE("genRandomDAG") { +TEST_CASE("genRandomGraph") { const size_t nbTests = 100; size_t nbUnicity = 0; @@ -48,14 +48,14 @@ TEST_CASE("genRandomDAG") { std::random_device rd; const std::mt19937::result_type seed(rd()); - RandomDAG randDAG; + RandomGraph randGraph; const auto g1 = std::make_shared<GraphView_Test>("g1"); - const bool unicity1 = g1->add(randDAG.gen(seed, 10)); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); const auto g2 = std::make_shared<GraphView>("g2"); - const bool unicity2 = g2->add(randDAG.gen(seed, 10)); + const bool unicity2 = g2->add(randGraph.gen(seed, 10)); - g1->save("./genRandomDAG1"); - g2->save("./genRandomDAG2"); + g1->save("./genRandomGraph1"); + g2->save("./genRandomGraph2"); REQUIRE(unicity1 == unicity2); @@ -96,9 +96,9 @@ TEST_CASE("clone") { std::random_device rd; const std::mt19937::result_type seed(rd()); - RandomDAG randDAG; + RandomGraph randGraph; const auto g1 = std::make_shared<GraphView>("g1"); - g1->add(randDAG.gen(seed, 10)); + g1->add(randGraph.gen(seed, 10)); const auto g2 = g1->clone(); @@ -126,16 +126,16 @@ TEST_CASE("clone_with_delete") { std::mt19937::result_type seed(42); for (int test = 0; test < nbTests; ++test) { - RandomDAG randDAG; - randDAG.types = {"Fictive", "DelFictive"}; - randDAG.typesWeights = {0.9, 0.1}; + RandomGraph randGraph; + randGraph.types = {"Fictive", "DelFictive"}; + randGraph.typesWeights = {0.9, 0.1}; const auto g1 = std::make_shared<GraphView>("g1"); - const bool unicity1 = g1->add(randDAG.gen(seed, 10)); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); if (unicity1) { - randDAG.omitType = "DelFictive"; + randGraph.omitType = "DelFictive"; const auto g2 = std::make_shared<GraphView>("g2"); - const bool unicity2 = g2->add(randDAG.gen(seed, 10)); + const bool unicity2 = g2->add(randGraph.gen(seed, 10)); g1->save("./clone_with_delete1"); g2->save("./clone_with_delete2"); @@ -167,11 +167,11 @@ TEST_CASE("remove") { std::random_device rd; const std::mt19937::result_type seed(rd()); - RandomDAG randDAG; - randDAG.types = {"Fictive", "DelFictive"}; - randDAG.typesWeights = {0.8, 0.2}; + RandomGraph randGraph; + randGraph.types = {"Fictive", "DelFictive"}; + randGraph.typesWeights = {0.8, 0.2}; const auto g1 = std::make_shared<GraphView>("g1"); - const bool unicity1 = g1->add(randDAG.gen(seed, 10)); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); if (unicity1) { g1->save("./remove1_before"); @@ -185,9 +185,9 @@ TEST_CASE("remove") { } } - randDAG.omitType = "DelFictive"; + randGraph.omitType = "DelFictive"; const auto g2 = std::make_shared<GraphView>("g2"); - g2->add(randDAG.gen(seed, 10)); + g2->add(randGraph.gen(seed, 10)); g1->save("./remove1"); g2->save("./remove2");