From e7f5fa0147a03810a10dd38828df6eb9ac0f2cc2 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Sat, 19 Oct 2024 20:38:12 +0000 Subject: [PATCH] Fix Scheduler::StaticSchedulingElement shared_ptr circular reference - Change shared_ptr to raw ptr. It is possible without issue here as each pointer is stored and owned by Scheduler::mStaticSchedule and deleted with it - Change Scheduler::resetScheduling() and Scheduler::~Scheduler() to delete raw pointers properly --- include/aidge/scheduler/Scheduler.hpp | 14 +++++------ src/scheduler/ParallelScheduler.cpp | 8 +++--- src/scheduler/Scheduler.cpp | 35 ++++++++++++++++++--------- src/scheduler/SequentialScheduler.cpp | 2 +- 4 files changed, 36 insertions(+), 23 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 981920ea1..2d03f4e8b 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -61,8 +61,8 @@ protected: std::shared_ptr<Node> node; /** Scheduled `Node` */ std::size_t early; /** Earliest possible execution time */ std::size_t late; /** Latest possible execution time */ - std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; /** Nodes that must be executed earlier */ - std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; /** Nodes that must be executed later */ + std::vector<StaticSchedulingElement*> earlierThan; /** Nodes that must be executed earlier */ + std::vector<StaticSchedulingElement*> laterThan; /** Nodes that must be executed later */ }; /** @@ -110,7 +110,7 @@ public: // ctor }; - virtual ~Scheduler() noexcept; + virtual ~Scheduler(); public: /** @@ -192,9 +192,9 @@ protected: * @brief Generate an initial base scheduling for the GraphView. * The scheduling is entirely sequential and garanteed to be valid w.r.t. * each node producer-consumer model. - * @return Vector of shared pointers to `StaticSchedulingElement` representing the base schedule. + * @return Vector of pointers to `StaticSchedulingElement` representing the base schedule. */ - std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const; + std::vector<StaticSchedulingElement*> generateBaseScheduling() const; /** * @brief Calculates early and late execution times for each node in an initial base scheduling. @@ -207,7 +207,7 @@ protected: * * @param schedule Vector of shared pointers to StaticSchedulingElements to be processed */ - void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const; + void generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const; private: /** @@ -227,7 +227,7 @@ protected: /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ std::vector<SchedulingElement> mScheduling; /** @brief List of nodes ordered by their */ - std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; + std::vector<std::vector<StaticSchedulingElement*>> mStaticSchedule; std::size_t mStaticScheduleStep = 0; mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; }; diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp index 1d70646b7..2b9a1f5b6 100644 --- a/src/scheduler/ParallelScheduler.cpp +++ b/src/scheduler/ParallelScheduler.cpp @@ -48,7 +48,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Sort static scheduling, the order will be the prefered threads scheduling // order for non critical nodes - std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); + std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); std::stable_sort(staticSchedule.begin(), staticSchedule.end(), [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); @@ -59,12 +59,12 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: size_t latest = 0; std::mutex schedulingMutex; - std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished; + std::map<StaticSchedulingElement*, std::atomic<bool>> finished; while (!staticSchedule.empty()) { Log::debug("Step {}", latest); - std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish; + std::vector<StaticSchedulingElement*> mustFinish; // Run all nodes that must be run at this step: latest (critical nodes) for (size_t i = 0; i < staticSchedule.size(); ) { @@ -188,7 +188,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: } // Wait for all nodes that must finish at latest to be finished - // By scheduling construction, no other node can be started before all + // By scheduling construction, no other node can be started before all // nodes at latest step are finished while (true) { bool ready = true; diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 958b25432..34aea5ffd 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -37,7 +37,14 @@ #include "aidge/utils/Types.h" -Aidge::Scheduler::~Scheduler() noexcept = default; +Aidge::Scheduler::~Scheduler() { + for (auto& staticScheduleVec : mStaticSchedule) { + for (auto& staticScheduleElt : staticScheduleVec) { + delete staticScheduleElt; + } + staticScheduleVec.clear(); + } +} Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers() = default; Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers(const PriorProducersConsumers&) = default; Aidge::Scheduler::PriorProducersConsumers::~PriorProducersConsumers() noexcept = default; @@ -48,7 +55,7 @@ void Aidge::Scheduler::generateScheduling() { mStaticSchedule.push_back(schedule); } -std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const { +std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const { // 0) setup useful variables // map associating each node with string "name (type#rank)" @@ -60,7 +67,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S // producers-consumers model! std::set<std::shared_ptr<Node>> stillConsumers; - std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; + std::vector<StaticSchedulingElement*> schedule; // 1) Initialize consumers list: start from the output nodes and @@ -124,7 +131,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S // Producers are special nodes that generate data on demand. for (const auto& requiredProducer : requiredProducers) { requiredProducer->getOperator()->updateConsummerProducer(); - schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer)); + schedule.push_back(new StaticSchedulingElement(requiredProducer)); } // 5) Find runnable consumers. @@ -178,7 +185,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S for (const auto& runnable : runnableConsumers) { Log::debug("Runnable: {}", namePtrTable.at(runnable)); runnable->getOperator()->updateConsummerProducer(); - schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable)); + schedule.push_back(new StaticSchedulingElement(runnable)); } // 7) Update consumers list @@ -310,7 +317,7 @@ void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node> } -void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { +void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const { std::size_t latest = 0; // Calculate early (logical) start for (std::size_t elt = 0; elt < schedule.size(); ++elt) { @@ -390,15 +397,20 @@ void Aidge::Scheduler::resetScheduling() { for (auto node : mGraphView->getNodes()) { node->getOperator()->resetConsummerProducer(); } - + for (auto& staticScheduleVec : mStaticSchedule) { + for (auto& staticScheduleElt : staticScheduleVec) { + delete staticScheduleElt; + } + staticScheduleVec.clear(); + } mStaticSchedule.clear(); mStaticScheduleStep = 0; mScheduling.clear(); } /** - * This version is a simplified version without special handling of concatenation. -*/ + * @warning This version is a simplified version without special handling of concatenation. + */ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { MemoryManager memManager; @@ -669,8 +681,8 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& return Elts_t::NoneElts(); } -Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( - const std::shared_ptr<Node>& node) const +Aidge::Scheduler::PriorProducersConsumers +Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) const { const auto priorCache = mPriorCache.find(node); if (priorCache != mPriorCache.end()) { @@ -707,6 +719,7 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon const auto& parentPrior = getPriorProducersConsumers(parent.first); if (!parentPrior.isPrior) { + // only happens in case of cyclic graphs return PriorProducersConsumers(); // not scheduled } else { diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 88b5e98bc..4e6e91f51 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -45,7 +45,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } // Sort static scheduling according to the policy - std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); + std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) { std::stable_sort(staticSchedule.begin(), staticSchedule.end(), -- GitLab