Skip to content
Snippets Groups Projects
Commit 23183567 authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix Scheduler::StaticSchedulingElement shared_ptr circular reference

- Change shared_ptr to raw ptr. It is possible without issue here as each pointer is stored and owned by Scheduler::mStaticSchedule and deleted with it
- Change Scheduler::resetScheduling() and Scheduler::~Scheduler() to delete raw pointers properly
parent 95b0cd6b
No related branches found
No related tags found
No related merge requests found
...@@ -61,8 +61,8 @@ protected: ...@@ -61,8 +61,8 @@ protected:
std::shared_ptr<Node> node; /** Scheduled `Node` */ std::shared_ptr<Node> node; /** Scheduled `Node` */
std::size_t early; /** Earliest possible execution time */ std::size_t early; /** Earliest possible execution time */
std::size_t late; /** Latest possible execution time */ std::size_t late; /** Latest possible execution time */
std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; /** Nodes that must be executed earlier */ std::vector<StaticSchedulingElement*> earlierThan; /** Nodes that must be executed earlier */
std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; /** Nodes that must be executed later */ std::vector<StaticSchedulingElement*> laterThan; /** Nodes that must be executed later */
}; };
/** /**
...@@ -110,7 +110,7 @@ public: ...@@ -110,7 +110,7 @@ public:
// ctor // ctor
}; };
virtual ~Scheduler() noexcept; virtual ~Scheduler();
public: public:
/** /**
...@@ -192,9 +192,9 @@ protected: ...@@ -192,9 +192,9 @@ protected:
* @brief Generate an initial base scheduling for the GraphView. * @brief Generate an initial base scheduling for the GraphView.
* The scheduling is entirely sequential and garanteed to be valid w.r.t. * The scheduling is entirely sequential and garanteed to be valid w.r.t.
* each node producer-consumer model. * each node producer-consumer model.
* @return Vector of shared pointers to `StaticSchedulingElement` representing the base schedule. * @return Vector of pointers to `StaticSchedulingElement` representing the base schedule.
*/ */
std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const; std::vector<StaticSchedulingElement*> generateBaseScheduling() const;
/** /**
* @brief Calculates early and late execution times for each node in an initial base scheduling. * @brief Calculates early and late execution times for each node in an initial base scheduling.
...@@ -207,7 +207,7 @@ protected: ...@@ -207,7 +207,7 @@ protected:
* *
* @param schedule Vector of shared pointers to StaticSchedulingElements to be processed * @param schedule Vector of shared pointers to StaticSchedulingElements to be processed
*/ */
void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const; void generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const;
private: private:
/** /**
...@@ -227,7 +227,7 @@ protected: ...@@ -227,7 +227,7 @@ protected:
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
std::vector<SchedulingElement> mScheduling; std::vector<SchedulingElement> mScheduling;
/** @brief List of nodes ordered by their */ /** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; std::vector<std::vector<StaticSchedulingElement*>> mStaticSchedule;
std::size_t mStaticScheduleStep = 0; std::size_t mStaticScheduleStep = 0;
mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
}; };
......
...@@ -48,7 +48,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: ...@@ -48,7 +48,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
// Sort static scheduling, the order will be the prefered threads scheduling // Sort static scheduling, the order will be the prefered threads scheduling
// order for non critical nodes // order for non critical nodes
std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::stable_sort(staticSchedule.begin(), staticSchedule.end(), std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
...@@ -59,12 +59,12 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: ...@@ -59,12 +59,12 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
size_t latest = 0; size_t latest = 0;
std::mutex schedulingMutex; std::mutex schedulingMutex;
std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished; std::map<StaticSchedulingElement*, std::atomic<bool>> finished;
while (!staticSchedule.empty()) { while (!staticSchedule.empty()) {
Log::debug("Step {}", latest); Log::debug("Step {}", latest);
std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish; std::vector<StaticSchedulingElement*> mustFinish;
// Run all nodes that must be run at this step: latest (critical nodes) // Run all nodes that must be run at this step: latest (critical nodes)
for (size_t i = 0; i < staticSchedule.size(); ) { for (size_t i = 0; i < staticSchedule.size(); ) {
...@@ -188,7 +188,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: ...@@ -188,7 +188,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
} }
// Wait for all nodes that must finish at latest to be finished // Wait for all nodes that must finish at latest to be finished
// By scheduling construction, no other node can be started before all // By scheduling construction, no other node can be started before all
// nodes at latest step are finished // nodes at latest step are finished
while (true) { while (true) {
bool ready = true; bool ready = true;
......
...@@ -37,7 +37,14 @@ ...@@ -37,7 +37,14 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
Aidge::Scheduler::~Scheduler() noexcept = default; Aidge::Scheduler::~Scheduler() {
for (auto& staticScheduleVec : mStaticSchedule) {
for (auto& staticScheduleElt : staticScheduleVec) {
delete staticScheduleElt;
}
staticScheduleVec.clear();
}
}
Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers() = default; Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers() = default;
Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers(const PriorProducersConsumers&) = default; Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers(const PriorProducersConsumers&) = default;
Aidge::Scheduler::PriorProducersConsumers::~PriorProducersConsumers() noexcept = default; Aidge::Scheduler::PriorProducersConsumers::~PriorProducersConsumers() noexcept = default;
...@@ -48,7 +55,7 @@ void Aidge::Scheduler::generateScheduling() { ...@@ -48,7 +55,7 @@ void Aidge::Scheduler::generateScheduling() {
mStaticSchedule.push_back(schedule); mStaticSchedule.push_back(schedule);
} }
std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const { std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
// 0) setup useful variables // 0) setup useful variables
// map associating each node with string "name (type#rank)" // map associating each node with string "name (type#rank)"
...@@ -60,7 +67,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S ...@@ -60,7 +67,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
// producers-consumers model! // producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers; std::set<std::shared_ptr<Node>> stillConsumers;
std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; std::vector<StaticSchedulingElement*> schedule;
// 1) Initialize consumers list: // 1) Initialize consumers list:
...@@ -131,7 +138,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S ...@@ -131,7 +138,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
// Producers are special nodes that generate data on demand. // Producers are special nodes that generate data on demand.
for (const auto& requiredProducer : requiredProducers) { for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer(); requiredProducer->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer)); schedule.push_back(new StaticSchedulingElement(requiredProducer));
} }
// 5) Find runnable consumers. // 5) Find runnable consumers.
...@@ -185,7 +192,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S ...@@ -185,7 +192,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
for (const auto& runnable : runnableConsumers) { for (const auto& runnable : runnableConsumers) {
Log::debug("Runnable: {}", namePtrTable.at(runnable)); Log::debug("Runnable: {}", namePtrTable.at(runnable));
runnable->getOperator()->updateConsummerProducer(); runnable->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable)); schedule.push_back(new StaticSchedulingElement(runnable));
} }
// 7) Update consumers list // 7) Update consumers list
...@@ -317,7 +324,7 @@ void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node> ...@@ -317,7 +324,7 @@ void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>
} }
void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const {
std::size_t latest = 0; std::size_t latest = 0;
// Calculate early (logical) start // Calculate early (logical) start
for (std::size_t elt = 0; elt < schedule.size(); ++elt) { for (std::size_t elt = 0; elt < schedule.size(); ++elt) {
...@@ -397,15 +404,20 @@ void Aidge::Scheduler::resetScheduling() { ...@@ -397,15 +404,20 @@ void Aidge::Scheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) { for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer(); node->getOperator()->resetConsummerProducer();
} }
for (auto& staticScheduleVec : mStaticSchedule) {
for (auto& staticScheduleElt : staticScheduleVec) {
delete staticScheduleElt;
}
staticScheduleVec.clear();
}
mStaticSchedule.clear(); mStaticSchedule.clear();
mStaticScheduleStep = 0; mStaticScheduleStep = 0;
mScheduling.clear(); mScheduling.clear();
} }
/** /**
* This version is a simplified version without special handling of concatenation. * @warning This version is a simplified version without special handling of concatenation.
*/ */
Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager; MemoryManager memManager;
...@@ -676,8 +688,8 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& ...@@ -676,8 +688,8 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>&
return Elts_t::NoneElts(); return Elts_t::NoneElts();
} }
Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( Aidge::Scheduler::PriorProducersConsumers
const std::shared_ptr<Node>& node) const Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) const
{ {
const auto priorCache = mPriorCache.find(node); const auto priorCache = mPriorCache.find(node);
if (priorCache != mPriorCache.end()) { if (priorCache != mPriorCache.end()) {
...@@ -714,6 +726,7 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon ...@@ -714,6 +726,7 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon
const auto& parentPrior = getPriorProducersConsumers(parent.first); const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) { if (!parentPrior.isPrior) {
// only happens in case of cyclic graphs
return PriorProducersConsumers(); // not scheduled return PriorProducersConsumers(); // not scheduled
} }
else { else {
......
...@@ -45,7 +45,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std ...@@ -45,7 +45,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
} }
// Sort static scheduling according to the policy // Sort static scheduling according to the policy
std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) { if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(), std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment