Skip to content
Snippets Groups Projects
Commit 9b85471f authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added cache system for getPriorProducersConsumers()

parent d9e59761
No related branches found
No related tags found
No related merge requests found
...@@ -153,6 +153,7 @@ protected: ...@@ -153,6 +153,7 @@ protected:
/** @brief List of nodes ordered by their */ /** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule;
size_t mStaticScheduleStep = 0; size_t mStaticScheduleStep = 0;
mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -173,7 +173,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) ...@@ -173,7 +173,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId)
"Input index ({}) is out of bound ({}) for node {} (of type {})", "Input index ({}) is out of bound ({}) for node {} (of type {})",
inId, nbInputs(), name(), type()); inId, nbInputs(), name(), type());
if (mIdOutParents[inId] != gk_IODefaultIndex) { if (mIdOutParents[inId] != gk_IODefaultIndex) {
fmt::print("Warning: filling a Tensor already attributed\n"); Log::notice("Notice: filling a Tensor already attributed");
auto originalParent = input(inId); auto originalParent = input(inId);
// remove original parent reference to child // remove original parent reference to child
// find the output ID for original Parent // find the output ID for original Parent
......
...@@ -71,6 +71,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S ...@@ -71,6 +71,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
Log::debug("List of consumers with their priors:"); Log::debug("List of consumers with their priors:");
std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> requiredProducers;
std::set<std::shared_ptr<Node>> priorConsumers; std::set<std::shared_ptr<Node>> priorConsumers;
mPriorCache.clear();
for (const auto& consumer : consumers) { for (const auto& consumer : consumers) {
Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange)));
...@@ -280,6 +281,8 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S ...@@ -280,6 +281,8 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
Log::debug("********************"); Log::debug("********************");
} while (!consumers.empty()); } while (!consumers.empty());
mPriorCache.clear();
if (!consumers.empty()) { if (!consumers.empty()) {
Log::warn("Remaining consumers: possible dead-lock"); Log::warn("Remaining consumers: possible dead-lock");
} }
...@@ -633,6 +636,11 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& ...@@ -633,6 +636,11 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>&
Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const const std::shared_ptr<Node>& node) const
{ {
const auto priorCache = mPriorCache.find(node);
if (priorCache != mPriorCache.end()) {
return priorCache->second;
}
PriorProducersConsumers prior; PriorProducersConsumers prior;
IOIndex_t inputIdx = 0; IOIndex_t inputIdx = 0;
...@@ -673,5 +681,6 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon ...@@ -673,5 +681,6 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon
if (prior.priorConsumers.empty()) { if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node); prior.priorConsumers.insert(node);
} }
mPriorCache.insert(std::make_pair(node, prior));
return prior; return prior;
} }
...@@ -30,7 +30,8 @@ ...@@ -30,7 +30,8 @@
using namespace Aidge; using namespace Aidge;
TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
const size_t nbTests = 100; const size_t nbTests = 1;
const size_t graphSize = 1000;
size_t nbUnicity = 0; size_t nbUnicity = 0;
for (int test = 0; test < nbTests; ++test) { for (int test = 0; test < nbTests; ++test) {
...@@ -40,7 +41,9 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { ...@@ -40,7 +41,9 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
RandomGraph randGraph; RandomGraph randGraph;
randGraph.acyclic = true; randGraph.acyclic = true;
const auto g1 = std::make_shared<GraphView>("g1"); const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(randGraph.gen(seed, 10)); fmt::print("gen graph of {} nodes...\n", graphSize);
const bool unicity1 = g1->add(randGraph.gen(seed, graphSize));
fmt::print("gen graph finished\n", graphSize);
if (unicity1) { if (unicity1) {
for (auto& node : g1->getNodes()) { for (auto& node : g1->getNodes()) {
...@@ -57,8 +60,10 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { ...@@ -57,8 +60,10 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
g1->save("schedule"); g1->save("schedule");
g1->forwardDims(); g1->forwardDims();
fmt::print("gen scheduling...\n");
auto scheduler = SequentialScheduler(g1); auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling(); scheduler.generateScheduling();
fmt::print("gen scheduling finished\n");
const auto sch = scheduler.getStaticScheduling(); const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
...@@ -69,7 +74,8 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { ...@@ -69,7 +74,8 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
[&namePtrTable](auto val){ return namePtrTable.at(val); }); [&namePtrTable](auto val){ return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName); fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == 10 + orderedInputs.size()); REQUIRE(sch.size() == graphSize + orderedInputs.size());
++nbUnicity;
} }
} }
...@@ -112,6 +118,7 @@ TEST_CASE("randomScheduling_tokens", "[Scheduler][randomGen]") { ...@@ -112,6 +118,7 @@ TEST_CASE("randomScheduling_tokens", "[Scheduler][randomGen]") {
fmt::print("schedule: {}\n", nodesName); fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == 10 + orderedInputs.size()); REQUIRE(sch.size() == 10 + orderedInputs.size());
++nbUnicity;
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment