From a95cddb10fad1c2d785041d38a6142d73d1e3df6 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 1 Mar 2024 09:54:25 +0100 Subject: [PATCH] Fixed missing producer consumer system reset --- include/aidge/backend/OperatorImpl.hpp | 6 ++++++ include/aidge/operator/Operator.hpp | 2 ++ include/aidge/scheduler/Scheduler.hpp | 6 +----- python_binding/backend/pybind_OperatorImpl.cpp | 10 ++++++++++ src/backend/OperatorImpl.cpp | 5 +++++ src/operator/Operator.cpp | 4 ++++ src/scheduler/Scheduler.cpp | 14 +++++++++++++- 7 files changed, 41 insertions(+), 6 deletions(-) diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 19f083750..8b5aba10d 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -63,6 +63,12 @@ public: */ virtual void updateConsummerProducer(); + /** + * @brief Reset the Consummer Producer system. + * + */ + virtual void resetConsummerProducer(); + virtual ~OperatorImpl() = default; protected: diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 7cfe6b925..9a32a9b9d 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -153,6 +153,8 @@ public: virtual void updateConsummerProducer(); + virtual void resetConsummerProducer(); + virtual void forward(); virtual void backward(); diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 06c3b6e9e..747785bf8 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -58,11 +58,7 @@ public: ~SequentialScheduler() = default; void generateScheduling(bool verbose = false); - inline void resetScheduling() { - mScheduling.clear(); - mStaticSchedule.clear(); - mStaticScheduleStep = 0; - } + void resetScheduling(); /** * Generate the memory layout for the current static scheduling. diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 346100690..a2a5e6b8b 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -102,6 +102,15 @@ public: ); } + void resetConsummerProducer() override { + PYBIND11_OVERRIDE_NAME( + void, + OperatorImpl, + "reset_consummer_producer", + resetConsummerProducer, + + ); + } }; void init_OperatorImpl(py::module& m){ @@ -116,6 +125,7 @@ void init_OperatorImpl(py::module& m){ .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) .def("get_nb_produced_data", &OperatorImpl::getNbProducedData) .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) + .def("reset_consummer_producer", &OperatorImpl::resetConsummerProducer) ; } } diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index b76bf3336..1911da228 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -68,6 +68,11 @@ void Aidge::OperatorImpl::updateConsummerProducer(){ } } +void Aidge::OperatorImpl::resetConsummerProducer(){ + std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0); + std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0); +} + void Aidge::OperatorImpl::forward() { AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented"); } diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index af6973362..289b2be90 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -59,6 +59,10 @@ void Aidge::Operator::updateConsummerProducer(){ AIDGE_ASSERT(mImpl != nullptr, "updateConsummerProducer(): an implementation is required for {}!", type()); mImpl->updateConsummerProducer(); } +void Aidge::Operator::resetConsummerProducer(){ + AIDGE_ASSERT(mImpl != nullptr, "resetConsummerProducer(): an implementation is required for {}!", type()); + mImpl->resetConsummerProducer(); +} void Aidge::Operator::runHooks() const { for (auto& hook : mHooks) { diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index c82f58590..d45bc5f8e 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -301,6 +301,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } } +void Aidge::SequentialScheduler::resetScheduling() { + for (auto node : mGraphView->getNodes()) { + node->getOperator()->resetConsummerProducer(); + } + + mStaticSchedule.clear(); + mStaticScheduleStep = 0; + mScheduling.clear(); +} + /** * This version is a simplified version without special handling of concatenation. */ @@ -423,7 +433,6 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge } -// TODO: handle multiple inputs/outputs void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) { // Collect all data input of the graph (that are producers) @@ -461,6 +470,9 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve fmt::print("\n"); ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { -- GitLab