diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index de2a7b6aae5357d9a1304ec2b718a475abc1ea43..908f56295887bd2fbed3350a026045a4ab6b21d9 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -456,7 +456,17 @@ private: * @param inId index for adding the parent. */ void addParent(const NodePtr otherNode, const IOIndex_t inId); + + // OPERATOR FUNCTIONNAL but commented out to avoid iostream inclusion + // /** + // * @brief operator<< overload to ease print & debug of nodes + // * @param[inout] ostream to print to + // * @param[in] n node to print + // */ + // friend std::ostream& operator << (std::ostream& os, Node& n); }; + } // namespace Aidge + #endif /* AIDGE_CORE_GRAPH_NODE_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 0406835a5810c06262a3fbb1a87a8c51dbfc91fe..b25ebd3c8de3830174c11d93d6eb60c8703c6a0d 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -121,6 +121,7 @@ private: /** @brief List of nodes ordered by their */ std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; size_t mStaticScheduleStep = 0; + mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; }; } // namespace Aidge diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 14e166402039230a283ce617e4997c9ad099eed9..a0e828cbce960132fd0c85dba7a3db9b95cc144f 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -173,7 +173,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) "Input index ({}) is out of bound ({}) for node {} (of type {})", inId, nbInputs(), name(), type()); if (mIdOutParents[inId] != gk_IODefaultIndex) { - fmt::print("Warning: filling a Tensor already attributed\n"); + Log::warn("Warning: filling a Tensor already attributed\n"); auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent @@ -390,6 +390,26 @@ std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::No return out; } + +// namespace Aidge { +// std::ostream& operator << (std::ostream& os, Aidge::Node& n) { +// using namespace std; +// os << "Node :\tName :\t\"" << n.name() << "\"\tType : \"" << n.getOperator()->type()<< "\"\tIN/OUTputs : "<< n.nbInputs() <<"/"<< n.nbOutputs() <<endl; +// os << "\tParents :\t" ; +// for (const auto & p : n.getParents()) +// { +// os << "\"" <<p->name() << "\"\t"; +// } +// os << endl; +// os << "\tChildren :\t" ; +// for (const auto & c : n.getChildren()) +// { +// os << "\"" << c->name() << "\"\t"; +// } +// os << endl; +// return os; +// } +// } ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 49e8de80cbca4e3b43720d921e261599b0db9bfa..94baf6a3e7b6e2e86de4e2d72ed19bfd9338392e 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -84,6 +84,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { if (verbose) fmt::print("List of consumers with their priors:\n"); std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> priorConsumers; + mPriorCache.clear(); for (const auto& consumer : consumers) { if (verbose) { @@ -620,6 +621,11 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { + const auto priorCache = mPriorCache.find(node); + if (priorCache != mPriorCache.end()) { + return priorCache->second; + } + PriorProducersConsumers prior; IOIndex_t inputIdx = 0; @@ -660,5 +666,6 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } + mPriorCache.insert(std::make_pair(node, prior)); return prior; } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 3ef70bcfb64afb8c5cfdf30bf6b1386541a5c2c3..514fb3b494b50112f26efbaba831e2b46429adcd 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -21,8 +21,8 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Testing.hpp" #include "aidge/graph/OpArgs.hpp" +#include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/scheduler/Scheduler.hpp" @@ -30,48 +30,105 @@ using namespace Aidge; TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { - const size_t nbTests = 100; - size_t nbUnicity = 0; - - for (int test = 0; test < nbTests; ++test) { - std::random_device rd; - const std::mt19937::result_type seed(rd()); - - RandomGraph randGraph; - randGraph.acyclic = true; - const auto g1 = std::make_shared<GraphView>("g1"); - const bool unicity1 = g1->add(randGraph.gen(seed, 10)); - - if (unicity1) { - for (auto& node : g1->getNodes()) { - std::static_pointer_cast<GenericOperator_Op>(node->getOperator())->setComputeOutputDims(GenericOperator_Op::InputIdentity(0, node->nbOutputs())); - } - - const auto orderedInputs = g1->getOrderedInputs(); - for (const auto& input : orderedInputs) { - auto prod = Producer({16, 32}); - prod->addChild(input.first, 0, input.second); - g1->add(prod); - } - - g1->save("schedule"); - g1->compile(); - - auto scheduler = SequentialScheduler(g1); - scheduler.generateScheduling(true); - const auto sch = scheduler.getStaticScheduling(); - - const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); - - std::vector<std::string> nodesName; - std::transform(sch.begin(), sch.end(), - std::back_inserter(nodesName), - [&namePtrTable](auto val){ return namePtrTable.at(val); }); - - fmt::print("schedule: {}\n", nodesName); - REQUIRE(sch.size() == 10 + orderedInputs.size()); + const size_t nbTests = 10; + size_t nbUnicity = 0; + std::uniform_int_distribution<std::size_t> nb_nodes_dist(100, 500); + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + std::mt19937 gen(rd()); + + RandomGraph randGraph; + const auto g1 = std::make_shared<GraphView>("g1"); + const size_t nb_nodes = nb_nodes_dist(gen); + + SECTION("Acyclic Graph") { + Aidge::Log::setConsoleLevel(Aidge::Log::Warn); + fmt::print("gen acyclic graph of {} nodes...\n", nb_nodes); + randGraph.acyclic = true; + + const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); + // g1->save("test_graph_" + std::to_string(test)); + + if (unicity1) { + for (auto &node : g1->getNodes()) { + std::static_pointer_cast<GenericOperator_Op>(node->getOperator()) + ->setComputeOutputDims( + GenericOperator_Op::InputIdentity(0, node->nbOutputs())); + } + + const auto orderedInputs = g1->getOrderedInputs(); + for (const auto &input : orderedInputs) { + auto prod = Producer({16, 32}); + prod->addChild(input.first, 0, input.second); + g1->add(prod); } + + g1->save("schedule"); + g1->compile(); + + fmt::print("gen scheduling...\n"); + auto scheduler = SequentialScheduler(g1); + scheduler.generateScheduling(); + fmt::print("gen scheduling finished\n"); + const auto sch = scheduler.getStaticScheduling(); + + const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); + + std::vector<std::string> nodesName; + std::transform( + sch.begin(), sch.end(), std::back_inserter(nodesName), + [&namePtrTable](auto val) { return namePtrTable.at(val); }); + + fmt::print("schedule: {}\n", nodesName); + REQUIRE(sch.size() == nb_nodes + orderedInputs.size()); + ++nbUnicity; + } } + // SECTION("Cyclic graph") { + // fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes); + // randGraph.acyclic = false; + // randGraph.types={"Memorize"}; + + // const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); + // // g1->save("test_graph_" + std::to_string(test)); + + // if (unicity1) { + // for (auto &node : g1->getNodes()) { + // std::static_pointer_cast<GenericOperator_Op>(node->getOperator()) + // ->setComputeOutputDims( + // GenericOperator_Op::InputIdentity(0, node->nbOutputs())); + // } + + // const auto orderedInputs = g1->getOrderedInputs(); + // for (const auto &input : orderedInputs) { + // auto prod = Producer({16, 32}); + // prod->addChild(input.first, 0, input.second); + // g1->add(prod); + // } + + // g1->save("schedule"); + // g1->forwardDims(); + + // fmt::print("gen scheduling...\n"); + // auto scheduler = SequentialScheduler(g1); + // scheduler.generateScheduling(); + // fmt::print("gen scheduling finished\n"); + // const auto sch = scheduler.getStaticScheduling(); + + // const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); + + // std::vector<std::string> nodesName; + // std::transform( + // sch.begin(), sch.end(), std::back_inserter(nodesName), + // [&namePtrTable](auto val) { return namePtrTable.at(val); }); - fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); + // fmt::print("schedule: {}\n", nodesName); + // REQUIRE(sch.size() == nb_nodes + orderedInputs.size()); + // ++nbUnicity; + // } + // } + } + fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); }