diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 747785bf886889aed273c944904ddbb6198c4968..c737680bf3d9227161eed250c2cb52a443c37ab3 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -116,6 +116,7 @@ private: /** @brief List of nodes ordered by their */ std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; size_t mStaticScheduleStep = 0; + mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; }; } // namespace Aidge diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index a9dec5a8fc21ca2a822ddb2addd076c66c30afed..2b1205c068d2c8510fb98e99c89b0988712fb193 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -173,7 +173,6 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) "Input index ({}) is out of bound ({}) for node {} (of type {})", inId, nbInputs(), name(), type()); if (mIdOutParents[inId] != gk_IODefaultIndex) { - fmt::print("Warning: filling a Tensor already attributed\n"); auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 2975538bc3271f4dbf6faea920be3a05452a0859..8febe3cd267a4b5e9ebd1f5cc03805279fa5c382 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -83,6 +83,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { if (verbose) fmt::print("List of consumers with their priors:\n"); std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> priorConsumers; + mPriorCache.clear(); for (const auto& consumer : consumers) { if (verbose) { @@ -566,6 +567,11 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { + const auto priorCache = mPriorCache.find(node); + if (priorCache != mPriorCache.end()) { + return priorCache->second; + } + PriorProducersConsumers prior; IOIndex_t inputIdx = 0; @@ -606,5 +612,6 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } + mPriorCache.insert(std::make_pair(node, prior)); return prior; } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 7e28f1fadc56855d266c1e8547261f5903f8c724..b928f408ee85f2b82c5721b574ac151ee0b782a3 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -21,8 +21,8 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Testing.hpp" #include "aidge/graph/OpArgs.hpp" +#include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/scheduler/Scheduler.hpp" @@ -30,48 +30,104 @@ using namespace Aidge; TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { - const size_t nbTests = 100; - size_t nbUnicity = 0; - - for (int test = 0; test < nbTests; ++test) { - std::random_device rd; - const std::mt19937::result_type seed(rd()); - - RandomGraph randGraph; - randGraph.acyclic = true; - const auto g1 = std::make_shared<GraphView>("g1"); - const bool unicity1 = g1->add(randGraph.gen(seed, 10)); - - if (unicity1) { - for (auto& node : g1->getNodes()) { - std::static_pointer_cast<GenericOperator_Op>(node->getOperator())->setComputeOutputDims(GenericOperator_Op::InputIdentity(0, node->nbOutputs())); - } - - const auto orderedInputs = g1->getOrderedInputs(); - for (const auto& input : orderedInputs) { - auto prod = Producer({16, 32}); - prod->addChild(input.first, 0, input.second); - g1->add(prod); - } - - g1->save("schedule"); - g1->forwardDims(); - - auto scheduler = SequentialScheduler(g1); - scheduler.generateScheduling(true); - const auto sch = scheduler.getStaticScheduling(); - - const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); - - std::vector<std::string> nodesName; - std::transform(sch.begin(), sch.end(), - std::back_inserter(nodesName), - [&namePtrTable](auto val){ return namePtrTable.at(val); }); - - fmt::print("schedule: {}\n", nodesName); - REQUIRE(sch.size() == 10 + orderedInputs.size()); + const size_t nbTests = 10; + size_t nbUnicity = 0; + std::uniform_int_distribution<std::size_t> nb_nodes_dist(100, 500); + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + std::mt19937 gen(rd()); + + RandomGraph randGraph; + const auto g1 = std::make_shared<GraphView>("g1"); + const size_t nb_nodes = nb_nodes_dist(gen); + + SECTION("Acyclic Graph") { + fmt::print("gen acyclic graph of {} nodes...\n", nb_nodes); + randGraph.acyclic = true; + + const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); + // g1->save("test_graph_" + std::to_string(test)); + + if (unicity1) { + for (auto &node : g1->getNodes()) { + std::static_pointer_cast<GenericOperator_Op>(node->getOperator()) + ->setComputeOutputDims( + GenericOperator_Op::InputIdentity(0, node->nbOutputs())); + } + + const auto orderedInputs = g1->getOrderedInputs(); + for (const auto &input : orderedInputs) { + auto prod = Producer({16, 32}); + prod->addChild(input.first, 0, input.second); + g1->add(prod); } + + g1->save("schedule"); + g1->forwardDims(); + + fmt::print("gen scheduling...\n"); + auto scheduler = SequentialScheduler(g1); + scheduler.generateScheduling(); + fmt::print("gen scheduling finished\n"); + const auto sch = scheduler.getStaticScheduling(); + + const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); + + std::vector<std::string> nodesName; + std::transform( + sch.begin(), sch.end(), std::back_inserter(nodesName), + [&namePtrTable](auto val) { return namePtrTable.at(val); }); + + fmt::print("schedule: {}\n", nodesName); + REQUIRE(sch.size() == nb_nodes + orderedInputs.size()); + ++nbUnicity; + } } + // SECTION("Cyclic graph") { + // fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes); + // randGraph.acyclic = false; + // randGraph.types={"Memorize"}; + + // const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); + // // g1->save("test_graph_" + std::to_string(test)); + + // if (unicity1) { + // for (auto &node : g1->getNodes()) { + // std::static_pointer_cast<GenericOperator_Op>(node->getOperator()) + // ->setComputeOutputDims( + // GenericOperator_Op::InputIdentity(0, node->nbOutputs())); + // } + + // const auto orderedInputs = g1->getOrderedInputs(); + // for (const auto &input : orderedInputs) { + // auto prod = Producer({16, 32}); + // prod->addChild(input.first, 0, input.second); + // g1->add(prod); + // } + + // g1->save("schedule"); + // g1->forwardDims(); + + // fmt::print("gen scheduling...\n"); + // auto scheduler = SequentialScheduler(g1); + // scheduler.generateScheduling(); + // fmt::print("gen scheduling finished\n"); + // const auto sch = scheduler.getStaticScheduling(); + + // const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); + + // std::vector<std::string> nodesName; + // std::transform( + // sch.begin(), sch.end(), std::back_inserter(nodesName), + // [&namePtrTable](auto val) { return namePtrTable.at(val); }); - fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); + // fmt::print("schedule: {}\n", nodesName); + // REQUIRE(sch.size() == nb_nodes + orderedInputs.size()); + // ++nbUnicity; + // } + // } + } + fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); }