Skip to content
Snippets Groups Projects
Commit f616e3d2 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Added default implementation for several operators

parent 5a285f17
No related branches found
No related tags found
No related merge requests found
Showing
with 285 additions and 144 deletions
...@@ -23,7 +23,7 @@ class Operator; ...@@ -23,7 +23,7 @@ class Operator;
class OperatorImpl { class OperatorImpl {
public: public:
OperatorImpl(const Operator& op, const std::string& backend); OperatorImpl(const Operator& op, const std::string& backend = "");
virtual void forward(); virtual void forward();
virtual void backward(); virtual void backward();
......
...@@ -24,13 +24,20 @@ ...@@ -24,13 +24,20 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Cast_OpImpl : public OperatorImpl {
public:
Cast_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
class Cast_Op : public OperatorTensor, class Cast_Op : public OperatorTensor,
public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> { public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> {
public: public:
static const std::string Type; static const std::string Type;
Cast_Op() : OperatorTensor(Type, 1, 0, 1) {} Cast_Op() : OperatorTensor(Type, 1, 0, 1) {
mImpl = std::make_shared<Cast_OpImpl>(*this);
}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
...@@ -39,10 +46,11 @@ public: ...@@ -39,10 +46,11 @@ public:
Cast_Op(const Cast_Op& op) Cast_Op(const Cast_Op& op)
: OperatorTensor(op) : OperatorTensor(op)
{ {
if (op.mImpl) { if (!op.backend().empty()) {
SET_IMPL_MACRO(Cast_Op, *this, op.backend()); SET_IMPL_MACRO(Cast_Op, *this, op.backend());
} else { }
mImpl = nullptr; else {
mImpl = std::make_shared<Cast_OpImpl>(*this);
} }
} }
...@@ -56,8 +64,6 @@ public: ...@@ -56,8 +64,6 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override; void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
void forward() override;
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input"};
} }
......
...@@ -26,6 +26,12 @@ ...@@ -26,6 +26,12 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Concat_OpImpl : public OperatorImpl {
public:
Concat_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class ConcatAttr { Axis }; enum class ConcatAttr { Axis };
class Concat_Op : public OperatorTensor, class Concat_Op : public OperatorTensor,
...@@ -45,6 +51,7 @@ public: ...@@ -45,6 +51,7 @@ public:
if (nbIn == 0) { if (nbIn == 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Add operator should have at least one input."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Add operator should have at least one input.");
} }
mImpl = std::make_shared<Concat_OpImpl>(*this);
} }
/** /**
...@@ -55,10 +62,11 @@ public: ...@@ -55,10 +62,11 @@ public:
: OperatorTensor(op), : OperatorTensor(op),
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (!op.backend().empty()) {
SET_IMPL_MACRO(Concat_Op, *this, op.backend()); SET_IMPL_MACRO(Concat_Op, *this, op.backend());
}else{ }
mImpl = nullptr; else {
mImpl = std::make_shared<Concat_OpImpl>(*this);
} }
} }
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut)
: OperatorTensor(type, nbData, nbParam, nbOut) : OperatorTensor(type, nbData, nbParam, nbOut)
{ {
mImpl = std::make_shared<OperatorImpl>(*this, ""); mImpl = std::make_shared<OperatorImpl>(*this);
} }
/** /**
......
...@@ -42,7 +42,7 @@ public: ...@@ -42,7 +42,7 @@ public:
Identity_Op() Identity_Op()
: OperatorTensor(Type, 1, 0, 1) : OperatorTensor(Type, 1, 0, 1)
{ {
mImpl = std::make_shared<OperatorImpl>(*this, ""); mImpl = std::make_shared<OperatorImpl>(*this);
} }
/** /**
......
...@@ -25,6 +25,15 @@ ...@@ -25,6 +25,15 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Memorize_OpImpl : public OperatorImpl {
public:
Memorize_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override final;
void updateConsummerProducer() override;
void forward() override;
};
enum class MemorizeAttr { ScheduleStep, ForwardStep, EndStep }; enum class MemorizeAttr { ScheduleStep, ForwardStep, EndStep };
class Memorize_Op : public OperatorTensor, class Memorize_Op : public OperatorTensor,
......
...@@ -24,13 +24,20 @@ ...@@ -24,13 +24,20 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Move_OpImpl : public OperatorImpl {
public:
Move_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
class Move_Op : public OperatorTensor, class Move_Op : public OperatorTensor,
public Registrable<Move_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Move_Op&)> { public Registrable<Move_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Move_Op&)> {
public: public:
static const std::string Type; static const std::string Type;
Move_Op() : OperatorTensor(Type, 1, 0, 1) {} Move_Op() : OperatorTensor(Type, 1, 0, 1) {
mImpl = std::make_shared<Move_OpImpl>(*this);
}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
...@@ -39,7 +46,12 @@ public: ...@@ -39,7 +46,12 @@ public:
Move_Op(const Move_Op& op) Move_Op(const Move_Op& op)
: OperatorTensor(op) : OperatorTensor(op)
{ {
mImpl = op.mImpl ? Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr; if (!op.backend().empty()) {
SET_IMPL_MACRO(Move_Op, *this, {op.getInput(0)->getImpl()->backend(), op.backend()});
}
else {
mImpl = std::make_shared<Move_OpImpl>(*this);
}
} }
/** /**
...@@ -50,14 +62,7 @@ public: ...@@ -50,14 +62,7 @@ public:
return std::make_shared<Move_Op>(*this); return std::make_shared<Move_Op>(*this);
} }
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
if (mInputs[0]->getImpl() && Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
mImpl = Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this);
}
mOutputs[0]->setBackend(name, device);
}
void forward() override;
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input"};
......
...@@ -24,6 +24,13 @@ ...@@ -24,6 +24,13 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Pop_OpImpl : public OperatorImpl {
public:
Pop_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
void forward() override;
};
enum class PopAttr { ForwardStep }; enum class PopAttr { ForwardStep };
class Pop_Op : public OperatorTensor, class Pop_Op : public OperatorTensor,
...@@ -39,7 +46,9 @@ public: ...@@ -39,7 +46,9 @@ public:
Pop_Op() Pop_Op()
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<PopAttr::ForwardStep>(0)) Attributes_(attr<PopAttr::ForwardStep>(0))
{} {
mImpl = std::make_shared<Pop_OpImpl>(*this);
}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
...@@ -49,10 +58,11 @@ public: ...@@ -49,10 +58,11 @@ public:
: OperatorTensor(op), : OperatorTensor(op),
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (!op.backend().empty()) {
SET_IMPL_MACRO(Pop_Op, *this, op.backend()); SET_IMPL_MACRO(Pop_Op, *this, op.backend());
} else { }
mImpl = nullptr; else {
mImpl = std::make_shared<Pop_OpImpl>(*this);
} }
} }
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
Attributes_(attr<ProdAttr::Constant>(constant)) Attributes_(attr<ProdAttr::Constant>(constant))
{ {
mOutputs[0]->resize(dims); mOutputs[0]->resize(dims);
mImpl = std::make_shared<OperatorImpl>(*this, ""); mImpl = std::make_shared<OperatorImpl>(*this);
} }
/** /**
...@@ -102,9 +102,8 @@ public: ...@@ -102,9 +102,8 @@ public:
return {"data_output"}; return {"data_output"};
} }
void forward() override final { void forward() override final;
fmt::print("Basic Producer forward() function.\n");
}
void backward() override final { void backward() override final {
fmt::print("Basic Producer backward() function.\n"); fmt::print("Basic Producer backward() function.\n");
} }
......
...@@ -23,6 +23,11 @@ ...@@ -23,6 +23,11 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Reshape_OpImpl : public OperatorImpl {
public:
Reshape_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class ReshapeAttr { Shape }; enum class ReshapeAttr { Shape };
...@@ -42,7 +47,9 @@ public: ...@@ -42,7 +47,9 @@ public:
Reshape_Op(const std::vector<std::int64_t>& shape) Reshape_Op(const std::vector<std::int64_t>& shape)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape)) Attributes_(attr<ReshapeAttr::Shape>(shape))
{} {
mImpl = std::make_shared<Reshape_OpImpl>(*this);
}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
...@@ -52,10 +59,11 @@ public: ...@@ -52,10 +59,11 @@ public:
: OperatorTensor(op), : OperatorTensor(op),
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (!op.backend().empty()) {
SET_IMPL_MACRO(Reshape_Op, *this, op.backend()); SET_IMPL_MACRO(Reshape_Op, *this, op.backend());
} else { }
mImpl = nullptr; else {
mImpl = std::make_shared<Reshape_OpImpl>(*this);
} }
} }
......
...@@ -129,16 +129,16 @@ void declare_registrable(py::module& m, const std::string& class_name){ ...@@ -129,16 +129,16 @@ void declare_registrable(py::module& m, const std::string& class_name){
* cyril.moineau@cea.fr * cyril.moineau@cea.fr
*/ */
#ifdef PYBIND #ifdef PYBIND
#define SET_IMPL_MACRO(T_Op, op, backend_name) \ #define SET_IMPL_MACRO(T_Op, op, ...) \
if(Py_IsInitialized()) { \ if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \ auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ (op).setImpl(Registrar<T_Op>::create(__VA_ARGS__)(op)); \
} else { \ } else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ (op).setImpl(Registrar<T_Op>::create(__VA_ARGS__)(op)); \
} }
#else #else
#define SET_IMPL_MACRO(T_Op, op, backend_name) \ #define SET_IMPL_MACRO(T_Op, op, ...) \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); (op).setImpl(Registrar<T_Op>::create(__VA_ARGS__)(op));
#endif #endif
} }
......
...@@ -20,22 +20,19 @@ ...@@ -20,22 +20,19 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
const std::string Aidge::Cast_Op::Type = "Cast"; void Aidge::Cast_OpImpl::forward() {
const Cast_Op& op = dynamic_cast<const Cast_Op&>(mOp);
void Aidge::Cast_Op::forward() { op.getOutput(0)->copyCast(*(op.getInput(0)));
if (mImpl) {
mImpl->forward();
}
else {
mOutputs[0]->copyCast(*(mInputs[0]));
}
runHooks();
} }
const std::string Aidge::Cast_Op::Type = "Cast";
void Aidge::Cast_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Cast_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Cast_Op>::exists({name})) { if (Registrar<Cast_Op>::exists({name})) {
SET_IMPL_MACRO(Cast_Op, *this, name); SET_IMPL_MACRO(Cast_Op, *this, name);
} }
else {
mImpl = std::make_shared<Cast_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
...@@ -18,6 +18,45 @@ ...@@ -18,6 +18,45 @@
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
void Aidge::Concat_OpImpl::forward() {
const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp);
const DimSize_t axis = op.template getAttr<DimSize_t>("Axis");
assert(op.getInput(0) && "missing input in Concat operator");
DataType datatypeFirstInput = op.getInput(0)->dataType();
for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) {
assert(op.getInput(i) && "missing input in Concat operator");
assert(op.getInput(i)->dataType() == datatypeFirstInput);
}
DimSize_t outputAxisValue = 0;
for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
outputAxisValue += op.getInput(i)->dims()[axis];
}
DimSize_t prodDimLower = 1;
for (DimIdx_t i = 0; i < axis; ++i) {
prodDimLower *= op.getInput(0)->dims()[i];
}
DimSize_t prodDimHigher = 1;
for (DimIdx_t i = axis + 1; static_cast<std::size_t>(i) < op.getInput(0)->dims().size();
++i) {
prodDimHigher *= op.getInput(0)->dims()[i];
}
std::size_t oIndexStart = 0;
std::size_t oIndex = 0;
for (std::size_t inputId = 0; inputId < op.nbInputs(); ++inputId) {
oIndex = oIndexStart;
const DimSize_t iOffset = prodDimHigher*op.getInput(inputId)->dims()[axis];
for (std::size_t iIndex = 0; iIndex < prodDimLower; ++iIndex) {
op.getOutput(0)->getImpl()->copy(op.getInput(inputId)->getImpl()->rawPtr(iIndex*iOffset), iOffset, oIndex);
oIndex += prodDimHigher*outputAxisValue;
}
oIndexStart += op.getInput(inputId)->dims()[axis]*prodDimHigher;
}
}
const std::string Aidge::Concat_Op::Type = "Concat"; const std::string Aidge::Concat_Op::Type = "Concat";
bool Aidge::Concat_Op::computeOutputDims(bool /*allowDataDependency*/) { bool Aidge::Concat_Op::computeOutputDims(bool /*allowDataDependency*/) {
...@@ -54,6 +93,11 @@ bool Aidge::Concat_Op::computeOutputDims(bool /*allowDataDependency*/) { ...@@ -54,6 +93,11 @@ bool Aidge::Concat_Op::computeOutputDims(bool /*allowDataDependency*/) {
} }
void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) { void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) {
SET_IMPL_MACRO(Concat_Op, *this, name); if (Registrar<Concat_Op>::exists({name})) {
SET_IMPL_MACRO(Concat_Op, *this, name);
}
else {
mImpl = std::make_shared<Concat_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
...@@ -20,8 +20,73 @@ ...@@ -20,8 +20,73 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::Memorize_OpImpl::getNbRequiredData(
Aidge::IOIndex_t inputIdx) const
{
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
if (scheduleStep == 0 && inputIdx == 0) {
// No data input is required for the initial step.
// Initialization data is required however.
return Elts_t::NoneElts();
}
else if (scheduleStep > 0 && inputIdx == 1) {
// No initialization data is required after the initial step.
return Elts_t::NoneElts();
}
else {
return OperatorImpl::getNbRequiredData(inputIdx);
}
}
Aidge::Elts_t Aidge::Memorize_OpImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
assert(mOp.getRawOutput(outputIdx) && "requires valid output");
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
if (endStep > 0 && outputIdx == 1 && scheduleStep >= endStep) {
return Elts_t::NoneElts();
}
else {
return Elts_t::DataElts(op.getOutput(outputIdx)->size());
}
}
void Aidge::Memorize_OpImpl::updateConsummerProducer() {
OperatorImpl::updateConsummerProducer();
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
AIDGE_ASSERT(endStep == 0 || scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded");
}
void Aidge::Memorize_OpImpl::forward() {
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int forwardStep = op.template getAttr<MemorizeAttr::ForwardStep>();
const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
AIDGE_ASSERT(endStep == 0 || forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded");
if (forwardStep == 0) {
op.getOutput(0)->getImpl()->copy(op.getInput(1)->getImpl()->rawPtr(), op.getInput(1)->size());
}
else {
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size());
}
}
const std::string Aidge::Memorize_Op::Type = "Memorize"; const std::string Aidge::Memorize_Op::Type = "Memorize";
void Aidge::Memorize_Op::updateConsummerProducer() {
Operator::updateConsummerProducer();
++this->template getAttr<MemorizeAttr::ScheduleStep>();
this->template getAttr<MemorizeAttr::ForwardStep>() = 0;
}
bool Aidge::Memorize_Op::computeOutputDims(bool /*allowDataDependency*/) { bool Aidge::Memorize_Op::computeOutputDims(bool /*allowDataDependency*/) {
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (!getInput(i)) { if (!getInput(i)) {
...@@ -45,11 +110,6 @@ bool Aidge::Memorize_Op::computeOutputDims(bool /*allowDataDependency*/) { ...@@ -45,11 +110,6 @@ bool Aidge::Memorize_Op::computeOutputDims(bool /*allowDataDependency*/) {
return false; return false;
} }
void Aidge::Memorize_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
mImpl = Registrar<Memorize_Op>::create({name})(*this);
mOutputs[0]->setBackend(name, device);
}
bool Aidge::Memorize_Op::outputDimsForwarded() const { bool Aidge::Memorize_Op::outputDimsForwarded() const {
// Only check the output dims // Only check the output dims
bool forwarded = true; bool forwarded = true;
...@@ -60,10 +120,14 @@ bool Aidge::Memorize_Op::outputDimsForwarded() const { ...@@ -60,10 +120,14 @@ bool Aidge::Memorize_Op::outputDimsForwarded() const {
return forwarded; return forwarded;
} }
void Aidge::Memorize_Op::updateConsummerProducer() { void Aidge::Memorize_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
Operator::updateConsummerProducer(); if (Registrar<Memorize_Op>::exists({name})){
++this->template getAttr<MemorizeAttr::ScheduleStep>(); SET_IMPL_MACRO(Memorize_Op, *this, name);
this->template getAttr<MemorizeAttr::ForwardStep>() = 0; }
else {
mImpl = std::make_shared<Memorize_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
} }
void Aidge::Memorize_Op::forward() { void Aidge::Memorize_Op::forward() {
......
...@@ -12,15 +12,19 @@ ...@@ -12,15 +12,19 @@
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Move.hpp" #include "aidge/operator/Move.hpp"
void Aidge::Move_OpImpl::forward() {
const Move_Op& op = dynamic_cast<const Move_Op&>(mOp);
op.getOutput(0)->copyFrom(*(op.getInput(0)));
}
const std::string Aidge::Move_Op::Type = "Move"; const std::string Aidge::Move_Op::Type = "Move";
void Aidge::Move_Op::forward() { void Aidge::Move_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (mImpl) { if (Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
mImpl->forward(); SET_IMPL_MACRO(Move_Op, *this, {mInputs[0]->getImpl()->backend(), name});
} }
else { else {
mOutputs[0]->copyFrom(*(mInputs[0])); mImpl = std::make_shared<Move_OpImpl>(*this);
} }
mOutputs[0]->setBackend(name, device);
runHooks();
} }
...@@ -20,6 +20,20 @@ ...@@ -20,6 +20,20 @@
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::Pop_OpImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input");
const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp);
return Elts_t::DataElts(op.getInput(inputIdx)->size()
/ op.getInput(inputIdx)->dims()[0]);
}
void Aidge::Pop_OpImpl::forward() {
const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp);
assert(op.getInput(0) && "missing input #0");
const unsigned int forwardStep = op.template getAttr<PopAttr::ForwardStep>();
*op.getOutput(0) = op.getInput(0)->extract({forwardStep});
}
const std::string Aidge::Pop_Op::Type = "Pop"; const std::string Aidge::Pop_Op::Type = "Pop";
...@@ -43,12 +57,17 @@ void Aidge::Pop_Op::updateConsummerProducer() { ...@@ -43,12 +57,17 @@ void Aidge::Pop_Op::updateConsummerProducer() {
this->template getAttr<PopAttr::ForwardStep>() = 0; this->template getAttr<PopAttr::ForwardStep>() = 0;
} }
void Aidge::Pop_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Pop_Op>::exists({name})){
SET_IMPL_MACRO(Pop_Op, *this, name);
}
else {
mImpl = std::make_shared<Pop_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
void Aidge::Pop_Op::forward() { void Aidge::Pop_Op::forward() {
Operator::forward(); Operator::forward();
++this->template getAttr<PopAttr::ForwardStep>(); ++this->template getAttr<PopAttr::ForwardStep>();
} }
void Aidge::Pop_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(Pop_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
...@@ -32,28 +32,12 @@ Aidge::Producer_Op::Producer_Op(const std::shared_ptr<Aidge::Tensor> tensor, boo ...@@ -32,28 +32,12 @@ Aidge::Producer_Op::Producer_Op(const std::shared_ptr<Aidge::Tensor> tensor, boo
Attributes_(attr<ProdAttr::Constant>(constant)) Attributes_(attr<ProdAttr::Constant>(constant))
{ {
mOutputs[0] = tensor; // copy the pointer of the Tensor mOutputs[0] = tensor; // copy the pointer of the Tensor
#ifdef PYBIND if (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
if(Py_IsInitialized()) { SET_IMPL_MACRO(Producer_Op, *this, mOutputs[0]->getImpl()->backend());
auto obj = py::cast(&(*this)); }
setImpl((mOutputs[0]->hasImpl()) ? else {
(Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? mImpl = std::make_shared<OperatorImpl>(*this);
Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) :
std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) :
std::make_shared<OperatorImpl>(*this, ""));
} else {
setImpl((mOutputs[0]->hasImpl()) ?
(Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ?
Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) :
std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) :
std::make_shared<OperatorImpl>(*this, ""));
} }
#else
setImpl((mOutputs[0]->hasImpl()) ?
(Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ?
Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) :
std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) :
std::make_shared<OperatorImpl>(*this, ""));
#endif
} }
/** /**
...@@ -66,57 +50,31 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op) ...@@ -66,57 +50,31 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op)
Attributes_(op) Attributes_(op)
{ {
mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0))); mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0)));
#ifdef PYBIND if (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
if(Py_IsInitialized()) { SET_IMPL_MACRO(Producer_Op, *this, mOutputs[0]->getImpl()->backend());
auto obj = py::cast(&(*this)); }
setImpl((mOutputs[0]->hasImpl()) ? else {
(Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? mImpl = std::make_shared<OperatorImpl>(*this);
Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : }
std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) :
std::make_shared<OperatorImpl>(*this, ""));
} else {
setImpl((mOutputs[0]->hasImpl()) ?
(Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ?
Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) :
std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) :
std::make_shared<OperatorImpl>(*this, ""));
}
#else
setImpl((mOutputs[0]->hasImpl()) ?
(Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ?
Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) :
std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) :
std::make_shared<OperatorImpl>(*this, ""));
#endif
// if (mOutputs[0]->hasImpl()) {
// if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){
// setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this));
// }
// else {
// mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend());
// }
// } else {
// mImpl = nullptr;
// }
} }
void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
#ifdef PYBIND if (Registrar<Producer_Op>::exists({name})){
if(Py_IsInitialized()) { SET_IMPL_MACRO(Producer_Op, *this, name);
auto obj = py::cast(&(*this)); }
setImpl((Registrar<Producer_Op>::exists({name})) ? else {
Registrar<Producer_Op>::create(name)(*this) : mImpl = std::make_shared<OperatorImpl>(*this);
std::make_shared<OperatorImpl>(*this, "")); }
} else {
setImpl((Registrar<Producer_Op>::exists({name})) ?
Registrar<Producer_Op>::create(name)(*this) :
std::make_shared<OperatorImpl>(*this, ""));
}
#else
setImpl((Registrar<Producer_Op>::exists({name})) ?
Registrar<Producer_Op>::create(name)(*this) :
std::make_shared<OperatorImpl>(*this, ""));
#endif
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
\ No newline at end of file
void Aidge::Producer_Op::forward() {
if (!backend().empty()) {
mImpl->forward();
}
else {
fmt::print("Basic Producer forward() function.\n");
}
runHooks();
}
...@@ -23,6 +23,11 @@ ...@@ -23,6 +23,11 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
void Aidge::Reshape_OpImpl::forward() {
const Reshape_Op& op = dynamic_cast<const Reshape_Op&>(mOp);
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size());
}
const std::string Aidge::Reshape_Op::Type = "Reshape"; const std::string Aidge::Reshape_Op::Type = "Reshape";
bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) { bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) {
...@@ -65,6 +70,11 @@ bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) { ...@@ -65,6 +70,11 @@ bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) {
} }
void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(Reshape_Op, *this, name); if (Registrar<Reshape_Op>::exists({name})){
SET_IMPL_MACRO(Reshape_Op, *this, name);
}
else {
mImpl = std::make_shared<Reshape_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment