Skip to content
Snippets Groups Projects
Commit 10a0b754 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

Merge remote-tracking branch 'EclipseRepo/dev' into feat/support_ASAN

parents 1dfbbd11 d00e9a7f
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!100fix/scheduler_exec_time
Showing
with 309 additions and 55 deletions
......@@ -31,7 +31,7 @@ enum class PadBorderType { Constant, Edge, Reflect, Wrap };
template <DimIdx_t DIM>
class Pad_Op : public OperatorTensor,
public Registrable<Pad_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Pad_Op<DIM> &)>,
public Registrable<Pad_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Pad_Op<DIM> &)>,
public StaticAttributes<PadAttr,
std::array<DimSize_t, 2*DIM>,
PadBorderType,
......@@ -98,7 +98,7 @@ public:
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(Pad_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -27,7 +27,7 @@
namespace Aidge {
class Pow_Op : public OperatorTensor,
public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> {
public Registrable<Pow_Op, std::string, std::shared_ptr<OperatorImpl>(const Pow_Op&)> {
public:
static const std::string Type;
......@@ -40,7 +40,11 @@ public:
Pow_Op(const Pow_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Pow_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Pow_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -55,7 +59,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Pow_Op>::create(name)(*this);
SET_IMPL_MACRO(Pow_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......@@ -72,4 +76,4 @@ inline std::shared_ptr<Node> Pow(const std::string& name = "") {
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_POW_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_POW_H_ */
......@@ -28,7 +28,7 @@ enum class ProdAttr { Constant };
class Producer_Op
: public OperatorTensor,
public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>(
public Registrable<Producer_Op, std::string, std::shared_ptr<OperatorImpl>(
const Producer_Op &)>,
public StaticAttributes<ProdAttr, bool> {
public:
......@@ -67,9 +67,11 @@ public:
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i)));
}
mImpl = (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}))
? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this)
: std::make_shared<OperatorImpl>(*this);
if (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
SET_IMPL_MACRO(Producer_Op, *this, mOutputs[0]->getImpl()->backend());
}else{
mImpl = std::make_shared<OperatorImpl>(*this);
}
}
/**
......@@ -92,9 +94,7 @@ public:
inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); }
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
if (Registrar<Producer_Op>::exists({name})) {
mImpl = Registrar<Producer_Op>::create({name})(*this);
}
SET_IMPL_MACRO(Producer_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -26,7 +26,7 @@
namespace Aidge {
class ReLU_Op : public OperatorTensor,
public Registrable<ReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const ReLU_Op&)> {
public Registrable<ReLU_Op, std::string, std::shared_ptr<OperatorImpl>(const ReLU_Op&)> {
public:
static const std::string Type;
......@@ -39,7 +39,11 @@ public:
ReLU_Op(const ReLU_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<ReLU_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(ReLU_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -52,7 +56,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<ReLU_Op>::create(name)(*this);
SET_IMPL_MACRO(ReLU_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......@@ -69,4 +73,4 @@ inline std::shared_ptr<Node> ReLU(const std::string& name = "") {
}
}
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
......@@ -32,7 +32,7 @@ enum class ReduceMeanAttr { Axes, KeepDims };
template <DimIdx_t DIM>
class ReduceMean_Op : public OperatorTensor,
public Registrable<ReduceMean_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ReduceMean_Op<DIM> &)>,
public Registrable<ReduceMean_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const ReduceMean_Op<DIM> &)>,
public StaticAttributes<ReduceMeanAttr, std::array<std::int32_t, DIM>, DimSize_t> {
public:
......@@ -57,7 +57,11 @@ class ReduceMean_Op : public OperatorTensor,
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<ReduceMean_Op<DIM>>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(ReduceMean_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -99,7 +103,7 @@ class ReduceMean_Op : public OperatorTensor,
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(ReduceMean_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -28,7 +28,7 @@ namespace Aidge {
enum class ReshapeAttr { Shape };
class Reshape_Op : public OperatorTensor,
public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>,
public Registrable<Reshape_Op, std::string, std::shared_ptr<OperatorImpl>(const Reshape_Op&)>,
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> {
public:
......@@ -53,7 +53,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Reshape_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Reshape_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -67,7 +71,7 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Reshape_Op>::create(name)(*this);
SET_IMPL_MACRO(Reshape_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -55,7 +55,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Scaling_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Scaling_Op, *this, op.mOutputs[0]->getImpl()->backend());
} else {
mImpl = nullptr;
}
}
/**
......@@ -95,4 +99,4 @@ const char* const EnumStrings<Aidge::ScalingAttr>::data[]
= {"scalingFactor", "quantizedNbBits", "isOutputUnsigned"};
}
#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */
\ No newline at end of file
#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */
......@@ -28,7 +28,7 @@ enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op
: public OperatorTensor,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
public:
static const std::string Type;
......@@ -55,8 +55,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Slice_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this)
: nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Slice_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
public:
......@@ -69,7 +72,7 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Slice_Op>::create(name)(*this);
SET_IMPL_MACRO(Slice_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -33,7 +33,7 @@ enum class SoftmaxAttr { AxisIdx };
class Softmax_Op : public OperatorTensor,
public Registrable<Softmax_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Softmax_Op&)>,
std::shared_ptr<OperatorImpl>(const Softmax_Op&)>,
public StaticAttributes<SoftmaxAttr, int> {
public:
......@@ -55,7 +55,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Softmax_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Softmax_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -67,7 +71,7 @@ public:
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Softmax_Op>::create(name)(*this);
SET_IMPL_MACRO(Softmax_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -27,7 +27,7 @@
namespace Aidge {
class Sqrt_Op : public OperatorTensor,
public Registrable<Sqrt_Op, std::string, std::unique_ptr<OperatorImpl>(const Sqrt_Op&)> {
public Registrable<Sqrt_Op, std::string, std::shared_ptr<OperatorImpl>(const Sqrt_Op&)> {
public:
// FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
......@@ -45,7 +45,11 @@ public:
Sqrt_Op(const Sqrt_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Sqrt_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Sqrt_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -57,7 +61,7 @@ public:
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Sqrt_Op>::create(name)(*this);
SET_IMPL_MACRO(Sqrt_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -27,7 +27,7 @@
namespace Aidge {
class Sub_Op : public OperatorTensor,
public Registrable<Sub_Op, std::string, std::unique_ptr<OperatorImpl>(const Sub_Op&)> {
public Registrable<Sub_Op, std::string, std::shared_ptr<OperatorImpl>(const Sub_Op&)> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()};
......@@ -45,7 +45,11 @@ public:
Sub_Op(const Sub_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Sub_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Sub_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -60,7 +64,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Sub_Op>::create(name)(*this);
SET_IMPL_MACRO(Sub_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......@@ -77,4 +81,4 @@ inline std::shared_ptr<Node> Sub(const std::string& name = "") {
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_SUB_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_SUB_H_ */
......@@ -30,7 +30,7 @@ enum class TransposeAttr { OutputDimsOrder };
template <DimIdx_t DIM>
class Transpose_Op : public OperatorTensor,
public Registrable<Transpose_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Transpose_Op<DIM> &)>,
public Registrable<Transpose_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Transpose_Op<DIM> &)>,
public StaticAttributes<TransposeAttr,
std::array<DimSize_t, DIM>> {
......@@ -56,7 +56,11 @@ class Transpose_Op : public OperatorTensor,
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Transpose_Op<DIM>>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Transpose_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -80,7 +84,7 @@ class Transpose_Op : public OperatorTensor,
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(Transpose_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -18,13 +18,15 @@
#include <fmt/format.h>
#include <fmt/ranges.h>
#include "aidge/utils/Log.hpp"
#ifdef NO_EXCEPTION
#define AIDGE_THROW_OR_ABORT(ex, ...) \
do { fmt::print(__VA_ARGS__); std::abort(); } while (false)
do { Aidge::Log::fatal(__VA_ARGS__); std::abort(); } while (false)
#else
#include <stdexcept>
#define AIDGE_THROW_OR_ABORT(ex, ...) \
throw ex(fmt::format(__VA_ARGS__))
do { Aidge::Log::fatal(__VA_ARGS__); throw ex(fmt::format(__VA_ARGS__)); } while (false)
#endif
/**
......@@ -33,7 +35,7 @@ throw ex(fmt::format(__VA_ARGS__))
* If it asserts, it means an user error.
*/
#define AIDGE_ASSERT(stm, ...) \
if (!(stm)) { fmt::print("Assertion failed: " #stm " in {}:{}", __FILE__, __LINE__); \
if (!(stm)) { Aidge::Log::error("Assertion failed: " #stm " in {}:{}", __FILE__, __LINE__); \
AIDGE_THROW_OR_ABORT(std::runtime_error, __VA_ARGS__); }
/**
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_LOG_H_
#define AIDGE_LOG_H_
#include <memory>
#include <fmt/format.h>
#include <fmt/ranges.h>
namespace Aidge {
/**
* Aidge logging class, for displaying and file logging of events.
*/
class Log {
public:
enum Level {
Debug = 0,
Info,
Notice,
Warn,
Error,
Fatal
};
/**
* Detailed messages for debugging purposes, providing information helpful
* for developers to trace and identify issues.
* Detailed insights of what is appening in an operation, not useful for the
* end-user. The operation is performed nominally.
* @note This level is disabled at compile time for Release, therefore
* inducing no runtime overhead for Release.
*/
template <typename... Args>
constexpr static void debug(Args&&... args) {
#ifndef NDEBUG
// only when compiled in Debug
log(Debug, fmt::format(std::forward<Args>(args)...));
#endif
}
/**
* Messages that provide a record of the normal operation, about
* the application's state, progress, or important events.
* Reports normal start, end and key steps in an operation. The operation is
* performed nominally.
*/
template <typename... Args>
constexpr static void info(Args&&... args) {
log(Info, fmt::format(std::forward<Args>(args)...));
}
/**
* Applies to normal but significant conditions that may require monitoring,
* like unusual or normal fallback events.
* Reports specific paths in an operation. The operation can still be
* performed normally.
*/
template <typename... Args>
constexpr static void notice(Args&&... args) {
log(Notice, fmt::format(std::forward<Args>(args)...));
}
/**
* Indicates potential issues or situations that may lead to errors but do
* not necessarily cause immediate problems.
* Some specific steps of the operation could not be performed, but it can
* still provide an exploitable result.
*/
template <typename... Args>
constexpr static void warn(Args&&... args) {
log(Warn, fmt::format(std::forward<Args>(args)...));
}
/**
* Signifies a problem or unexpected condition that the application can
* recover from, but attention is needed to prevent further issues.
* The operation could not be performed, but it does not prevent potential
* further operations.
*/
template <typename... Args>
constexpr static void error(Args&&... args) {
log(Error, fmt::format(std::forward<Args>(args)...));
}
/**
* Represents a critical error or condition that leads to the termination of
* the application, indicating a severe and unrecoverable problem.
* The operation could not be performed and any further operation is
* impossible.
*/
template <typename... Args>
constexpr static void fatal(Args&&... args) {
log(Fatal, fmt::format(std::forward<Args>(args)...));
}
/**
* Set the minimum log level displayed in the console.
*/
constexpr static void setConsoleLevel(Level level) {
mConsoleLevel = level;
}
/**
* Set the minimum log level saved in the log file.
*/
constexpr static void setFileLevel(Level level) {
mFileLevel = level;
}
/**
* Set the log file name.
* Close the current log file and open the one with the new file name.
* If empty, stop logging into a file.
*/
static void setFileName(const std::string& fileName) {
if (fileName != mFileName) {
mFileName = fileName;
mFile.release();
if (!fileName.empty()) {
initFile(fileName);
}
}
}
private:
static void log(Level level, const std::string& msg);
static void initFile(const std::string& fileName);
static Level mConsoleLevel;
static Level mFileLevel;
static std::string mFileName;
static std::unique_ptr<FILE, decltype(&std::fclose)> mFile;
};
}
#endif //AIDGE_LOG_H_
......@@ -14,6 +14,9 @@
#ifdef PYBIND
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // declare_registrable key can recquire stl
#include <pybind11/functional.h>// declare_registrable allow binding of lambda fn
#endif
#include "aidge/utils/ErrorHandling.hpp"
......@@ -27,6 +30,9 @@ namespace Aidge {
namespace py = pybind11;
#endif
// Abstract class used to test if a class is Registrable.
class AbstractRegistrable {};
template <class DerivedClass, class Key, class Func> // curiously rucurring template pattern
class Registrable {
public:
......@@ -58,8 +64,10 @@ struct Registrar {
Registrar(const registrar_key& key, registrar_type func) {
//fmt::print("REGISTRAR: {}\n", key);
bool newInsert;
std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
// bool newInsert;
// std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
C::registry().erase(key);
C::registry().insert(std::make_pair(key, func));
//assert(newInsert && "registrar already exists");
}
......@@ -81,6 +89,62 @@ struct Registrar {
return keys;
}
};
#ifdef PYBIND
/**
* @brief Function to define register function for a registrable class
* Defined here to have access to this function in every module who wants
* to create a new registrable class.
*
* @tparam C registrable class
* @param m pybind module
* @param class_name python name of the class
*/
template <class C>
void declare_registrable(py::module& m, const std::string& class_name){
typedef typename C::registrar_key registrar_key;
typedef typename C::registrar_type registrar_type;
m.def(("register_"+ class_name).c_str(), [](registrar_key& key, registrar_type function){
Registrar<C>(key, function);
})
.def(("get_keys_"+ class_name).c_str(), [](){
return Registrar<C>::getKeys();
});
}
#endif
/*
* This macro allow to set an implementation to an operator
* This macro is mandatory for using implementation registered in python
* PyBind when calling create method will do a call to the copy ctor if
* op is not visible to the python world (if the create method return a python function)
* See this issue for more information https://github.com/pybind/pybind11/issues/4417
* Note: using a method to do this is not possible has any call to a function will call
* the cpy ctor. This is why I used a macro
* Note: I duplicated
* (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
* This is because the py::cast need to be done in the same scope.
* I know this only empyrically not sure what happens under the hood...
*
* If someone wants to find an alternative to this Macro, you can contact me:
* cyril.moineau@cea.fr
*/
#ifdef PYBIND
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
\
if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
}
#else
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
if (Registrar<T_Op>::exists(backend_name)) { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
}
#endif
}
#endif //AIDGE_CORE_UTILS_REGISTRAR_H_
......@@ -116,7 +116,7 @@ public:
void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<const Operator&>())
.def(py::init<const Operator&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>())
.def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
......
......@@ -23,7 +23,7 @@ void declare_Add(py::module &m) {
py::class_<Add_Op, std::shared_ptr<Add_Op>, OperatorTensor>(m, "AddOp", py::multiple_inheritance())
.def("get_inputs_name", &Add_Op::getInputsName)
.def("get_outputs_name", &Add_Op::getOutputsName);
declare_registrable<Add_Op>(m, "AddOp");
m.def("Add", &Add, py::arg("nbIn"), py::arg("name") = "");
}
......
......@@ -26,8 +26,9 @@ namespace py = pybind11;
namespace Aidge {
template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) {
const std::string pyClassName("AvgPoolingOp" + std::to_string(DIM) + "D");
py::class_<AvgPooling_Op<DIM>, std::shared_ptr<AvgPooling_Op<DIM>>, Attributes, OperatorTensor>(
m, ("AvgPoolingOp" + std::to_string(DIM) + "D").c_str(),
m, pyClassName.c_str(),
py::multiple_inheritance())
.def(py::init<const std::array<DimSize_t, DIM> &,
const std::array<DimSize_t, DIM> &>(),
......@@ -36,7 +37,7 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) {
.def("get_inputs_name", &AvgPooling_Op<DIM>::getInputsName)
.def("get_outputs_name", &AvgPooling_Op<DIM>::getOutputsName)
.def("attributes_name", &AvgPooling_Op<DIM>::staticGetAttrsName);
declare_registrable<AvgPooling_Op<DIM>>(m, pyClassName);
m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims,
const std::string& name,
const std::vector<DimSize_t> &stride_dims) {
......
......@@ -21,13 +21,12 @@ namespace Aidge {
template <DimSize_t DIM>
void declare_BatchNormOp(py::module& m) {
py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Attributes, OperatorTensor>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance())
.def(py::init<float, float>(),
py::arg("epsilon"),
py::arg("momentum"))
const std::string pyClassName("BatchNormOp" + std::to_string(DIM) + "D");
py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Attributes, OperatorTensor>(m, pyClassName.c_str(), py::multiple_inheritance())
.def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName)
.def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName)
.def("attributes_name", &BatchNorm_Op<DIM>::staticGetAttrsName);
declare_registrable<BatchNorm_Op<DIM>>(m, pyClassName);
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = "");
}
......
......@@ -24,6 +24,7 @@ void init_Concat(py::module& m) {
.def("get_outputs_name", &Concat_Op::getOutputsName)
.def("attributes_name", &Concat_Op::staticGetAttrsName);
declare_registrable<Concat_Op>(m, "ConcatOp");
m.def("Concat", &Concat, py::arg("nbIn"), py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment