Skip to content
Snippets Groups Projects
Commit e7f5fa01 authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix Scheduler::StaticSchedulingElement shared_ptr circular reference

- Change shared_ptr to raw ptr. It is possible without issue here as each pointer is stored and owned by Scheduler::mStaticSchedule and deleted with it
- Change Scheduler::resetScheduling() and Scheduler::~Scheduler() to delete raw pointers properly
parent fc5ff602
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
......@@ -61,8 +61,8 @@ protected:
std::shared_ptr<Node> node; /** Scheduled `Node` */
std::size_t early; /** Earliest possible execution time */
std::size_t late; /** Latest possible execution time */
std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; /** Nodes that must be executed earlier */
std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; /** Nodes that must be executed later */
std::vector<StaticSchedulingElement*> earlierThan; /** Nodes that must be executed earlier */
std::vector<StaticSchedulingElement*> laterThan; /** Nodes that must be executed later */
};
/**
......@@ -110,7 +110,7 @@ public:
// ctor
};
virtual ~Scheduler() noexcept;
virtual ~Scheduler();
public:
/**
......@@ -192,9 +192,9 @@ protected:
* @brief Generate an initial base scheduling for the GraphView.
* The scheduling is entirely sequential and garanteed to be valid w.r.t.
* each node producer-consumer model.
* @return Vector of shared pointers to `StaticSchedulingElement` representing the base schedule.
* @return Vector of pointers to `StaticSchedulingElement` representing the base schedule.
*/
std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const;
std::vector<StaticSchedulingElement*> generateBaseScheduling() const;
/**
* @brief Calculates early and late execution times for each node in an initial base scheduling.
......@@ -207,7 +207,7 @@ protected:
*
* @param schedule Vector of shared pointers to StaticSchedulingElements to be processed
*/
void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const;
void generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const;
private:
/**
......@@ -227,7 +227,7 @@ protected:
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
std::vector<SchedulingElement> mScheduling;
/** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule;
std::vector<std::vector<StaticSchedulingElement*>> mStaticSchedule;
std::size_t mStaticScheduleStep = 0;
mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
};
......
......@@ -48,7 +48,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
// Sort static scheduling, the order will be the prefered threads scheduling
// order for non critical nodes
std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
......@@ -59,12 +59,12 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
size_t latest = 0;
std::mutex schedulingMutex;
std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
std::map<StaticSchedulingElement*, std::atomic<bool>> finished;
while (!staticSchedule.empty()) {
Log::debug("Step {}", latest);
std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
std::vector<StaticSchedulingElement*> mustFinish;
// Run all nodes that must be run at this step: latest (critical nodes)
for (size_t i = 0; i < staticSchedule.size(); ) {
......@@ -188,7 +188,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
}
// Wait for all nodes that must finish at latest to be finished
// By scheduling construction, no other node can be started before all
// By scheduling construction, no other node can be started before all
// nodes at latest step are finished
while (true) {
bool ready = true;
......
......@@ -37,7 +37,14 @@
#include "aidge/utils/Types.h"
Aidge::Scheduler::~Scheduler() noexcept = default;
Aidge::Scheduler::~Scheduler() {
for (auto& staticScheduleVec : mStaticSchedule) {
for (auto& staticScheduleElt : staticScheduleVec) {
delete staticScheduleElt;
}
staticScheduleVec.clear();
}
}
Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers() = default;
Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers(const PriorProducersConsumers&) = default;
Aidge::Scheduler::PriorProducersConsumers::~PriorProducersConsumers() noexcept = default;
......@@ -48,7 +55,7 @@ void Aidge::Scheduler::generateScheduling() {
mStaticSchedule.push_back(schedule);
}
std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const {
std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
// 0) setup useful variables
// map associating each node with string "name (type#rank)"
......@@ -60,7 +67,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
// producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers;
std::vector<std::shared_ptr<StaticSchedulingElement>> schedule;
std::vector<StaticSchedulingElement*> schedule;
// 1) Initialize consumers list: start from the output nodes and
......@@ -124,7 +131,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
// Producers are special nodes that generate data on demand.
for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer));
schedule.push_back(new StaticSchedulingElement(requiredProducer));
}
// 5) Find runnable consumers.
......@@ -178,7 +185,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
for (const auto& runnable : runnableConsumers) {
Log::debug("Runnable: {}", namePtrTable.at(runnable));
runnable->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable));
schedule.push_back(new StaticSchedulingElement(runnable));
}
// 7) Update consumers list
......@@ -310,7 +317,7 @@ void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>
}
void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const {
std::size_t latest = 0;
// Calculate early (logical) start
for (std::size_t elt = 0; elt < schedule.size(); ++elt) {
......@@ -390,15 +397,20 @@ void Aidge::Scheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
}
for (auto& staticScheduleVec : mStaticSchedule) {
for (auto& staticScheduleElt : staticScheduleVec) {
delete staticScheduleElt;
}
staticScheduleVec.clear();
}
mStaticSchedule.clear();
mStaticScheduleStep = 0;
mScheduling.clear();
}
/**
* This version is a simplified version without special handling of concatenation.
*/
* @warning This version is a simplified version without special handling of concatenation.
*/
Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager;
......@@ -669,8 +681,8 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>&
return Elts_t::NoneElts();
}
Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
Aidge::Scheduler::PriorProducersConsumers
Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) const
{
const auto priorCache = mPriorCache.find(node);
if (priorCache != mPriorCache.end()) {
......@@ -707,6 +719,7 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
// only happens in case of cyclic graphs
return PriorProducersConsumers(); // not scheduled
}
else {
......
......@@ -45,7 +45,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
}
// Sort static scheduling according to the policy
std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment