diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 19f0837504016f38ae96dd852bc6fa41b5ab53ba..8b5aba10dbc2691b5d607cda28eba621335881d1 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -63,6 +63,12 @@ public: */ virtual void updateConsummerProducer(); + /** + * @brief Reset the Consummer Producer system. + * + */ + virtual void resetConsummerProducer(); + virtual ~OperatorImpl() = default; protected: diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 7cfe6b92521c3ef00528d1b5eff602d9f52b11fd..9a32a9b9d16d32a49937804b5c0c596dd05cae1e 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -153,6 +153,8 @@ public: virtual void updateConsummerProducer(); + virtual void resetConsummerProducer(); + virtual void forward(); virtual void backward(); diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 06c3b6e9ed23f10fe32aeb807cfef112970897a0..747785bf886889aed273c944904ddbb6198c4968 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -58,11 +58,7 @@ public: ~SequentialScheduler() = default; void generateScheduling(bool verbose = false); - inline void resetScheduling() { - mScheduling.clear(); - mStaticSchedule.clear(); - mStaticScheduleStep = 0; - } + void resetScheduling(); /** * Generate the memory layout for the current static scheduling. diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 34610069079ee792ebbe4b261b57177b3bbe2997..a2a5e6b8bb2d0f2413ef94c360b383608c5b41b5 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -102,6 +102,15 @@ public: ); } + void resetConsummerProducer() override { + PYBIND11_OVERRIDE_NAME( + void, + OperatorImpl, + "reset_consummer_producer", + resetConsummerProducer, + + ); + } }; void init_OperatorImpl(py::module& m){ @@ -116,6 +125,7 @@ void init_OperatorImpl(py::module& m){ .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) .def("get_nb_produced_data", &OperatorImpl::getNbProducedData) .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) + .def("reset_consummer_producer", &OperatorImpl::resetConsummerProducer) ; } } diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index b76bf33367221add6273e02590d6ec315cfa4544..1911da228c83d66117a2591adf47dc07cd8dc674 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -68,6 +68,11 @@ void Aidge::OperatorImpl::updateConsummerProducer(){ } } +void Aidge::OperatorImpl::resetConsummerProducer(){ + std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0); + std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0); +} + void Aidge::OperatorImpl::forward() { AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented"); } diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index af697336284542edf38559f7b052e5211ddeb7d0..289b2be90735d848e5083090d2ae4319a7490fde 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -59,6 +59,10 @@ void Aidge::Operator::updateConsummerProducer(){ AIDGE_ASSERT(mImpl != nullptr, "updateConsummerProducer(): an implementation is required for {}!", type()); mImpl->updateConsummerProducer(); } +void Aidge::Operator::resetConsummerProducer(){ + AIDGE_ASSERT(mImpl != nullptr, "resetConsummerProducer(): an implementation is required for {}!", type()); + mImpl->resetConsummerProducer(); +} void Aidge::Operator::runHooks() const { for (auto& hook : mHooks) { diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index c82f58590ccc1b54e1367d8b23c3b2e64f7e845e..d45bc5f8eb1ac4b76bef3dd5c8e596efd933033b 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -301,6 +301,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } } +void Aidge::SequentialScheduler::resetScheduling() { + for (auto node : mGraphView->getNodes()) { + node->getOperator()->resetConsummerProducer(); + } + + mStaticSchedule.clear(); + mStaticScheduleStep = 0; + mScheduling.clear(); +} + /** * This version is a simplified version without special handling of concatenation. */ @@ -423,7 +433,6 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge } -// TODO: handle multiple inputs/outputs void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) { // Collect all data input of the graph (that are producers) @@ -461,6 +470,9 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve fmt::print("\n"); ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {