Skip to content
Snippets Groups Projects
Commit a95cddb1 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed missing producer consumer system reset

parent 79570aec
No related branches found
No related tags found
No related merge requests found
...@@ -63,6 +63,12 @@ public: ...@@ -63,6 +63,12 @@ public:
*/ */
virtual void updateConsummerProducer(); virtual void updateConsummerProducer();
/**
* @brief Reset the Consummer Producer system.
*
*/
virtual void resetConsummerProducer();
virtual ~OperatorImpl() = default; virtual ~OperatorImpl() = default;
protected: protected:
......
...@@ -153,6 +153,8 @@ public: ...@@ -153,6 +153,8 @@ public:
virtual void updateConsummerProducer(); virtual void updateConsummerProducer();
virtual void resetConsummerProducer();
virtual void forward(); virtual void forward();
virtual void backward(); virtual void backward();
......
...@@ -58,11 +58,7 @@ public: ...@@ -58,11 +58,7 @@ public:
~SequentialScheduler() = default; ~SequentialScheduler() = default;
void generateScheduling(bool verbose = false); void generateScheduling(bool verbose = false);
inline void resetScheduling() { void resetScheduling();
mScheduling.clear();
mStaticSchedule.clear();
mStaticScheduleStep = 0;
}
/** /**
* Generate the memory layout for the current static scheduling. * Generate the memory layout for the current static scheduling.
......
...@@ -102,6 +102,15 @@ public: ...@@ -102,6 +102,15 @@ public:
); );
} }
void resetConsummerProducer() override {
PYBIND11_OVERRIDE_NAME(
void,
OperatorImpl,
"reset_consummer_producer",
resetConsummerProducer,
);
}
}; };
void init_OperatorImpl(py::module& m){ void init_OperatorImpl(py::module& m){
...@@ -116,6 +125,7 @@ void init_OperatorImpl(py::module& m){ ...@@ -116,6 +125,7 @@ void init_OperatorImpl(py::module& m){
.def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData)
.def("get_nb_produced_data", &OperatorImpl::getNbProducedData) .def("get_nb_produced_data", &OperatorImpl::getNbProducedData)
.def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer)
.def("reset_consummer_producer", &OperatorImpl::resetConsummerProducer)
; ;
} }
} }
...@@ -68,6 +68,11 @@ void Aidge::OperatorImpl::updateConsummerProducer(){ ...@@ -68,6 +68,11 @@ void Aidge::OperatorImpl::updateConsummerProducer(){
} }
} }
void Aidge::OperatorImpl::resetConsummerProducer(){
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0);
std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0);
}
void Aidge::OperatorImpl::forward() { void Aidge::OperatorImpl::forward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented"); AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented");
} }
......
...@@ -59,6 +59,10 @@ void Aidge::Operator::updateConsummerProducer(){ ...@@ -59,6 +59,10 @@ void Aidge::Operator::updateConsummerProducer(){
AIDGE_ASSERT(mImpl != nullptr, "updateConsummerProducer(): an implementation is required for {}!", type()); AIDGE_ASSERT(mImpl != nullptr, "updateConsummerProducer(): an implementation is required for {}!", type());
mImpl->updateConsummerProducer(); mImpl->updateConsummerProducer();
} }
void Aidge::Operator::resetConsummerProducer(){
AIDGE_ASSERT(mImpl != nullptr, "resetConsummerProducer(): an implementation is required for {}!", type());
mImpl->resetConsummerProducer();
}
void Aidge::Operator::runHooks() const { void Aidge::Operator::runHooks() const {
for (auto& hook : mHooks) { for (auto& hook : mHooks) {
......
...@@ -301,6 +301,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -301,6 +301,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
} }
} }
void Aidge::SequentialScheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
}
mStaticSchedule.clear();
mStaticScheduleStep = 0;
mScheduling.clear();
}
/** /**
* This version is a simplified version without special handling of concatenation. * This version is a simplified version without special handling of concatenation.
*/ */
...@@ -423,7 +433,6 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge ...@@ -423,7 +433,6 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge
} }
// TODO: handle multiple inputs/outputs
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) { void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers) // Collect all data input of the graph (that are producers)
...@@ -461,6 +470,9 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve ...@@ -461,6 +470,9 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve
fmt::print("\n"); fmt::print("\n");
++mStaticScheduleStep; ++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
} }
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment