Skip to content
Snippets Groups Projects
Commit 19e3e5c9 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'scheduling' into 'dev'

Improved scheduling

See merge request !94
parents f3364c9f f1d4d12b
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!94Improved scheduling
Pipeline #42688 passed
Showing
with 475 additions and 90 deletions
......@@ -44,6 +44,8 @@ set(FMT_SYSTEM_HEADERS ON)
FetchContent_MakeAvailable(fmt)
set_property(TARGET fmt PROPERTY POSITION_INDEPENDENT_CODE ON)
find_package(Threads REQUIRED)
##############################################
# Create target and set properties
......@@ -88,7 +90,7 @@ if (PYBIND)
)
endif()
target_link_libraries(${module_name} PUBLIC fmt::fmt)
target_link_libraries(${module_name} PUBLIC Threads::Threads fmt::fmt)
target_compile_features(${module_name} PRIVATE cxx_std_14)
if (DOSANITIZE STREQUAL "ON")
......
......@@ -2,6 +2,7 @@
include(CMakeFindDependencyMacro)
find_dependency(fmt)
find_dependency(Threads)
include(${CMAKE_CURRENT_LIST_DIR}/aidge_core-config-version.cmake)
......
......@@ -16,6 +16,7 @@
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/data/Elts.hpp"
namespace Aidge {
class Operator;
......@@ -36,13 +37,13 @@ public:
* @param inputIdx Index of the input analysed.
* @return std::size_t
*/
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
* @brief Total amount of consumed data from a specific input.
......@@ -50,7 +51,7 @@ public:
* @param inputIdx Index of the input analysed.
* @return DimSize_t
*/
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/**
* @brief Total amount of produced data ready to be used on a specific output.
......@@ -58,7 +59,7 @@ public:
* @param outputIdx Index of the output analysed.
* @return DimSize_t
*/
virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
/**
* @brief Update the Consummer Producer system by simulating the consumption and production of i/o
......@@ -77,8 +78,8 @@ public:
protected:
const Operator &mOp;
const std::string mBackend;
std::vector<NbElts_t> mNbConsumedData;
std::vector<NbElts_t> mNbProducedData;
std::vector<Elts_t> mNbConsumedData;
std::vector<Elts_t> mNbProducedData;
};
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_ELTS_H_
#define AIDGE_ELTS_H_
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* Base object for Aidge consumer-producer model (C-P model).
* It is a hybrid model: operator implementations can specify their C-P model
* with precise data (bytes) or with tokens.
*/
struct Elts_t {
enum EltType {
Data,
Token,
Undef
};
NbElts_t data;
NbElts_t token;
EltType type;
// Addition operator
inline Elts_t operator+(const Elts_t& other) const {
AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef,
"Incompatible C-P model types: {} + {}. Data and Token cannot be mixed.", type, other.type);
return Elts_t(data + other.data, token + other.token, (other.type == Undef) ? type : other.type);
}
// Addition assignment operator
inline Elts_t& operator+=(const Elts_t& other) {
AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef,
"Incompatible C-P model types: {} += {}. Data and Token cannot be mixed.", type, other.type);
data += other.data;
token += other.token;
type = (other.type == Undef) ? type : other.type;
return *this;
}
// Comparison operators
inline bool operator<(const Elts_t& other) const {
if (type == Elts_t::Undef || type == Elts_t::Token) {
// Nothing, or only a token is required: don't care about how much data has been produced for the token
return (token < other.token);
}
else if (type == Elts_t::Data && other.type != Elts_t::Token) {
// A precise amount of data is required, so the amount of produced data must be specified, a token is not enough
return (data < other.data);
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Incompatible C-P model types: {} < {}. Data is expected for right-hand side.", type, other.type);
}
}
inline bool operator>(const Elts_t& other) const {
if (type == Elts_t::Undef || type == Elts_t::Token) {
// Nothing, or only a token is required: don't care about how much data has been produced for the token
return (token > other.token);
}
else if (type == Elts_t::Data && other.type != Elts_t::Token) {
// A precise amount of data is required, so the amount of produced data must be specified, a token is not enough
return (data > other.data);
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Incompatible C-P model types: {} > {}. Data is expected for right-hand side.", type, other.type);
}
}
inline static Elts_t NoneElts() {
return Elts_t(0, 0, Elts_t::Undef);
}
inline static Elts_t DataElts(NbElts_t data, NbElts_t token = 1) {
return Elts_t(data, token, Elts_t::Data);
}
inline static Elts_t TokenElts(NbElts_t token) {
return Elts_t(0, token, Elts_t::Token);
}
private:
inline Elts_t(NbElts_t data_, NbElts_t token_, EltType type_):
data(data_), token(token_), type(type_) {}
};
} // end namespace Aidge
template<>
struct fmt::formatter<Aidge::Elts_t> {
template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template<typename FormatContext>
inline auto format(Aidge::Elts_t const& elt, FormatContext& ctx) {
return fmt::format_to(ctx.out(), "{}:{}", elt.data, elt.token);
}
};
namespace {
template <>
const char* const EnumStrings<Aidge::Elts_t::EltType>::data[]
= {"Data", "Token", "Undef"};
}
namespace Aidge {
inline auto format_as(Elts_t::EltType elt) { return EnumStrings<Aidge::Elts_t::EltType>::data[static_cast<int>(elt)]; }
}
#endif /* AIDGE_ELTS_H_ */
......@@ -21,7 +21,7 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
......@@ -115,11 +115,11 @@ public:
mGraph->setDataType(datatype);
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override;
NbElts_t getNbProducedData(IOIndex_t outputIdx) const override;
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
Elts_t getNbConsumedData(IOIndex_t inputIdx) const override;
Elts_t getNbProducedData(IOIndex_t outputIdx) const override;
void updateConsummerProducer() override;
void forward() override;
......
......@@ -134,31 +134,31 @@ public:
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @param inputIdx Index of the input analysed.
* @return NbElts_t
* @return Elts_t
*/
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
* @brief Total amount of consumed data from a specific input.
*
* @param inputIdx Index of the input analysed.
* @return NbElts_t
* @return Elts_t
*/
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/**
* @brief Total amount of produced data ready to be used on a specific output.
*
* @param outputIdx Index of the output analysed.
* @return NbElts_t
* @return Elts_t
*/
virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
virtual void updateConsummerProducer();
......
......@@ -47,8 +47,7 @@ public:
Attributes_(attr<ProdAttr::Constant>(constant))
{
mOutputs[0]->resize(dims);
// mImpl = std::make_shared<OperatorImpl>(*this, "");
mImpl = nullptr;
mImpl = std::make_shared<OperatorImpl>(*this, "");
}
/**
......
......@@ -22,6 +22,8 @@
namespace Aidge {
void constantFolding(std::shared_ptr<GraphView> graph);
// FUSE MATMUL + ADD -> FC
/**
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_PARALLELSCHEDULER_H_
#define AIDGE_PARALLELSCHEDULER_H_
#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <map>
#include "aidge/scheduler/Scheduler.hpp"
namespace Aidge {
/**
* Multi-threaded parallel scheduler with dynamic scheduling.
*/
class ParallelScheduler : public Scheduler {
public:
ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: Scheduler(graphView, upperNode)
{
// ctor
};
~ParallelScheduler() = default;
/**
* @brief Run the provided Computational Graph with a batch of data
*/
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
};
} // namespace Aidge
#endif /* AIDGE_PARALLELSCHEDULER_H_ */
......@@ -28,8 +28,22 @@ namespace Aidge {
class Node;
class GraphView;
class SequentialScheduler {
private:
class Scheduler {
protected:
struct StaticSchedulingElement {
StaticSchedulingElement(
std::shared_ptr<Node> node_,
size_t early_ = static_cast<size_t>(-1),
size_t late_ = static_cast<size_t>(-1))
: node(node_), early(early_), late(late_) {}
std::shared_ptr<Node> node;
size_t early;
size_t late;
std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan;
std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan;
};
struct SchedulingElement {
SchedulingElement(
std::shared_ptr<Node> node_,
......@@ -49,15 +63,25 @@ private:
};
public:
SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
Scheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: mGraphView(graphView),
mUpperNode(upperNode)
{
// ctor
};
~SequentialScheduler() = default;
virtual ~Scheduler() = default;
void generateScheduling(bool verbose = false);
/**
* Generate full static scheduling of the GraphView.
* For each node, an earliest and latest possible execution logical step
* is specified. Nodes that may be scheduled at the same logical step have
* no data dependency and can be run in parallel.
*/
void generateScheduling();
/**
* Reset all scheduling and associated nodes producer consumer.
*/
void resetScheduling();
/**
......@@ -75,14 +99,10 @@ public:
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/**
* @brief Run the provided Computational Graph with a batch of data
* @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes.
* @param fileName Name of the generated file.
*/
void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true, bool verbose = false);
void saveStaticSchedulingDiagram(const std::string& fileName) const;
/**
* @brief Save in a Markdown file the order of layers execution.
......@@ -94,14 +114,26 @@ public:
* @brief Return a vector of Node ordered by the order they are called by the scheduler
* @return std::vector<std::shared_ptr<Node>>
*/
inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept {
return mStaticSchedule.at(step);
}
std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const;
inline std::shared_ptr<GraphView> getGraphView() const noexcept {
return mGraphView;
}
private:
protected:
/**
* Generate an initial base scheduling for the GraphView.
* The scheduling is entirely sequential and garanteed to be valid w.r.t.
* each node producer-consumer model.
*/
std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const;
/**
* Fill-in early and late scheduling step from initial base scheduling.
* For each node, specifies the earliest and latest possible execution
* logical step.
*/
void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const;
/**
* @brief Set of layers receiving an input from currently processing layers
*
......@@ -109,7 +141,7 @@ private:
* @return std::set<std::shared_ptr<Node>>
*/
std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
/** @brief Shared ptr to the scheduled graph view */
......@@ -119,7 +151,7 @@ private:
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
std::vector<SchedulingElement> mScheduling;
/** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule;
std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule;
size_t mStaticScheduleStep = 0;
mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
};
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_SEQUENTIALSCHEDULER_H_
#define AIDGE_SEQUENTIALSCHEDULER_H_
#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <map>
#include "aidge/scheduler/Scheduler.hpp"
namespace Aidge {
/**
* Multi-threaded parallel scheduler with dynamic scheduling.
*/
class SequentialScheduler : public Scheduler {
public:
enum SchedulingPolicy {
Default,
AsSoonAsPossible,
AsLateAsPossible
};
SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: Scheduler(graphView, upperNode),
mSchedulingPolicy(Default)
{
// ctor
};
inline void setSchedulingPolicy(SchedulingPolicy policy) {
mSchedulingPolicy = policy;
}
~SequentialScheduler() = default;
/**
* @brief Run the provided Computational Graph with a batch of data
*/
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true);
private:
SchedulingPolicy mSchedulingPolicy;
};
} // namespace Aidge
#endif /* AIDGE_SEQUENTIALSCHEDULER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_THREADPOOL_H_
#define AIDGE_THREADPOOL_H_
#include <thread>
#include <mutex>
#include <queue>
#include <vector>
#include <functional>
#include <condition_variable>
#include <atomic>
namespace Aidge {
class ThreadPool {
public:
ThreadPool(size_t nbThreads = std::thread::hardware_concurrency());
void queueJob(const std::function<void()>& job);
bool busy();
virtual ~ThreadPool();
private:
void threadLoop();
bool mTerminate = false;
std::mutex mQueueMutex;
std::condition_variable mMutexCondition;
std::vector<std::thread> mThreads;
std::queue<std::function<void()>> mJobs;
};
} // namespace Aidge
#endif /* AIDGE_THREADPOOL_H_ */
......@@ -14,6 +14,7 @@
#define AIDGE_ERRORHANDLING_H_
#include <memory>
#include <cassert>
#include <fmt/format.h>
#include <fmt/ranges.h>
......
......@@ -18,7 +18,15 @@
#include <fmt/format.h>
#include <fmt/ranges.h>
#include "aidge/utils/Attributes.hpp"
namespace Aidge {
/**
* Helper to define a context anywhere, hidding the scoped variable name
* which has no relevance.
*/
#define AIDGE_LOG_CONTEXT(...) const Log::Context logContext_##__LINE__(__VA_ARGS__)
/**
* Aidge logging class, for displaying and file logging of events.
*/
......@@ -33,6 +41,18 @@ public:
Fatal
};
class Context {
public:
template <typename... Args>
Context(Args&&... args) {
Log::mContext.push_back(fmt::format(std::forward<Args>(args)...));
}
~Context() {
Log::mContext.pop_back();
}
};
/**
* Detailed messages for debugging purposes, providing information helpful
* for developers to trace and identify issues.
......@@ -142,7 +162,13 @@ private:
static Level mFileLevel;
static std::string mFileName;
static std::unique_ptr<FILE, decltype(&std::fclose)> mFile;
static std::vector<std::string> mContext;
};
}
namespace {
template <>
const char *const EnumStrings<Aidge::Log::Level>::data[] = {"Debug", "Info", "Notice", "Warn", "Error", "Fatal"};
}
#endif //AIDGE_LOG_H_
......@@ -130,20 +130,15 @@ void declare_registrable(py::module& m, const std::string& class_name){
*/
#ifdef PYBIND
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
\
if (Registrar<T_Op>::exists(backend_name)) { \
if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} \
}
#else
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
if (Registrar<T_Op>::exists(backend_name)) { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
}
#else
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op));
#endif
}
......
......@@ -43,18 +43,18 @@ public:
);
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_required_data",
getNbRequiredData,
inputIdx
);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_required_protected",
getNbRequiredProtected,
......@@ -62,10 +62,10 @@ public:
);
}
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
Elts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_required_memory",
getRequiredMemory,
......@@ -74,9 +74,9 @@ public:
);
}
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_consumed_data",
getNbConsumedData,
......@@ -84,9 +84,9 @@ public:
);
}
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override {
Elts_t getNbProducedData(const IOIndex_t outputIdx) const override {
PYBIND11_OVERRIDE_NAME(
NbElts_t,
Elts_t,
OperatorImpl,
"get_nb_produced_data",
getNbProducedData,
......
......@@ -12,20 +12,31 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/ParallelScheduler.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scheduler(py::module& m){
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler")
py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>())
.def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true, py::arg("verbose")=false)
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
.def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0)
.def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &Scheduler::resetScheduling)
.def("generate_scheduling", &Scheduler::generateScheduling)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0)
;
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
.def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true)
;
py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &ParallelScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
;
}
}
......
......@@ -20,48 +20,91 @@
Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend):
mOp(op),
mBackend(backend),
mNbConsumedData(mOp.nbInputs(), 0),
mNbProducedData(mOp.nbOutputs(), 0)
mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()),
mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts())
{
//ctor
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mOp.getRawInput(inputIdx),
"a valid input is required at index {} for operator type {}",
inputIdx, mOp.type());
// Requires the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->empty()) {
// Known amount of data: requires the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: require a single token by default
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
AIDGE_ASSERT(mOp.getRawInput(inputIdx),
"a valid input is required at index {} for operator type {}",
inputIdx, mOp.type());
// Protect the whole tensor by default
return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->empty()) {
// Known amount of data: protect the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: protect a single token by default
// (this does not really make sense for now, as getNbRequiredProtected()
// is supposed to give a precise amount of data to protect for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
AIDGE_ASSERT(mOp.getRawOutput(outputIdx),
"a valid output is required at index {} for operator type {}",
outputIdx, mOp.type());
// Requires the whole tensor by default, regardless of available data on inputs
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size();
if (mOp.getRawOutput(outputIdx)) {
const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx));
if (!output->empty()) {
// Known amount of data: requires the whole tensor by default,
// regardless of available data on inputs
return Elts_t::DataElts(output->size());
}
else {
// Unknown amount of data: require a single token by default
// (this does not really make sense for now, as getRequiredMemory()
// is supposed to give a precise amount of data to allocate for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Output not set, meaning it is an optional output: do no require anything!
return Elts_t::NoneElts();
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(),
"input index ({}) is out of bound ({}) for operator type {}",
inputIdx, mNbConsumedData.size(), mOp.type());
return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}
Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(),
"output index ({}) is out of bound ({}) for operator type {}",
outputIdx, mNbProducedData.size(), mOp.type());
......@@ -81,8 +124,8 @@ void Aidge::OperatorImpl::updateConsummerProducer(){
}
void Aidge::OperatorImpl::resetConsummerProducer(){
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0);
std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0);
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts());
std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts());
}
void Aidge::OperatorImpl::forward() {
......
......@@ -165,7 +165,7 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
size_t inputIdx = 0;
for (auto input : mInputNodes) {
if (input.first != nullptr) {
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|&rarr;{}|{}_{}\n", inputIdx, inputIdx,
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}\"|{}_{}\n", inputIdx, inputIdx,
input.second, input.first->type(), namePtrTable.at(input.first));
}
else {
......@@ -1258,7 +1258,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
if (deletedNode == mRootNode) {
const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes();
if(ranked_nodes.second== 0 )
if(ranked_nodes.second== 0 || ranked_nodes.first.size() <= 1)
{
mRootNode = nullptr;
} else {
......
......@@ -173,7 +173,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId)
"Input index ({}) is out of bound ({}) for node {} (of type {})",
inId, nbInputs(), name(), type());
if (mIdOutParents[inId] != gk_IODefaultIndex) {
Log::warn("Warning: filling a Tensor already attributed\n");
Log::notice("Notice: filling a Tensor already attributed");
auto originalParent = input(inId);
// remove original parent reference to child
// find the output ID for original Parent
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment