From 36994a1c80f18668750eea2cd921fb57e84e4f56 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 5 Sep 2023 13:07:11 +0000 Subject: [PATCH] [Scheduling] Refactor to separate scheduling & forward phase. --- include/aidge/backend/OperatorImpl.hpp | 10 ++- include/aidge/operator/Operator.hpp | 2 + include/aidge/scheduler/Scheduler.hpp | 28 ++++-- .../operator/pybind_GenericOperator.cpp | 6 +- python_binding/scheduler/pybind_Scheduler.cpp | 3 +- src/operator/Operator.cpp | 3 + src/scheduler/Scheduler.cpp | 88 +++++++++++-------- 7 files changed, 91 insertions(+), 49 deletions(-) diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 5aa2829e1..d10270b62 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -20,7 +20,7 @@ namespace Aidge { class OperatorImpl { public: virtual void forward(){}; - virtual void backward() {} + virtual void backward(){}; /** * @brief Minimum amount of data from a specific input required by the @@ -46,13 +46,19 @@ public: virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0; /** - * @brief TOtal amount of produced data ready to be used on a specific output. + * @brief Total amount of produced data ready to be used on a specific output. * * @param outputIdx Index of the output analysed. * @return DimSize_t */ virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0; + /** + * @brief Update the Consummer Producer system by simulating the consumption and production of i/o + * + */ + virtual void updateConsummerProducer() = 0; + virtual ~OperatorImpl() = default; }; } // namespace Aidge diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 30e1ce2a7..36f846dda 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -78,6 +78,8 @@ public: */ NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + void updateConsummerProducer(); + virtual void forward(); virtual void backward(); diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index fb39bcf4a..9916ee200 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -43,6 +43,8 @@ public: }; ~SequentialScheduler() = default; + void generateScheduling(bool verbose = false); + /** * @brief Run the provided Computational Graph with a batch of data */ @@ -59,12 +61,8 @@ public: * * @return std::vector<std::shared_ptr<Node>> */ - std::vector<std::shared_ptr<Node>> getNodeScheduling(){ - std::vector<std::shared_ptr<Node>> nodeScheduling = {}; - for(SchedulingElement & scheduleElt: mScheduling){ - nodeScheduling.push_back(scheduleElt.node); - } - return nodeScheduling; + std::vector<std::shared_ptr<Node>> getStaticScheduling(){ + return mStaticSchedule; } private: @@ -76,8 +74,26 @@ private: */ std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; + /** + * @brief Shared ptr to the scheduled graph view + * + */ std::shared_ptr<GraphView> mGraphView; + /** + * @brief List of SchedulingElement (i.e: Nodes with their computation time) + * + */ std::vector<SchedulingElement> mScheduling; + /** + * @brief List of nodes ordered by their + * + */ + std::vector<std::shared_ptr<Node>> mStaticSchedule; + /** + * @brief Number of computation node (i.e: nb nodes != Producer) + * + */ + std::size_t mComputationNumber = 0; // TODO: Check if not inferable from mStaticSchedule }; } // namespace Aidge diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index ee3ee74c1..bec59eaf2 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -22,7 +22,7 @@ namespace Aidge { void init_GenericOperator(py::module& m) { py::class_<GenericOperator_Op, std::shared_ptr<GenericOperator_Op>, Operator>(m, "GenericOperatorOp", py::multiple_inheritance()) - .def("get_parameter_type", &GenericOperator_Op::getParameterType) + .def("get_parameter_type", &GenericOperator_Op::getParameterType) .def("get_parameters_name", &GenericOperator_Op::getParametersName) .def("add_parameter", &GenericOperator_Op::addParameter<bool>) .def("add_parameter", &GenericOperator_Op::addParameter<int>) @@ -34,10 +34,10 @@ void init_GenericOperator(py::module& m) { .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<std::string>>) .def("get_parameter", [](GenericOperator_Op& self, std::string key) -> py::object { /* - This getParameter method returns the good python type without having to have + This getParameter method returns the good python type without having to have prior knowledge of the parameter type. */ - py::object res = py::none(); + py::object res = py::none(); std::string paramType = self.getParameterType(key); if(paramType == typeid(int).name()) res = py::cast(self.getParameter<int>(key)); diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index ab7d3d850..85479d41f 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -21,7 +21,8 @@ void init_Scheduler(py::module& m){ .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) - .def("get_node_scheduling", &SequentialScheduler::getNodeScheduling) + .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) + .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling) ; } } diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 99b07235e..b3896b121 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -38,6 +38,9 @@ Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) co Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { return mImpl->getNbProducedData(outputIdx); } +void Aidge::Operator::updateConsummerProducer(){ + mImpl->updateConsummerProducer(); +} void Aidge::Operator::forward() { mImpl->forward(); } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index c8cc0a51d..dc0768d2b 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -33,26 +33,19 @@ void drawProgressBar(double progress, int barWidth, const std::string& additiona fflush(stdout); } -// TODO: handle multiple inputs/outputs -void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { - if (forwardDims) {mGraphView->forwardDims(); } - - mScheduling.clear(); - +void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // setup initial producers list - // add each Producer Node. - std::set<std::shared_ptr<Node>> computationOver; - std::size_t computationNumber = 0; + mComputationNumber = 0; std::set<std::shared_ptr<Node>> producers; for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { if (nodePtr->type() == "Producer") { producers.insert(nodePtr); } else { - ++computationNumber; + ++mComputationNumber; } } // add Data Input - // FIXME : shoudl be changed when the real system for providing + // FIXME : should be changed when the real system for providing // data is implemented for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) { for (const auto& parentPtr : nodePtr->getParents()) { @@ -112,21 +105,10 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { } } - // run sequencially every runnable consumers once - // TODO: handle memory allocation in scheduler - // TODO: optimize memory usage + // Push consumers in the list of nodes to run and update the consumer producer system for (const auto& runnable : runnableConsumers) { - if (verbose) - printf("run: %s\n", - (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); - else - drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(computationNumber), 50, - (std::string("running ") + runnable->type() + "_" + - std::to_string(reinterpret_cast<uintptr_t>(runnable.get())))); - const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); + runnable->getOperator()->updateConsummerProducer(); + mStaticSchedule.push_back(runnable); } // update producers and consumers list @@ -164,18 +146,6 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { } } - bool computationOverForConsumer = true; - for (IOIndex_t parentIDi = 0; parentIDi < consumer->nbInputs(); ++parentIDi) { - if (consumer->getOperator()->getNbConsumedData(parentIDi) < - consumer->getOperator()->getNbRequiredData(parentIDi)) { - computationOverForConsumer = false; - break; - } - } - if (computationOverForConsumer) { - computationOver.insert(consumer); - } - for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { if (consumer->getOperator()->getNbProducedData(outId) > 0) { if (verbose) printf(" also producer\n"); @@ -197,8 +167,52 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { if (verbose) printf("*************\n"); } while (!consumers.empty()); + +} + +// TODO: handle multiple inputs/outputs +void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { + if (forwardDims) {mGraphView->forwardDims(); } + + // add each Producer Node. + std::set<std::shared_ptr<Node>> computationOver; + + mScheduling.clear(); + + this->generateScheduling(); + + // TODO: For loop on the list of node to run + // run sequencially every runnable consumers once + // TODO: handle memory allocation in scheduler + // TODO: optimize memory usage + for (const auto& runnable : mStaticSchedule) { + bool computationOverForConsumer = true; + for (IOIndex_t parentIDi = 0; parentIDi < runnable->nbInputs(); ++parentIDi) { + if (runnable->getOperator()->getNbConsumedData(parentIDi) < + runnable->getOperator()->getNbRequiredData(parentIDi)) { + computationOverForConsumer = false; + break; + } + } + if (computationOverForConsumer) { + computationOver.insert(runnable); + } + + if (verbose) + printf("run: %s\n", + (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); + else + drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(mComputationNumber), 50, + (std::string("running ") + runnable->type() + "_" + + std::to_string(reinterpret_cast<uintptr_t>(runnable.get())))); + const auto tStart = std::chrono::high_resolution_clock::now(); + runnable->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); + } if (!verbose) drawProgressBar(1.0, 50, " "); printf("\n"); + } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { -- GitLab