diff --git a/CMakeLists.txt b/CMakeLists.txt index df8d6d4dff7be783aef58a5beca33d3a922caa1f..ec6aacd723a50eba2bfed0184941410340c6a7aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,6 +44,8 @@ set(FMT_SYSTEM_HEADERS ON) FetchContent_MakeAvailable(fmt) set_property(TARGET fmt PROPERTY POSITION_INDEPENDENT_CODE ON) +find_package(Threads REQUIRED) + ############################################## # Create target and set properties @@ -88,7 +90,7 @@ if (PYBIND) ) endif() -target_link_libraries(${module_name} PUBLIC fmt::fmt) +target_link_libraries(${module_name} PUBLIC Threads::Threads fmt::fmt) target_compile_features(${module_name} PRIVATE cxx_std_14) if (DOSANITIZE STREQUAL "ON") diff --git a/aidge_core-config.cmake.in b/aidge_core-config.cmake.in index 9862b640541458bdab1b1b8bc2a90297625e35ee..d97afe8a2a1ca98eb862d66c388081bca7b72edc 100644 --- a/aidge_core-config.cmake.in +++ b/aidge_core-config.cmake.in @@ -2,6 +2,7 @@ include(CMakeFindDependencyMacro) find_dependency(fmt) +find_dependency(Threads) include(${CMAKE_CURRENT_LIST_DIR}/aidge_core-config-version.cmake) diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 04044ed1c77915ec10b5af5b660cf8e6b20c81b2..6a9056723df133fef62e56f969d39d8f69390a76 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -16,6 +16,7 @@ #include <vector> #include "aidge/utils/Types.h" +#include "aidge/data/Elts.hpp" namespace Aidge { class Operator; @@ -36,13 +37,13 @@ public: * @param inputIdx Index of the input analysed. * @return std::size_t */ - virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const; // Amount of input data that cannot be overwritten during the execution. - virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; // Memory required at an output for a given input size. - virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; /** * @brief Total amount of consumed data from a specific input. @@ -50,7 +51,7 @@ public: * @param inputIdx Index of the input analysed. * @return DimSize_t */ - virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Total amount of produced data ready to be used on a specific output. @@ -58,7 +59,7 @@ public: * @param outputIdx Index of the output analysed. * @return DimSize_t */ - virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; /** * @brief Update the Consummer Producer system by simulating the consumption and production of i/o @@ -77,8 +78,8 @@ public: protected: const Operator &mOp; const std::string mBackend; - std::vector<NbElts_t> mNbConsumedData; - std::vector<NbElts_t> mNbProducedData; + std::vector<Elts_t> mNbConsumedData; + std::vector<Elts_t> mNbProducedData; }; } // namespace Aidge diff --git a/include/aidge/data/Elts.hpp b/include/aidge/data/Elts.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1a5a9e10ea131751ff5616eb2c310068d42ce991 --- /dev/null +++ b/include/aidge/data/Elts.hpp @@ -0,0 +1,124 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_ELTS_H_ +#define AIDGE_ELTS_H_ + +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +/** + * Base object for Aidge consumer-producer model (C-P model). + * It is a hybrid model: operator implementations can specify their C-P model + * with precise data (bytes) or with tokens. +*/ +struct Elts_t { + enum EltType { + Data, + Token, + Undef + }; + + NbElts_t data; + NbElts_t token; + EltType type; + + // Addition operator + inline Elts_t operator+(const Elts_t& other) const { + AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef, + "Incompatible C-P model types: {} + {}. Data and Token cannot be mixed.", type, other.type); + return Elts_t(data + other.data, token + other.token, (other.type == Undef) ? type : other.type); + } + + // Addition assignment operator + inline Elts_t& operator+=(const Elts_t& other) { + AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef, + "Incompatible C-P model types: {} += {}. Data and Token cannot be mixed.", type, other.type); + data += other.data; + token += other.token; + type = (other.type == Undef) ? type : other.type; + return *this; + } + + // Comparison operators + inline bool operator<(const Elts_t& other) const { + if (type == Elts_t::Undef || type == Elts_t::Token) { + // Nothing, or only a token is required: don't care about how much data has been produced for the token + return (token < other.token); + } + else if (type == Elts_t::Data && other.type != Elts_t::Token) { + // A precise amount of data is required, so the amount of produced data must be specified, a token is not enough + return (data < other.data); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Incompatible C-P model types: {} < {}. Data is expected for right-hand side.", type, other.type); + } + } + + inline bool operator>(const Elts_t& other) const { + if (type == Elts_t::Undef || type == Elts_t::Token) { + // Nothing, or only a token is required: don't care about how much data has been produced for the token + return (token > other.token); + } + else if (type == Elts_t::Data && other.type != Elts_t::Token) { + // A precise amount of data is required, so the amount of produced data must be specified, a token is not enough + return (data > other.data); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Incompatible C-P model types: {} > {}. Data is expected for right-hand side.", type, other.type); + } + } + + inline static Elts_t NoneElts() { + return Elts_t(0, 0, Elts_t::Undef); + } + + inline static Elts_t DataElts(NbElts_t data, NbElts_t token = 1) { + return Elts_t(data, token, Elts_t::Data); + } + + inline static Elts_t TokenElts(NbElts_t token) { + return Elts_t(0, token, Elts_t::Token); + } + +private: + inline Elts_t(NbElts_t data_, NbElts_t token_, EltType type_): + data(data_), token(token_), type(type_) {} +}; +} // end namespace Aidge + +template<> +struct fmt::formatter<Aidge::Elts_t> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(Aidge::Elts_t const& elt, FormatContext& ctx) { + return fmt::format_to(ctx.out(), "{}:{}", elt.data, elt.token); + } +}; + +namespace { +template <> +const char* const EnumStrings<Aidge::Elts_t::EltType>::data[] + = {"Data", "Token", "Undef"}; +} + +namespace Aidge { +inline auto format_as(Elts_t::EltType elt) { return EnumStrings<Aidge::Elts_t::EltType>::data[static_cast<int>(elt)]; } +} + +#endif /* AIDGE_ELTS_H_ */ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 4d719b6cb755bb2ddff96905f2e5b6bc24844e37..5ac9cf3c92b1951407e4c1892b1a8dc70a724013 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -21,7 +21,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/operator/OperatorTensor.hpp" -#include "aidge/scheduler/Scheduler.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" @@ -115,11 +115,11 @@ public: mGraph->setDataType(datatype); } - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override; - NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; - NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; + Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override; + Elts_t getNbConsumedData(IOIndex_t inputIdx) const override; + Elts_t getNbProducedData(IOIndex_t outputIdx) const override; void updateConsummerProducer() override; void forward() override; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 17c8204c1fec4a54e8194bf2db1dc6e5a616fd23..5c6ffad27b3ac6702c8fdbf6113e334e56a2deed 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -134,31 +134,31 @@ public: /** * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. - * @return NbElts_t + * @return Elts_t */ - virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const; // Amount of input data that cannot be overwritten during the execution. - virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; + virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; // Memory required at an output for a given input size. - virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; /** * @brief Total amount of consumed data from a specific input. * * @param inputIdx Index of the input analysed. - * @return NbElts_t + * @return Elts_t */ - virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; + virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Total amount of produced data ready to be used on a specific output. * * @param outputIdx Index of the output analysed. - * @return NbElts_t + * @return Elts_t */ - virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const; virtual void updateConsummerProducer(); diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 66c66d90b4ed465d31ed20dd41245fed7a71d58e..1e5a3940ba22c659121e76e1855353168d68441a 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -47,8 +47,7 @@ public: Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0]->resize(dims); - // mImpl = std::make_shared<OperatorImpl>(*this, ""); - mImpl = nullptr; + mImpl = std::make_shared<OperatorImpl>(*this, ""); } /** diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 2f77ae707ff66a6d68f649796d1bf07cce1e4498..97c608cd38ca76a4f40b8fb02282751a97ceed4e 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -22,6 +22,8 @@ namespace Aidge { +void constantFolding(std::shared_ptr<GraphView> graph); + // FUSE MATMUL + ADD -> FC /** diff --git a/include/aidge/scheduler/ParallelScheduler.hpp b/include/aidge/scheduler/ParallelScheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d471c65ff2d3e8a81c3992d1df06ba387559025e --- /dev/null +++ b/include/aidge/scheduler/ParallelScheduler.hpp @@ -0,0 +1,44 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_PARALLELSCHEDULER_H_ +#define AIDGE_PARALLELSCHEDULER_H_ + +#include <chrono> +#include <memory> +#include <set> +#include <string> +#include <vector> +#include <map> + +#include "aidge/scheduler/Scheduler.hpp" + +namespace Aidge { +/** + * Multi-threaded parallel scheduler with dynamic scheduling. +*/ +class ParallelScheduler : public Scheduler { +public: + ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + : Scheduler(graphView, upperNode) + { + // ctor + }; + ~ParallelScheduler() = default; + + /** + * @brief Run the provided Computational Graph with a batch of data + */ + virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); +}; +} // namespace Aidge + +#endif /* AIDGE_PARALLELSCHEDULER_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index b25ebd3c8de3830174c11d93d6eb60c8703c6a0d..79eeefb2b7e4c0ba94cc062955d0bcc326ffe468 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -28,8 +28,22 @@ namespace Aidge { class Node; class GraphView; -class SequentialScheduler { -private: +class Scheduler { +protected: + struct StaticSchedulingElement { + StaticSchedulingElement( + std::shared_ptr<Node> node_, + size_t early_ = static_cast<size_t>(-1), + size_t late_ = static_cast<size_t>(-1)) + : node(node_), early(early_), late(late_) {} + + std::shared_ptr<Node> node; + size_t early; + size_t late; + std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; + std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; + }; + struct SchedulingElement { SchedulingElement( std::shared_ptr<Node> node_, @@ -49,15 +63,25 @@ private: }; public: - SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + Scheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) : mGraphView(graphView), mUpperNode(upperNode) { // ctor }; - ~SequentialScheduler() = default; + virtual ~Scheduler() = default; - void generateScheduling(bool verbose = false); + /** + * Generate full static scheduling of the GraphView. + * For each node, an earliest and latest possible execution logical step + * is specified. Nodes that may be scheduled at the same logical step have + * no data dependency and can be run in parallel. + */ + void generateScheduling(); + + /** + * Reset all scheduling and associated nodes producer consumer. + */ void resetScheduling(); /** @@ -75,14 +99,10 @@ public: void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); /** - * @brief Run the provided Computational Graph with a batch of data - */ - void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); - - /** - * @brief Run the provided Computational Graph with a batch of data + * @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes. + * @param fileName Name of the generated file. */ - void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true, bool verbose = false); + void saveStaticSchedulingDiagram(const std::string& fileName) const; /** * @brief Save in a Markdown file the order of layers execution. @@ -94,14 +114,26 @@ public: * @brief Return a vector of Node ordered by the order they are called by the scheduler * @return std::vector<std::shared_ptr<Node>> */ - inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept { - return mStaticSchedule.at(step); - } + std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const; inline std::shared_ptr<GraphView> getGraphView() const noexcept { return mGraphView; } -private: +protected: + /** + * Generate an initial base scheduling for the GraphView. + * The scheduling is entirely sequential and garanteed to be valid w.r.t. + * each node producer-consumer model. + */ + std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const; + + /** + * Fill-in early and late scheduling step from initial base scheduling. + * For each node, specifies the earliest and latest possible execution + * logical step. + */ + void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const; + /** * @brief Set of layers receiving an input from currently processing layers * @@ -109,7 +141,7 @@ private: * @return std::set<std::shared_ptr<Node>> */ std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; - NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; + Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const; /** @brief Shared ptr to the scheduled graph view */ @@ -119,7 +151,7 @@ private: /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ std::vector<SchedulingElement> mScheduling; /** @brief List of nodes ordered by their */ - std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; + std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; size_t mStaticScheduleStep = 0; mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; }; diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be0a4a991adba0bef40acc26b8d84267a2010d4b --- /dev/null +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -0,0 +1,62 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_SEQUENTIALSCHEDULER_H_ +#define AIDGE_SEQUENTIALSCHEDULER_H_ + +#include <chrono> +#include <memory> +#include <set> +#include <string> +#include <vector> +#include <map> + +#include "aidge/scheduler/Scheduler.hpp" + +namespace Aidge { +/** + * Multi-threaded parallel scheduler with dynamic scheduling. +*/ +class SequentialScheduler : public Scheduler { +public: + enum SchedulingPolicy { + Default, + AsSoonAsPossible, + AsLateAsPossible + }; + + SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + : Scheduler(graphView, upperNode), + mSchedulingPolicy(Default) + { + // ctor + }; + inline void setSchedulingPolicy(SchedulingPolicy policy) { + mSchedulingPolicy = policy; + } + ~SequentialScheduler() = default; + + /** + * @brief Run the provided Computational Graph with a batch of data + */ + virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); + + /** + * @brief Run the provided Computational Graph with a batch of data + */ + void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true); + +private: + SchedulingPolicy mSchedulingPolicy; +}; +} // namespace Aidge + +#endif /* AIDGE_SEQUENTIALSCHEDULER_H_ */ diff --git a/include/aidge/scheduler/ThreadPool.hpp b/include/aidge/scheduler/ThreadPool.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5f2d9192def412d6084abc4eddf36ec31fe3aa84 --- /dev/null +++ b/include/aidge/scheduler/ThreadPool.hpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_THREADPOOL_H_ +#define AIDGE_THREADPOOL_H_ + +#include <thread> +#include <mutex> +#include <queue> +#include <vector> +#include <functional> +#include <condition_variable> +#include <atomic> + +namespace Aidge { +class ThreadPool { +public: + ThreadPool(size_t nbThreads = std::thread::hardware_concurrency()); + void queueJob(const std::function<void()>& job); + bool busy(); + virtual ~ThreadPool(); + +private: + void threadLoop(); + + bool mTerminate = false; + std::mutex mQueueMutex; + std::condition_variable mMutexCondition; + std::vector<std::thread> mThreads; + std::queue<std::function<void()>> mJobs; +}; +} // namespace Aidge + +#endif /* AIDGE_THREADPOOL_H_ */ diff --git a/include/aidge/utils/ErrorHandling.hpp b/include/aidge/utils/ErrorHandling.hpp index d4235d2db9b06597df80966e67306d84ac814a3c..f6a9aefe24a420e261a100041578dd751c4a1ee2 100644 --- a/include/aidge/utils/ErrorHandling.hpp +++ b/include/aidge/utils/ErrorHandling.hpp @@ -14,6 +14,7 @@ #define AIDGE_ERRORHANDLING_H_ #include <memory> +#include <cassert> #include <fmt/format.h> #include <fmt/ranges.h> diff --git a/include/aidge/utils/Log.hpp b/include/aidge/utils/Log.hpp index 8a18bbab34d3c1c86252833852abc5faca41dd96..f20a619c21f611fdbff9ce0cd4c912c0fcd54a9d 100644 --- a/include/aidge/utils/Log.hpp +++ b/include/aidge/utils/Log.hpp @@ -18,7 +18,15 @@ #include <fmt/format.h> #include <fmt/ranges.h> +#include "aidge/utils/Attributes.hpp" + namespace Aidge { +/** + * Helper to define a context anywhere, hidding the scoped variable name + * which has no relevance. +*/ +#define AIDGE_LOG_CONTEXT(...) const Log::Context logContext_##__LINE__(__VA_ARGS__) + /** * Aidge logging class, for displaying and file logging of events. */ @@ -33,6 +41,18 @@ public: Fatal }; + class Context { + public: + template <typename... Args> + Context(Args&&... args) { + Log::mContext.push_back(fmt::format(std::forward<Args>(args)...)); + } + + ~Context() { + Log::mContext.pop_back(); + } + }; + /** * Detailed messages for debugging purposes, providing information helpful * for developers to trace and identify issues. @@ -142,7 +162,13 @@ private: static Level mFileLevel; static std::string mFileName; static std::unique_ptr<FILE, decltype(&std::fclose)> mFile; + static std::vector<std::string> mContext; }; } +namespace { +template <> +const char *const EnumStrings<Aidge::Log::Level>::data[] = {"Debug", "Info", "Notice", "Warn", "Error", "Fatal"}; +} + #endif //AIDGE_LOG_H_ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index e116fa91cac4d3828e998c6a06825afb118ac52c..a6d1d7a9eb5d88dedaf73564847b0f4fbd797c43 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -130,20 +130,15 @@ void declare_registrable(py::module& m, const std::string& class_name){ */ #ifdef PYBIND #define SET_IMPL_MACRO(T_Op, op, backend_name) \ - \ - if (Registrar<T_Op>::exists(backend_name)) { \ - if(Py_IsInitialized()) { \ - auto obj = py::cast(&(op)); \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ - } else { \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ - } \ - } -#else -#define SET_IMPL_MACRO(T_Op, op, backend_name) \ - if (Registrar<T_Op>::exists(backend_name)) { \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + if(Py_IsInitialized()) { \ + auto obj = py::cast(&(op)); \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + } else { \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ } +#else +#define SET_IMPL_MACRO(T_Op, op, backend_name) \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); #endif } diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 97cf817176c733000eda8da6c6a213ccc22f1dc4..6a83805fc1af2e111dd1c9f49c669e0c2f9422aa 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -43,18 +43,18 @@ public: ); } - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override { + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_required_data", getNbRequiredData, inputIdx ); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_required_protected", getNbRequiredProtected, @@ -62,10 +62,10 @@ public: ); } - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_required_memory", getRequiredMemory, @@ -74,9 +74,9 @@ public: ); } - NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override { + Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_consumed_data", getNbConsumedData, @@ -84,9 +84,9 @@ public: ); } - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override { + Elts_t getNbProducedData(const IOIndex_t outputIdx) const override { PYBIND11_OVERRIDE_NAME( - NbElts_t, + Elts_t, OperatorImpl, "get_nb_produced_data", getNbProducedData, diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 1b541b60672cc28cfe318b7bcc029627d6491818..c0966e54d4f025a607aa9763a3657de5b39d2ff4 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -12,20 +12,31 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> #include "aidge/scheduler/Scheduler.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/scheduler/ParallelScheduler.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" namespace py = pybind11; namespace Aidge { void init_Scheduler(py::module& m){ - py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler") + py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) - .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>()) - .def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true, py::arg("verbose")=false) - .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) - .def("resetScheduling", &SequentialScheduler::resetScheduling) - .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) - .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0) + .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) + .def("resetScheduling", &Scheduler::resetScheduling) + .def("generate_scheduling", &Scheduler::generateScheduling) + .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0) + ; + + py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") + .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) + .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) + .def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true) + ; + + py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler") + .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) + .def("forward", &ParallelScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) ; } } diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 48d615a2b0a5ccb5a51a3edb28ac68dbd7d67501..2277db2421c36704270b81bdb6c45f19aaa891e4 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -20,48 +20,91 @@ Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend): mOp(op), mBackend(backend), - mNbConsumedData(mOp.nbInputs(), 0), - mNbProducedData(mOp.nbOutputs(), 0) + mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()), + mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts()) { //ctor } -Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mOp.getRawInput(inputIdx), "a valid input is required at index {} for operator type {}", inputIdx, mOp.type()); - // Requires the whole tensor by default - return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size(); + if (mOp.getRawInput(inputIdx)) { + const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); + if (!input->empty()) { + // Known amount of data: requires the whole tensor by default + return Elts_t::DataElts(input->size()); + } + else { + // Unknown amount of data: require a single token by default + return Elts_t::TokenElts(1); + } + } + + // Input not connected, meaning it is an optional input: do no require anything! + return Elts_t::NoneElts(); } -Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { AIDGE_ASSERT(mOp.getRawInput(inputIdx), "a valid input is required at index {} for operator type {}", inputIdx, mOp.type()); - // Protect the whole tensor by default - return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size(); + if (mOp.getRawInput(inputIdx)) { + const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); + if (!input->empty()) { + // Known amount of data: protect the whole tensor by default + return Elts_t::DataElts(input->size()); + } + else { + // Unknown amount of data: protect a single token by default + // (this does not really make sense for now, as getNbRequiredProtected() + // is supposed to give a precise amount of data to protect for + // memory management purpose...) + return Elts_t::TokenElts(1); + } + } + + // Input not connected, meaning it is an optional input: do no require anything! + return Elts_t::NoneElts(); } -Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, +Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { AIDGE_ASSERT(mOp.getRawOutput(outputIdx), "a valid output is required at index {} for operator type {}", outputIdx, mOp.type()); - // Requires the whole tensor by default, regardless of available data on inputs - return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size(); + if (mOp.getRawOutput(outputIdx)) { + const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx)); + if (!output->empty()) { + // Known amount of data: requires the whole tensor by default, + // regardless of available data on inputs + return Elts_t::DataElts(output->size()); + } + else { + // Unknown amount of data: require a single token by default + // (this does not really make sense for now, as getRequiredMemory() + // is supposed to give a precise amount of data to allocate for + // memory management purpose...) + return Elts_t::TokenElts(1); + } + } + + // Output not set, meaning it is an optional output: do no require anything! + return Elts_t::NoneElts(); } -Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(), "input index ({}) is out of bound ({}) for operator type {}", inputIdx, mNbConsumedData.size(), mOp.type()); return mNbConsumedData[static_cast<std::size_t>(inputIdx)]; } -Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const { +Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const { AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(), "output index ({}) is out of bound ({}) for operator type {}", outputIdx, mNbProducedData.size(), mOp.type()); @@ -81,8 +124,8 @@ void Aidge::OperatorImpl::updateConsummerProducer(){ } void Aidge::OperatorImpl::resetConsummerProducer(){ - std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0); - std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0); + std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts()); + std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts()); } void Aidge::OperatorImpl::forward() { diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 586adbff50facec5d9fab4a447011b34e8090a2b..f498d5e82710f7fe78f27323f252a5c8f07ef96c 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -165,7 +165,7 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd size_t inputIdx = 0; for (auto input : mInputNodes) { if (input.first != nullptr) { - fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|→{}|{}_{}\n", inputIdx, inputIdx, + fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"→{}\"|{}_{}\n", inputIdx, inputIdx, input.second, input.first->type(), namePtrTable.at(input.first)); } else { @@ -1258,7 +1258,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo if (deletedNode == mRootNode) { const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes(); - if(ranked_nodes.second== 0 ) + if(ranked_nodes.second== 0 || ranked_nodes.first.size() <= 1) { mRootNode = nullptr; } else { diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 8867a40de6f13994eec367afa8cc5dc2c994a3cf..149691f796d1d84212e9d7842a28e4cb79469e6a 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -173,7 +173,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) "Input index ({}) is out of bound ({}) for node {} (of type {})", inId, nbInputs(), name(), type()); if (mIdOutParents[inId] != gk_IODefaultIndex) { - Log::warn("Warning: filling a Tensor already attributed\n"); + Log::notice("Notice: filling a Tensor already attributed"); auto originalParent = input(inId); // remove original parent reference to child // find the output ID for original Parent diff --git a/src/operator/Cast.cpp b/src/operator/Cast.cpp index 3e594b49404999fee10eed3a22a7c0a78f765df0..4f1ac55898b11668ba1c2f5299f8e1ca1d4e5df1 100644 --- a/src/operator/Cast.cpp +++ b/src/operator/Cast.cpp @@ -34,6 +34,8 @@ void Aidge::Cast_Op::forward() { } void Aidge::Cast_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - SET_IMPL_MACRO(Cast_Op, *this, name); + if (Registrar<Cast_Op>::exists({name})) { + SET_IMPL_MACRO(Cast_Op, *this, name); + } mOutputs[0]->setBackend(name, device); } diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 45e7556265d1af4e95e50be4cf60e8067ded332f..46e9e1173af98ed5711aa0bbce54705fb61dc03c 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -37,7 +37,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbRequiredData(inputIdx); } @@ -47,12 +47,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbRequiredProtected(inputIdx); } @@ -62,12 +62,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t i return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { if (mImpl) { return mImpl->getRequiredMemory(outputIdx, inputsSize); } @@ -77,12 +77,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t output return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbConsumedData(inputIdx); } @@ -92,12 +92,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } -Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { +Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { if (mImpl) { return mImpl->getNbProducedData(outputIdx); } @@ -107,7 +107,7 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c return outputOp.first->getOperator()->getNbProducedData(outputOp.second); } else { - return 0; + return Elts_t::NoneElts(); } } } diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index e4213cad80ebdc177649b0c25e4fc49222993211..317bbd364572f49a714e328bf33f3cd58c19215f 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -31,27 +31,27 @@ Aidge::Operator::~Operator() noexcept = default; // IMPLEMENTATION /////////////////////////////////////////////////////// -Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredData(): an implementation is required for {}!", type()); return mImpl->getNbRequiredData(inputIdx); } -Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type()); return mImpl->getNbRequiredProtected(inputIdx); } -Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { +Aidge::Elts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type()); return mImpl->getRequiredMemory(outputIdx, inputsSize); } -Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type()); return mImpl->getNbConsumedData(inputIdx); } -Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { +Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type()); return mImpl->getNbProducedData(outputIdx); } diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp index 43e991288c483f07138a2b236a2c4925ea0a3754..38bbbc14846f8f4356602b1d3a66058439bb37d0 100644 --- a/src/operator/Producer.cpp +++ b/src/operator/Producer.cpp @@ -39,20 +39,20 @@ Aidge::Producer_Op::Producer_Op(const std::shared_ptr<Aidge::Tensor> tensor, boo (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - nullptr); + std::make_shared<OperatorImpl>(*this, "")); } else { setImpl((mOutputs[0]->hasImpl()) ? (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - nullptr); + std::make_shared<OperatorImpl>(*this, "")); } #else setImpl((mOutputs[0]->hasImpl()) ? (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - nullptr); + std::make_shared<OperatorImpl>(*this, "")); #endif } @@ -73,20 +73,20 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op) (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - nullptr); + std::make_shared<OperatorImpl>(*this, "")); } else { setImpl((mOutputs[0]->hasImpl()) ? (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - nullptr); + std::make_shared<OperatorImpl>(*this, "")); } #else setImpl((mOutputs[0]->hasImpl()) ? (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - nullptr); + std::make_shared<OperatorImpl>(*this, "")); #endif // if (mOutputs[0]->hasImpl()) { // if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ @@ -107,16 +107,16 @@ void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t auto obj = py::cast(&(*this)); setImpl((Registrar<Producer_Op>::exists({name})) ? Registrar<Producer_Op>::create(name)(*this) : - std::make_shared<OperatorImpl>(*this, name)); + std::make_shared<OperatorImpl>(*this, "")); } else { setImpl((Registrar<Producer_Op>::exists({name})) ? Registrar<Producer_Op>::create(name)(*this) : - std::make_shared<OperatorImpl>(*this, name)); + std::make_shared<OperatorImpl>(*this, "")); } #else setImpl((Registrar<Producer_Op>::exists({name})) ? Registrar<Producer_Op>::create(name)(*this) : - std::make_shared<OperatorImpl>(*this, name)); + std::make_shared<OperatorImpl>(*this, "")); #endif mOutputs[0]->setBackend(name, device); } \ No newline at end of file diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..42fb45224614ca2655165a69b974cfe229e27f90 --- /dev/null +++ b/src/recipes/ConstantFolding.cpp @@ -0,0 +1,86 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <cassert> +#include <memory> +#include <set> +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { + bool folded; + do { + folded = false; + std::set<std::shared_ptr<Node>> candidates; + for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) { + if (nodePtr->type() == Producer_Op::Type) { + const auto& childs = nodePtr->getChildren(); + candidates.insert(childs.begin(), childs.end()); + } + } + + for (const auto& node : candidates) { + bool foldable = true; + auto replaceGraph = std::make_shared<GraphView>(); + for (const auto& input : node->inputs()) { + if (input.first) { + if (input.first->type() != Producer_Op::Type) { + foldable = false; + break; + } + + const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); + if (!producer->getAttr<bool>("Constant")) { + Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", + node->name(), node->type(), input.first->name()); + foldable = false; + break; + } + + replaceGraph->add(input.first, false); + } + } + + if (foldable) { + Log::info("Folding node {} (of type {})", node->name(), node->type()); + replaceGraph->add(node, false); + + node->forward(); + + auto prodGraph = std::make_shared<GraphView>(); + const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); + + for (IOIndex_t output = 0; output < node->nbOutputs(); ++output) { + const auto computedOutput = std::make_shared<Tensor>(op->getOutput(output)->clone()); + const auto newProd = Producer(computedOutput, node->name() + "_" + std::to_string(output), true); + + // Add output in right order + prodGraph->add(newProd); + } + + if (GraphView::replace(replaceGraph, prodGraph)) { + folded = true; + } + else { + Log::warn("Error with replace when folding node {} (of type {})", + node->name(), node->type()); + } + } + } + } + while (folded); +} diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index b57c1c3fc5e4b12dbd0004472a864ddaa864116e..6582038e981bb58534d04ded57052f6a0f83e9a9 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -39,11 +39,11 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< std::shared_ptr<Node> bias = nullptr; if (addNode->getParent(0) == matmulNode) { AIDGE_ASSERT(matmulNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe."); - bias = addNode->getParent(1)->cloneSharedOperators(); + bias = addNode->getParent(1); } else if (addNode->getParent(1) == matmulNode) { AIDGE_ASSERT(matmulNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe."); - bias = addNode->getParent(0)->cloneSharedOperators(); + bias = addNode->getParent(0); } std::shared_ptr<Node> weight = nullptr; @@ -51,13 +51,13 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< || (matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type && matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)) { - weight = matmulNode->getParent(1)->cloneSharedOperators(); + weight = matmulNode->getParent(1); } else if ((matmulNode->getParent(0) && !matmulNode->getParent(1)) || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type)) { - weight = matmulNode->getParent(0)->cloneSharedOperators(); + weight = matmulNode->getParent(0); } else if (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) @@ -65,7 +65,7 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< // If both inputs are producers, there is an ambiguity, but both options // result in a correct solution. Log::notice("Notice: both MatMul inputs are Producers, assume data at input#0 and weights at input#1."); - weight = matmulNode->getParent(1)->cloneSharedOperators(); + weight = matmulNode->getParent(1); } AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator."); @@ -90,9 +90,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< // Step 2 : Branch existing producers & create the others // link weights & bias - weight->addChild(fc, 0, 1); + weight->cloneSharedOperators()->addChild(fc, 0, 1); if (bias) { - bias->addChild(fc, 0, 2); + bias->cloneSharedOperators()->addChild(fc, 0, 2); } @@ -100,8 +100,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory? - auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)}); - GraphView::replace({matmulNode, addNode, addNode->getParent(1), matmulNode->getParent(1)}, newNodes); + auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1), fc->getParent(2)}); + GraphView::replace({matmulNode, addNode, bias, weight}, newNodes); } diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1dd13fe2100122002d4ed068ada4851b1bfba463 --- /dev/null +++ b/src/scheduler/ParallelScheduler.cpp @@ -0,0 +1,200 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/scheduler/ParallelScheduler.hpp" +#include "aidge/scheduler/ThreadPool.hpp" + +#include <chrono> +#include <memory> +#include <set> +#include <string> + +#include <fmt/ranges.h> +#include <fmt/color.h> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/Memorize.hpp" +#include "aidge/operator/MetaOperator.hpp" + +void Aidge::ParallelScheduler::forward(bool forwardDims, std::vector<std::shared_ptr<Aidge::Tensor>> data) { + // Collect all data input of the graph (that are producers) + if (!data.empty()){ + connectInputs(data); + } + + // Forward dims (if allowed) + if (forwardDims) {mGraphView->forwardDims(); } + + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + // Sort static scheduling, the order will be the prefered threads scheduling + // order for non critical nodes + std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); + + // The thread pool has N threads running to process nodes. + // Thread pooling avoid the overhead of threads creation and deletion for + // each node execution. + ThreadPool pool; + + size_t latest = 0; + std::mutex schedulingMutex; + std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished; + + while (!staticSchedule.empty()) { + Log::debug("Step {}", latest); + + std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish; + + // Run all nodes that must be run at this step: latest (critical nodes) + for (size_t i = 0; i < staticSchedule.size(); ) { + auto runnable = staticSchedule[i]; + + if (runnable->late == latest) { + // Wait for potential preceding non-critical nodes to finish + while (true) { + bool ready = true; + for (auto elt : runnable->laterThan) { + ready = ready && finished.at(elt); + } + if (!ready) { + std::this_thread::yield(); + } + else { + break; + } + } + + // Add the critical node to the thread pool queue, to be run ASAP + finished[runnable] = false; + pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + const auto tStart = std::chrono::high_resolution_clock::now(); + node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + finished = true; + { + std::unique_lock<std::mutex> lock(schedulingMutex); + mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); + } + }); + staticSchedule.erase(staticSchedule.begin() + i); + mustFinish.push_back(runnable); + + Log::debug(" run critical {}", namePtrTable.at(runnable->node)); + + // Ensure the following nodes cannot start earlier than next step + for (auto elt : runnable->earlierThan) { + if (elt->early < latest + 1) { + Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); + elt->early = latest + 1; + AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); + } + } + } + else if (runnable->early > latest + 1) { + // There cannot be more node that must be run at latest + 1 + // latest + 1 and not latest because early may have been updated + // for some elements to latest + 1 (above). + break; + } + else { + ++i; + } + } + + // If some threads are still available, run next early nodes + // These nodes are non-critical, meaning they can still be run at least + // in the next step + for (size_t i = 0; i < staticSchedule.size(); ) { + auto runnable = staticSchedule[i]; + if (!pool.busy() && runnable->early <= latest) { + // Check that potential preceding non-critical nodes are finished + bool ready = true; + for (auto elt : runnable->laterThan) { + ready = ready && finished.at(elt); + } + + if (ready) { + // All preceding nodes have finished, this node can be run. + // Add the node to the thread pool queue, to be run ASAP + finished[runnable] = false; + pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + const auto tStart = std::chrono::high_resolution_clock::now(); + node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + finished = true; + { + std::unique_lock<std::mutex> lock(schedulingMutex); + mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); + } + }); + staticSchedule.erase(staticSchedule.begin() + i); + + Log::debug(" run {}", namePtrTable.at(runnable->node)); + + // Ensure the following nodes cannot start earlier than next step + for (auto elt : runnable->earlierThan) { + if (elt->early < latest + 1) { + Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); + elt->early = latest + 1; + AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); + } + } + } + else { + // The node cannot be run yet, because preceding nodes are + // still running, move to the next one in schedule + ++i; + } + } + else { + // Thread pool is already full or no more node can be run at + // this step (latest) + break; + } + } + + // Wait for all nodes that must finish at latest to be finished + // By scheduling construction, no other node can be started before all + // nodes at latest step are finished + while (true) { + bool ready = true; + for (auto elt : mustFinish) { + ready = ready && finished.at(elt); + } + if (!ready) { + std::this_thread::yield(); + } + else { + break; + } + } + + ++latest; + } + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } +} diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 94baf6a3e7b6e2e86de4e2d72ed19bfd9338392e..b3b2d5e5b3944e64b6df7b8499237477d88d5b50 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -23,30 +23,17 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Types.h" -#include "aidge/recipes/GraphViewHelper.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Memorize.hpp" #include "aidge/operator/MetaOperator.hpp" -void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { - putchar('['); - int pos = static_cast<int>(barWidth * progress); - for (int i = 0; i < barWidth; ++i) { - if (i <= pos) - putchar('#'); - else - putchar(' '); - } - fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo); - fflush(stdout); +void Aidge::Scheduler::generateScheduling() { + auto schedule = generateBaseScheduling(); + generateEarlyLateScheduling(schedule); + mStaticSchedule.push_back(schedule); } -void Aidge::SequentialScheduler::generateScheduling(bool verbose) { - // TODO: For loop on the list of node to run - // run sequencially every runnable consumers once - // TODO: handle memory allocation in scheduler - // TODO: optimize memory usage - +std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const { // 1) Setup initial consumers list: // It is the list of input nodes std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); @@ -60,15 +47,15 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { const auto producersConsumers = getConsumers(producers); consumers.insert(producersConsumers.begin(), producersConsumers.end()); - std::map<std::shared_ptr<Node>, std::string> namePtrTable; - if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + const std::map<std::shared_ptr<Node>, std::string> namePtrTable + = mGraphView->getRankedNodesName("{0} ({1}#{3})"); // Still consumers are consumers that were run by can still consume data. // They must be run AFTER the remaining consumer to ensure a non-greedy // producers-consumers model! std::set<std::shared_ptr<Node>> stillConsumers; - mStaticSchedule.push_back(std::vector<std::shared_ptr<Node>>()); + std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; do { // 2) From the current consumers list, check if any prior consumer node @@ -81,34 +68,28 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // If the prior node is of another type, it replaces the initial consumer // in the new priorConsumers list. The initial consumer will become // again a consumer later, by construction. - if (verbose) fmt::print("List of consumers with their priors:\n"); + Log::debug("List of consumers with their priors:"); std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> priorConsumers; mPriorCache.clear(); for (const auto& consumer : consumers) { - if (verbose) { - fmt::print("\t- consumer: "); - fmt::print(fg(fmt::color::orange), namePtrTable[consumer]); - fmt::print("\n"); - } + Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); const auto& prior = getPriorProducersConsumers(consumer); if (prior.isPrior) { - if (verbose) { - std::vector<std::string> requiredProducersName; - std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(), - std::back_inserter(requiredProducersName), - [&namePtrTable](auto val){ return namePtrTable[val]; }); - fmt::print("\t\trequired producers: {}\n", requiredProducersName); - - std::vector<std::string> priorConsumersName; - std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(), - std::back_inserter(priorConsumersName), - [&namePtrTable](auto val){ return namePtrTable[val]; }); - fmt::print("\t\tprior consumers: {}\n", priorConsumersName); - } + std::vector<std::string> requiredProducersName; + std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(), + std::back_inserter(requiredProducersName), + [&namePtrTable](auto val){ return namePtrTable.at(val); }); + Log::debug("\t\trequired producers: {}", requiredProducersName); + + std::vector<std::string> priorConsumersName; + std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(), + std::back_inserter(priorConsumersName), + [&namePtrTable](auto val){ return namePtrTable.at(val); }); + Log::debug("\t\tprior consumers: {}", priorConsumersName); requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend()); priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend()); @@ -127,7 +108,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // Producers are special nodes that generate data on demand. for (const auto& requiredProducer : requiredProducers) { requiredProducer->getOperator()->updateConsummerProducer(); - mStaticSchedule.back().push_back(requiredProducer); + schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer)); } // 5) Find runnable consumers. @@ -136,32 +117,33 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // runnable because some may depend on the execution of others (when // there is multiple successive priors for example). std::set<std::shared_ptr<Node>> runnableConsumers; - if (verbose) fmt::print("Updated list of consumers:\n"); + Log::debug("Updated list of consumers:"); for (const auto& consumer : consumers) { - if (verbose) { - fmt::print("\t- consumer: "); - fmt::print(fg(fmt::color::orange), namePtrTable[consumer]); - fmt::print("\n\t\tC/R:\t"); - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); - } - fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); - fmt::print("\n\t\tP:\t"); - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); - } - fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); - fmt::print("\n"); + Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); + + std::string crLog = "\t\tC/R:\t"; + for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { + crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); + } + crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); + Log::debug("{}", crLog); + + std::string pLog = "\t\tP:\t"; + for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { + pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } + pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + Log::debug("{}", pLog); bool isRunnable = true; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { - if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0 - && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > + AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); + + if ((consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > getNbAvailableData(consumer, inputIdx)) { - if (verbose) fmt::print(" not runnable: C{} + R{} > P{} for input #{}\n", + Log::debug(" not runnable: C{} + R{} > P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), consumer->getOperator()->getNbRequiredData(inputIdx), getNbAvailableData(consumer, inputIdx), inputIdx); @@ -179,7 +161,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // 5) If not consumer is runnable, it is a stop condition! if (runnableConsumers.empty()) { - if (verbose) fmt::print("********************\n"); + Log::debug("********************"); // No consumer is runnable: some required data is missing for all of // them. There is two possibilities: // - At least one required data source is exhausted, which may be @@ -194,39 +176,42 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // At this point, simultaneously runnable consumers have no data // dependency and could be run in parallel! for (const auto& runnable : runnableConsumers) { - if (verbose) fmt::print("Runnable: {}\n", namePtrTable[runnable]); + Log::debug("Runnable: {}", namePtrTable.at(runnable)); runnable->getOperator()->updateConsummerProducer(); - mStaticSchedule.back().push_back(runnable); + schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable)); } // 7) Update consumers list - if (verbose) fmt::print("Updating producer and consumer lists...\n"); + Log::debug("Updating producer and consumer lists..."); for (const auto& consumer : runnableConsumers) { - if (verbose) { - fmt::print("\t- consumer: {}\n\t\tC/R:\t", - namePtrTable[consumer]); - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); - } - fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); - fmt::print("\n\t\tP:\t"); - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); - } - fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); - fmt::print("\n"); + Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); + + std::string crLog = "\t\tC/R:\t"; + for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { + crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); + } + crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); + Log::debug("{}", crLog); + + std::string pLog = "\t\tP:\t"; + for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { + pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } + pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + Log::debug("{}", pLog); // 7.1) If the current consumer has still data to consume, it will // be put back in the consumers list once the remaining consumers // have been exhausted. bool isStillConsumer = false; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { + AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); + if (consumer->getOperator()->getNbConsumedData(inputIdx) < getNbAvailableData(consumer, inputIdx)) { - if (verbose) fmt::print(" still consumer: C{} < P{} for input #{}\n", + Log::debug(" still consumer: C{} < P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), getNbAvailableData(consumer, inputIdx), inputIdx); @@ -245,17 +230,25 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { IOIndex_t inputIdx = 0; for (const auto& childParent : child->getParents()) { if (childParent == consumer) { - if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) { + AIDGE_LOG_CONTEXT("Consumer node {} input #{} / Producer node {} output #{}", + namePtrTable.at(child), inputIdx, namePtrTable.at(consumer), outId); + + if (child->getOperator()->getNbConsumedData(inputIdx) < consumer->getOperator()->getNbProducedData(outId)) { isProducer = true; + break; } } ++inputIdx; } + + if (isProducer) { + break; + } } } /* if (consumer->getOperator()->getNbProducedData(outId) > 0) { - if (verbose) fmt::print(" also producer\n"); + Log::debug(" also producer"); // make sure consumer is also a producer producers.insert(consumer); @@ -269,7 +262,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { consumers.erase(consumer); if (isProducer) { - if (verbose) fmt::print(" also producer\n"); + Log::debug(" also producer"); // make sure consumer is also a producer producers.insert(consumer); @@ -292,18 +285,95 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { stillConsumers.clear(); } - if (verbose) fmt::print("********************\n"); + Log::debug("********************"); } while (!consumers.empty()); - if (verbose) { - if (!consumers.empty()) { - fmt::print("/!\\ Remaining consumers: possible dead-lock\n"); - fmt::print("********************\n"); + mPriorCache.clear(); + + if (!consumers.empty()) { + Log::warn("Remaining consumers: possible dead-lock"); + } + + return schedule; +} + +void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { + size_t latest = 0; + // Calculate early (logical) start + for (size_t elt = 0; elt < schedule.size(); ++elt) { + const auto node = schedule[elt]->node; + const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), + [node](const auto& v) { return (v->node == node); }); + + // Node can be run the earliest just after its childs were run the last time! + size_t early = 0; + if (itNode != schedule.rend()) { + for (const auto& child : node->getChildren()) { + // Find child node next scheduled position + const auto it = std::find_if(schedule.rend() - elt, itNode, + [child](const auto& v) { return (v->node == child); }); + AIDGE_INTERNAL_ASSERT(it != schedule.rend()); + + const size_t step = std::distance(schedule.begin(), it.base()) - 1; + early = std::max(early, schedule[step]->early + 1); + schedule[step]->earlierThan.push_back(schedule[elt]); + } + } + + // Node can be run the earliest just after its latest parent was run + for (const auto& parent : node->getParents()) { + // Find parent node latest scheduled position + const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), + [parent](const auto& v) { return (v->node == parent); }); + if (it != schedule.rend()) { + const size_t step = std::distance(schedule.begin(), it.base()) - 1; + early = std::max(early, schedule[step]->early + 1); + schedule[step]->earlierThan.push_back(schedule[elt]); + } } + + latest = std::max(latest, early); + schedule[elt]->early = early; + } + + // Calculate late (logical) start + for (size_t elt = schedule.size(); elt-- != 0; ) { + const auto node = schedule[elt]->node; + const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [node](const auto& v) { return (v->node == node); }); + + // Node can be run the latest just before its parents are run the next time! + size_t late = latest; + if (itNode != schedule.end()) { + for (const auto& parent : node->getParents()) { + // Find child node next scheduled position + const auto it = std::find_if(schedule.begin() + elt + 1, itNode, + [parent](const auto& v) { return (v->node == parent); }); + AIDGE_INTERNAL_ASSERT(it != schedule.end()); + + const size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); + } + } + + // Node can be run the latest just before its earliest child is run + for (const auto& child : node->getChildren()) { + // Find child node earliest scheduled position + const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [child](const auto& v) { return (v->node == child); }); + if (it != schedule.end()) { + const size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); + } + } + + schedule[elt]->late = late; } } -void Aidge::SequentialScheduler::resetScheduling() { +void Aidge::Scheduler::resetScheduling() { for (auto node : mGraphView->getNodes()) { node->getOperator()->resetConsummerProducer(); } @@ -316,28 +386,33 @@ void Aidge::SequentialScheduler::resetScheduling() { /** * This version is a simplified version without special handling of concatenation. */ -Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { +Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { MemoryManager memManager; - for (const auto& shedule : mStaticSchedule) { - for (const auto& node : shedule) { + for (size_t step = 0; step < mStaticSchedule.size(); ++step) { + for (const auto& node : getStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; } const auto childs = node->getChildren(); - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, + "Operator must be of Tensor type for node {} (of type {}).", + node->name(), node->type()); const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; // Allocate a memory plane for each node's output for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { - const size_t requiredSize = op->getRequiredMemory(outputIdx, {}); + const auto requiredSize = op->getRequiredMemory(outputIdx, {}); + AIDGE_ASSERT(requiredSize.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); // By default, specifies a fully monolithic memory block - size_t size = requiredSize; + size_t size = requiredSize.data; size_t stride = 0; size_t length = 1; size_t count = 1; @@ -369,21 +444,27 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer // memSpace should not be already released && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1) { - const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx)); + const auto requiredData = op->getNbRequiredData(inputIdx); + const auto requiredProtected = op->getNbRequiredProtected(inputIdx); + AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data, + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + node->name(), node->type()); + + const bool isWrappable = (requiredProtected.data < requiredData.data); const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; if (isWrappable || !memManager.isWrapAround( memPlane.memSpace, memPlane.getFinalOffset() - memPlane.memSpace->offset, - requiredSize)) + requiredSize.data)) { - if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx) + if (memPlane.getSize() > wrapAroundSize + requiredProtected.data && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end()) { - wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx); - if (requiredSize > wrapAroundSize) { - wrapAroundExtra = requiredSize - wrapAroundSize; + wrapAroundSize = memPlane.getSize() - requiredProtected.data; + if (requiredSize.data > wrapAroundSize) { + wrapAroundExtra = requiredSize.data - wrapAroundSize; } wrapAroundMemPlane[outputIdx] = &memPlane; } @@ -400,17 +481,17 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer const MemoryManager::MemoryPlane& memPlane = (wrapAroundBuffer && wrapAroundSize > 0) ? (*wrapAroundMemPlane[outputIdx]) : - memManager.allocate(requiredSize, childs, stride, length, count); + memManager.allocate(requiredSize.data, childs, stride, length, count); if (wrapAroundBuffer && wrapAroundSize > 0) { memManager.reallocate(memPlane, node, 0, - requiredSize, true, wrapAroundExtra, childs, stride, length, count); + requiredSize.data, true, wrapAroundExtra, childs, stride, length, count); } else { memManager.reallocate(memPlane.memSpace, node, memPlane.offset, - requiredSize, false, 0, childs, stride, length, count); + requiredSize.data, false, 0, childs, stride, length, count); } } @@ -422,7 +503,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer return memManager; } -void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){ +void Aidge::Scheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){ // This version of connect inputs only connects tensor inputs in input data producers. auto inputNodes = mGraphView->getOrderedInputs(); @@ -435,102 +516,7 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge } } - -void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) { - - // Collect all data input of the graph (that are producers) - if (!data.empty()){ - connectInputs(data); - } - - // Forward dims (if allowed) - if (forwardDims) {mGraphView->forwardDims(); } - - // Generate scheduling *only if empty* - // If scheduling was already generated (in one or several steps, i.e. one or - // several successive call to generateScheduling()), do not generate it twice - if (mStaticSchedule.empty()) { - this->generateScheduling(verbose); - } - - const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - size_t cpt = 0; - for (const auto& runnable : mStaticSchedule.at(mStaticScheduleStep)) { - if (verbose) - fmt::print("run: {}\n", namePtrTable.at(runnable)); - else - drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, - (std::string("running ") + namePtrTable.at(runnable))); - const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); - cpt++; - } - if (!verbose) drawProgressBar(1.0, 50, " "); - fmt::print("\n"); - - ++mStaticScheduleStep; - if (mStaticScheduleStep == mStaticSchedule.size()) { - mStaticScheduleStep = 0; - } -} - -void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad, bool verbose) { - // create ad set Grad values - if (instanciateGrad) { compile_gradient(mGraphView); } - - const auto& ordered_outputs = mGraphView->getOrderedOutputs(); - AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \ - right number of data objects to run the backward function. \ - {} outputs detected for the current GraphView when {} were \ - provided.", ordered_outputs.size(), data.size()); - for (std::size_t i = 0; i < ordered_outputs.size(); ++i) { - const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator()); - const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad(); - AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size."); - *t_grad = data[i]->clone(); - } - // Generate scheduling *only if empty* - // If scheduling was already generated (in one or several steps, i.e. one or - // several successive call to generateScheduling()), do not generate it twice - if (mStaticSchedule.empty()) { - this->generateScheduling(); - } - - // map of node <-> info to display with verbose - const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - // Clear previous scheduling results - mScheduling.clear(); - - std::size_t cpt = 0; - // run scheduled operators in reverse order - const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep); - for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) { - if (verbose) - fmt::print("run: {}\n", namePtrTable.at(*runnable)); - else - drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, - (std::string("running ") + namePtrTable.at(*runnable))); - const auto tStart = std::chrono::high_resolution_clock::now(); - (*runnable)->backward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(*runnable, tStart, tEnd)); - cpt++; - } - if (!verbose) drawProgressBar(1.0, 50, " "); - fmt::print("\n"); - - ++mStaticScheduleStep; - if (mStaticScheduleStep == mStaticSchedule.size()) { - mStaticScheduleStep = 0; - } -} - - -void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -560,7 +546,43 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa fmt::print(fp.get(), "\n"); } -std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( +void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { + auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); + + if (!fp) { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Could not create scheduling diagram log file: {}", fileName + ".mmd"); + } + + fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n"); + + if (!mStaticSchedule.empty()) { + const std::map<std::shared_ptr<Node>, std::string> namePtrTable + = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + for (const auto& schedule : mStaticSchedule) { + for (const auto& element : schedule) { + auto name = namePtrTable.at(element->node); + // Mermaid does not allow : character in task title + std::replace(name.begin(), name.end(), ':', '_'); + + fmt::print(fp.get(), "{} :{}, {}\n", + name, element->early, element->late); + } + } + } + + fmt::print(fp.get(), "\n"); +} + +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(size_t step) const { + const auto& staticSchedule = mStaticSchedule.at(step); + std::vector<std::shared_ptr<Node>> schedule; + std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); + return schedule; +} + +std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers( const std::set<std::shared_ptr<Node>>& producers) const { std::set<std::shared_ptr<Node>> consumers; @@ -577,7 +599,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( return consumers; } -Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { const auto parent = node->inputs()[inputIdx]; if (parent.first) { @@ -608,17 +630,17 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared // In this case, we assume a single-use data (unlike a Producer, which // keep producing the data each time it is needed). fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); - return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size(); + return Elts_t::DataElts(std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size()); } else { // Input is not connected, this is an error AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); } - return 0; + return Elts_t::NoneElts(); } -Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( +Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { const auto priorCache = mPriorCache.find(node); @@ -630,32 +652,36 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: IOIndex_t inputIdx = 0; for (const auto& parent : node->inputs()) { - if (parent.first && - (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > - parent.first->getOperator()->getNbProducedData(parent.second)) - { - if (!mGraphView->inView(parent.first)) { - // Do not schedule prior outside the current graph! - return PriorProducersConsumers(); - } - - if (parent.first->type() == Producer_Op::Type) { - prior.requiredProducers.insert(parent.first); - prior.priorConsumers.insert(node); - } - else if (parent.first->type() == Memorize_Op::Type) { - // Break cycles - return PriorProducersConsumers(); - } - else { - const auto& parentPrior = getPriorProducersConsumers(parent.first); + if (parent.first) { + AIDGE_LOG_CONTEXT("Producer node {} (of type {}) output #{}", + parent.first->name(), parent.first->type(), parent.second); + + if ((node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > + parent.first->getOperator()->getNbProducedData(parent.second)) + { + if (!mGraphView->inView(parent.first)) { + // Do not schedule prior outside the current graph! + return PriorProducersConsumers(); + } - if (!parentPrior.isPrior) { + if (parent.first->type() == Producer_Op::Type) { + prior.requiredProducers.insert(parent.first); + prior.priorConsumers.insert(node); + } + else if (parent.first->type() == Memorize_Op::Type) { + // Break cycles return PriorProducersConsumers(); } else { - prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); - prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend()); + const auto& parentPrior = getPriorProducersConsumers(parent.first); + + if (!parentPrior.isPrior) { + return PriorProducersConsumers(); + } + else { + prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); + prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend()); + } } } } diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1454cb2bac982465248e7761673ad3646165c2dc --- /dev/null +++ b/src/scheduler/SequentialScheduler.cpp @@ -0,0 +1,116 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/scheduler/SequentialScheduler.hpp" + +#include <chrono> +#include <memory> +#include <set> +#include <string> + +#include <fmt/ranges.h> +#include <fmt/color.h> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/Memorize.hpp" +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/recipes/GraphViewHelper.hpp" + +void Aidge::SequentialScheduler::forward(bool forwardDims, std::vector<std::shared_ptr<Aidge::Tensor>> data) { + // Collect all data input of the graph (that are producers) + if (!data.empty()){ + connectInputs(data); + } + + // Forward dims (if allowed) + if (forwardDims) {mGraphView->forwardDims(); } + + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + // Sort static scheduling according to the policy + std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); + + if (mSchedulingPolicy == AsSoonAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); }); + } + else if (mSchedulingPolicy == AsLateAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); }); + } + + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + for (const auto& runnable : staticSchedule) { + Log::debug("run: {}", namePtrTable.at(runnable->node)); + + const auto tStart = std::chrono::high_resolution_clock::now(); + runnable->node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + } + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } +} + +void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad) { + // create ad set Grad values + if (instanciateGrad) { compile_gradient(mGraphView); } + + const auto& ordered_outputs = mGraphView->getOrderedOutputs(); + AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \ + right number of data objects to run the backward function. \ + {} outputs detected for the current GraphView when {} were \ + provided.", ordered_outputs.size(), data.size()); + for (std::size_t i = 0; i < ordered_outputs.size(); ++i) { + const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator()); + const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad(); + AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size."); + *t_grad = data[i]->clone(); + } + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + // map of node <-> info to display with verbose + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + // run scheduled operators in reverse order + const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep); + for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) { + Log::debug("run: {}", namePtrTable.at((*runnable)->node)); + + const auto tStart = std::chrono::high_resolution_clock::now(); + (*runnable)->node->backward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement((*runnable)->node, tStart, tEnd)); + } + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } +} diff --git a/src/scheduler/ThreadPool.cpp b/src/scheduler/ThreadPool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e81ab7a76f8a063b3ef33f5b24ecd2396267852e --- /dev/null +++ b/src/scheduler/ThreadPool.cpp @@ -0,0 +1,65 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/scheduler/ThreadPool.hpp" + +Aidge::ThreadPool::ThreadPool(size_t nbThreads) { + for (size_t i = 0; i < nbThreads; ++i) { + mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this)); + } +} + +void Aidge::ThreadPool::threadLoop() { + while (true) { + std::function<void()> job; + { + std::unique_lock<std::mutex> lock(mQueueMutex); + mMutexCondition.wait(lock, [this] { + return !mJobs.empty() || mTerminate; + }); + if (mTerminate) { + return; + } + job = mJobs.front(); + mJobs.pop(); + } + job(); + } +} + +void Aidge::ThreadPool::queueJob(const std::function<void()>& job) { + { + std::unique_lock<std::mutex> lock(mQueueMutex); + mJobs.push(job); + } + mMutexCondition.notify_one(); +} + +bool Aidge::ThreadPool::busy() { + bool poolbusy; + { + std::unique_lock<std::mutex> lock(mQueueMutex); + poolbusy = !mJobs.empty(); + } + return poolbusy; +} + +Aidge::ThreadPool::~ThreadPool() { + { + std::unique_lock<std::mutex> lock(mQueueMutex); + mTerminate = true; + } + mMutexCondition.notify_all(); + for (std::thread& active_thread : mThreads) { + active_thread.join(); + } + mThreads.clear(); +} diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp index 7649809339f4ebf716a7287f5744fb94a5b67ce2..03ecded8f5a193a8ab00cf9dc7be502b98205de2 100644 --- a/src/utils/Log.cpp +++ b/src/utils/Log.cpp @@ -12,13 +12,42 @@ #include "aidge/utils/Log.hpp" #include "aidge/utils/ErrorHandling.hpp" +#include <cstdlib> + #include <fmt/color.h> #include <fmt/chrono.h> -Aidge::Log::Level Aidge::Log::mConsoleLevel = Info; -Aidge::Log::Level Aidge::Log::mFileLevel = Debug; -std::string Aidge::Log::mFileName = "aidge.log"; +Aidge::Log::Level Aidge::Log::mConsoleLevel = []() { + const char* logLevel = std::getenv("AIDGE_LOGLEVEL_CONSOLE"); + if (logLevel != nullptr) { + for (std::size_t i = 0; i < size(EnumStrings<Log::Level>::data); ++i) { + if (std::string(logLevel) == EnumStrings<Log::Level>::data[i]) { + return static_cast<Log::Level>(i); + } + } + } + return Info; +}(); +Aidge::Log::Level Aidge::Log::mFileLevel = []() { + const char* logLevel = std::getenv("AIDGE_LOGLEVEL_FILE"); + if (logLevel != nullptr) { + for (std::size_t i = 0; i < size(EnumStrings<Log::Level>::data); ++i) { + if (std::string(logLevel) == EnumStrings<Log::Level>::data[i]) { + return static_cast<Log::Level>(i); + } + } + } + return Debug; +}(); +std::string Aidge::Log::mFileName = []() { + const char* logFile = std::getenv("AIDGE_LOG_FILE"); + if (logFile != nullptr) { + return std::string(logFile); + } + return std::string(); +}(); std::unique_ptr<FILE, decltype(&std::fclose)> Aidge::Log::mFile {nullptr, nullptr}; +std::vector<std::string> Aidge::Log::mContext; void Aidge::Log::log(Level level, const std::string& msg) { if (level >= mConsoleLevel) { @@ -33,6 +62,10 @@ void Aidge::Log::log(Level level, const std::string& msg) { : (level == Fatal) ? fmt::bg(fmt::color::red) : fmt::text_style(); + for (const auto& context : mContext) { + fmt::println("Context: {}", context); + } + fmt::println("{}", fmt::styled(msg, modifier)); } @@ -41,6 +74,10 @@ void Aidge::Log::log(Level level, const std::string& msg) { initFile(mFileName); } + for (const auto& context : mContext) { + fmt::println("Context: {}", context); + } + fmt::println(mFile.get(), msg); } } diff --git a/unit_tests/graphRegex/Test_examples.cpp b/unit_tests/graphRegex/Test_examples.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d85ae5c893a7ae4497125a62dad3cde97dac5195 --- /dev/null +++ b/unit_tests/graphRegex/Test_examples.cpp @@ -0,0 +1,55 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graphRegex/GraphRegex.hpp" +#include "aidge/recipes/Recipes.hpp" + +namespace Aidge { + + +TEST_CASE("Examples", "[GraphMatching]") { + auto g1 = Sequential({ + Producer({16, 3, 512, 512}, "dataProvider"), + Conv(3, 4, {5, 5}, "conv1"), + ReLU(), + PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}), + ReLU(), + PaddedConv(8, 16, {5, 5}, "conv3", {1, 1}, {2, 2, 2, 2}), + ReLU() + }); + + expandMetaOps(g1); + g1->save("Test_examples"); + + auto regex = std::make_shared<GraphRegex>(); + regex->setKeyFromGraph(g1); + regex->addQuery("Pad->Conv->ReLU"); + // Won't work, wrong number of matches: + //regex->addQuery("Pad*->Conv->ReLU*"); + + const auto match = regex->match(g1); + REQUIRE(match.size() == 2); + + for (const auto& solution : match) { + REQUIRE(solution->getAll().size() == 3); + } +} + +} // namespace Aidge \ No newline at end of file diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 514fb3b494b50112f26efbaba831e2b46429adcd..7eb0290d9d2dadbb7328604332fb84af2a1be941 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -25,7 +25,7 @@ #include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/scheduler/Scheduler.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" using namespace Aidge;