diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 747785bf886889aed273c944904ddbb6198c4968..c737680bf3d9227161eed250c2cb52a443c37ab3 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -116,6 +116,7 @@ private: /** @brief List of nodes ordered by their */ std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; size_t mStaticScheduleStep = 0; + mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; }; } // namespace Aidge diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index c84e19df9fe65b5afddf250d1cec55f148797481..7da7032012f79719f7cf64abba9a98cb84f8018a 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -171,7 +171,6 @@ Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const { void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) { assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index"); if (mIdOutParents[inId] != gk_IODefaultIndex) { - fmt::print("Warning: filling a Tensor already attributed\n"); auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 4733385fc795de948f828f76ff13781c25f222f5..f6ef29698e34cf03125e23ac925f8a8d93321ff9 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -83,6 +83,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { if (verbose) fmt::print("List of consumers with their priors:\n"); std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> priorConsumers; + mPriorCache.clear(); for (const auto& consumer : consumers) { if (verbose) { @@ -567,12 +568,15 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { + const auto priorCache = mPriorCache.find(node); + if (priorCache != mPriorCache.end()) { + return priorCache->second; + } + PriorProducersConsumers prior; IOIndex_t inputIdx = 0; - std::cout << *node << std::endl; for (const auto& parent : node->inputs()) { - std::cout << "parent.first " << *parent.first << std::endl; if (parent.first && (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > parent.first->getOperator()->getNbProducedData(parent.second)) @@ -609,5 +613,6 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } + mPriorCache.insert(std::make_pair(node, prior)); return prior; } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 66e0d4aeffa21c685444772d27b4663e275359d5..b84ac47566b4869fa2c8135245c392c5aa362255 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -21,8 +21,8 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Testing.hpp" #include "aidge/graph/OpArgs.hpp" +#include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/scheduler/Scheduler.hpp" @@ -30,53 +30,103 @@ using namespace Aidge; TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { - const size_t nbTests = 10; - size_t nbUnicity = 0; - std::uniform_int_distribution<std::size_t> nb_nodes_dist(70, 100); - - for (int test = 0; test < nbTests; ++test) { - std::random_device rd; - const std::mt19937::result_type seed(rd()); - std::mt19937 gen(rd()); - - RandomGraph randGraph; - randGraph.acyclic = true; - const auto g1 = std::make_shared<GraphView>("g1"); - // const size_t nb_nodes = nb_nodes_dist(gen); - const size_t nb_nodes = 85; - const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); - g1->save("test_graph_"+std::to_string(test)); - - if (unicity1) { - for (auto& node : g1->getNodes()) { - std::static_pointer_cast<GenericOperator_Op>(node->getOperator())->setComputeOutputDims(GenericOperator_Op::InputIdentity(0, node->nbOutputs())); - } - - const auto orderedInputs = g1->getOrderedInputs(); - for (const auto& input : orderedInputs) { - auto prod = Producer({16, 32}); - prod->addChild(input.first, 0, input.second); - g1->add(prod); - } - - g1->save("schedule"); - g1->forwardDims(); - - auto scheduler = SequentialScheduler(g1); - scheduler.generateScheduling(true); - const auto sch = scheduler.getStaticScheduling(); - - const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); - - std::vector<std::string> nodesName; - std::transform(sch.begin(), sch.end(), - std::back_inserter(nodesName), - [&namePtrTable](auto val){ return namePtrTable.at(val); }); - - fmt::print("schedule: {}\n", nodesName); - CHECK(sch.size() == nb_nodes + orderedInputs.size()); + const size_t nbTests = 10; + size_t nbUnicity = 0; + std::uniform_int_distribution<std::size_t> nb_nodes_dist(100, 500); + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + std::mt19937 gen(rd()); + + RandomGraph randGraph; + const auto g1 = std::make_shared<GraphView>("g1"); + const size_t nb_nodes = nb_nodes_dist(gen); + + SECTION("Acyclic Graph") { + fmt::print("gen acyclic graph of {} nodes...\n", nb_nodes); + randGraph.acyclic = true; + + const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); + // g1->save("test_graph_" + std::to_string(test)); + + if (unicity1) { + for (auto &node : g1->getNodes()) { + std::static_pointer_cast<GenericOperator_Op>(node->getOperator()) + ->setComputeOutputDims( + GenericOperator_Op::InputIdentity(0, node->nbOutputs())); + } + + const auto orderedInputs = g1->getOrderedInputs(); + for (const auto &input : orderedInputs) { + auto prod = Producer({16, 32}); + prod->addChild(input.first, 0, input.second); + g1->add(prod); } + + g1->save("schedule"); + g1->forwardDims(); + + fmt::print("gen scheduling...\n"); + auto scheduler = SequentialScheduler(g1); + scheduler.generateScheduling(); + fmt::print("gen scheduling finished\n"); + const auto sch = scheduler.getStaticScheduling(); + + const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); + + std::vector<std::string> nodesName; + std::transform( + sch.begin(), sch.end(), std::back_inserter(nodesName), + [&namePtrTable](auto val) { return namePtrTable.at(val); }); + + fmt::print("schedule: {}\n", nodesName); + REQUIRE(sch.size() == nb_nodes + orderedInputs.size()); + ++nbUnicity; + } } + SECTION("Cyclic graph") { + fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes); + randGraph.acyclic = false; + + const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); + // g1->save("test_graph_" + std::to_string(test)); + + if (unicity1) { + for (auto &node : g1->getNodes()) { + std::static_pointer_cast<GenericOperator_Op>(node->getOperator()) + ->setComputeOutputDims( + GenericOperator_Op::InputIdentity(0, node->nbOutputs())); + } + + const auto orderedInputs = g1->getOrderedInputs(); + for (const auto &input : orderedInputs) { + auto prod = Producer({16, 32}); + prod->addChild(input.first, 0, input.second); + g1->add(prod); + } + + g1->save("schedule"); + g1->forwardDims(); + + fmt::print("gen scheduling...\n"); + auto scheduler = SequentialScheduler(g1); + scheduler.generateScheduling(); + fmt::print("gen scheduling finished\n"); + const auto sch = scheduler.getStaticScheduling(); - fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); + const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); + + std::vector<std::string> nodesName; + std::transform( + sch.begin(), sch.end(), std::back_inserter(nodesName), + [&namePtrTable](auto val) { return namePtrTable.at(val); }); + + fmt::print("schedule: {}\n", nodesName); + REQUIRE(sch.size() == nb_nodes + orderedInputs.size()); + ++nbUnicity; + } + } + } + fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); }