diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 4c5b3bd4c88c2cc6ced816d0caa259a8cab644e5..7c2647a13d94a6097f638031d52c5f58ca2c2382 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -153,6 +153,7 @@ protected: /** @brief List of nodes ordered by their */ std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; size_t mStaticScheduleStep = 0; + mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; }; } // namespace Aidge diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 14e166402039230a283ce617e4997c9ad099eed9..265611c944817d9ecbbd132215f63de35960e5a3 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -173,7 +173,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) "Input index ({}) is out of bound ({}) for node {} (of type {})", inId, nbInputs(), name(), type()); if (mIdOutParents[inId] != gk_IODefaultIndex) { - fmt::print("Warning: filling a Tensor already attributed\n"); + Log::notice("Notice: filling a Tensor already attributed"); auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 906b3fa719ab2b0187e42be5ac239df2382c8749..5a056ff3a6ab3609b0b5f74ad87a7583a891b4e5 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -71,6 +71,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S Log::debug("List of consumers with their priors:"); std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> priorConsumers; + mPriorCache.clear(); for (const auto& consumer : consumers) { Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); @@ -280,6 +281,8 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S Log::debug("********************"); } while (!consumers.empty()); + mPriorCache.clear(); + if (!consumers.empty()) { Log::warn("Remaining consumers: possible dead-lock"); } @@ -633,6 +636,11 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { + const auto priorCache = mPriorCache.find(node); + if (priorCache != mPriorCache.end()) { + return priorCache->second; + } + PriorProducersConsumers prior; IOIndex_t inputIdx = 0; @@ -673,5 +681,6 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } + mPriorCache.insert(std::make_pair(node, prior)); return prior; } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 75a0daed6cc63cefb9c13412ba694c30e0dbc107..bacea379ea0883bca498c6d84a1acaf69db8adef 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -30,7 +30,8 @@ using namespace Aidge; TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { - const size_t nbTests = 100; + const size_t nbTests = 1; + const size_t graphSize = 1000; size_t nbUnicity = 0; for (int test = 0; test < nbTests; ++test) { @@ -40,7 +41,9 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { RandomGraph randGraph; randGraph.acyclic = true; const auto g1 = std::make_shared<GraphView>("g1"); - const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + fmt::print("gen graph of {} nodes...\n", graphSize); + const bool unicity1 = g1->add(randGraph.gen(seed, graphSize)); + fmt::print("gen graph finished\n", graphSize); if (unicity1) { for (auto& node : g1->getNodes()) { @@ -57,8 +60,10 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { g1->save("schedule"); g1->forwardDims(); + fmt::print("gen scheduling...\n"); auto scheduler = SequentialScheduler(g1); scheduler.generateScheduling(); + fmt::print("gen scheduling finished\n"); const auto sch = scheduler.getStaticScheduling(); const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); @@ -69,7 +74,8 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { [&namePtrTable](auto val){ return namePtrTable.at(val); }); fmt::print("schedule: {}\n", nodesName); - REQUIRE(sch.size() == 10 + orderedInputs.size()); + REQUIRE(sch.size() == graphSize + orderedInputs.size()); + ++nbUnicity; } } @@ -112,6 +118,7 @@ TEST_CASE("randomScheduling_tokens", "[Scheduler][randomGen]") { fmt::print("schedule: {}\n", nodesName); REQUIRE(sch.size() == 10 + orderedInputs.size()); + ++nbUnicity; } }