diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 6a9056723df133fef62e56f969d39d8f69390a76..1fc9168da120ba87c916b1a6a346997be69184b4 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -23,7 +23,7 @@ class Operator; class OperatorImpl { public: - OperatorImpl(const Operator& op, const std::string& backend); + OperatorImpl(const Operator& op, const std::string& backend = ""); virtual void forward(); virtual void backward(); diff --git a/include/aidge/operator/Cast.hpp b/include/aidge/operator/Cast.hpp index bbc776a1175a1fc29d08c3872649a6b7aac2f04f..6efbc0a214dde3ca969226f734b5ee903fe5ab50 100644 --- a/include/aidge/operator/Cast.hpp +++ b/include/aidge/operator/Cast.hpp @@ -24,13 +24,20 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Cast_OpImpl : public OperatorImpl { +public: + Cast_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + void forward() override; +}; class Cast_Op : public OperatorTensor, public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> { public: static const std::string Type; - Cast_Op() : OperatorTensor(Type, 1, 0, 1) {} + Cast_Op() : OperatorTensor(Type, 1, 0, 1) { + mImpl = std::make_shared<Cast_OpImpl>(*this); + } /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -39,10 +46,11 @@ public: Cast_Op(const Cast_Op& op) : OperatorTensor(op) { - if (op.mImpl) { + if (!op.backend().empty()) { SET_IMPL_MACRO(Cast_Op, *this, op.backend()); - } else { - mImpl = nullptr; + } + else { + mImpl = std::make_shared<Cast_OpImpl>(*this); } } @@ -56,8 +64,6 @@ public: void setBackend(const std::string& name, DeviceIdx_t device = 0) override; - void forward() override; - static const std::vector<std::string> getInputsName(){ return {"data_input"}; } diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 97c477db591a29987f88b0c58beaf128169624ea..32a519dbc750361c7ad1b6686d37a0766faf696e 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -26,6 +26,12 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Concat_OpImpl : public OperatorImpl { +public: + Concat_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + void forward() override; +}; + enum class ConcatAttr { Axis }; class Concat_Op : public OperatorTensor, @@ -45,6 +51,7 @@ public: if (nbIn == 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Add operator should have at least one input."); } + mImpl = std::make_shared<Concat_OpImpl>(*this); } /** @@ -55,10 +62,11 @@ public: : OperatorTensor(op), Attributes_(op) { - if (op.mImpl){ + if (!op.backend().empty()) { SET_IMPL_MACRO(Concat_Op, *this, op.backend()); - }else{ - mImpl = nullptr; + } + else { + mImpl = std::make_shared<Concat_OpImpl>(*this); } } diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 6208ea0a920e6b088dfb60ca49237d5f6664b08e..49885f9fdae05a55552869a6543ef1810aa1dfae 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -37,7 +37,7 @@ public: GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) : OperatorTensor(type, nbData, nbParam, nbOut) { - mImpl = std::make_shared<OperatorImpl>(*this, ""); + mImpl = std::make_shared<OperatorImpl>(*this); } /** diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index 51c70eae573513ef0c897bf6f71371512c467a0b..f49711837b9b5c5126dce18a1864b2ed156af6f4 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -42,7 +42,7 @@ public: Identity_Op() : OperatorTensor(Type, 1, 0, 1) { - mImpl = std::make_shared<OperatorImpl>(*this, ""); + mImpl = std::make_shared<OperatorImpl>(*this); } /** diff --git a/include/aidge/operator/Memorize.hpp b/include/aidge/operator/Memorize.hpp index 89d2652834101a0cfb4038c610d54c151a3760f4..6f668a94238cf7ba62b3cf2776729ff9f41e5b1a 100644 --- a/include/aidge/operator/Memorize.hpp +++ b/include/aidge/operator/Memorize.hpp @@ -25,6 +25,15 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Memorize_OpImpl : public OperatorImpl { +public: + Memorize_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; + Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override final; + void updateConsummerProducer() override; + void forward() override; +}; + enum class MemorizeAttr { ScheduleStep, ForwardStep, EndStep }; class Memorize_Op : public OperatorTensor, diff --git a/include/aidge/operator/Move.hpp b/include/aidge/operator/Move.hpp index 3652cf9697c6bcfea4befe4cdcdf5b9efff8b70c..e9bcaa871619828a50dcd407d39744e7983fe2c4 100644 --- a/include/aidge/operator/Move.hpp +++ b/include/aidge/operator/Move.hpp @@ -24,13 +24,20 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Move_OpImpl : public OperatorImpl { +public: + Move_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + void forward() override; +}; class Move_Op : public OperatorTensor, public Registrable<Move_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Move_Op&)> { public: static const std::string Type; - Move_Op() : OperatorTensor(Type, 1, 0, 1) {} + Move_Op() : OperatorTensor(Type, 1, 0, 1) { + mImpl = std::make_shared<Move_OpImpl>(*this); + } /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -39,7 +46,12 @@ public: Move_Op(const Move_Op& op) : OperatorTensor(op) { - mImpl = op.mImpl ? Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr; + if (!op.backend().empty()) { + SET_IMPL_MACRO(Move_Op, *this, {op.getInput(0)->getImpl()->backend(), op.backend()}); + } + else { + mImpl = std::make_shared<Move_OpImpl>(*this); + } } /** @@ -50,14 +62,7 @@ public: return std::make_shared<Move_Op>(*this); } - void setBackend(const std::string& name, DeviceIdx_t device = 0) override { - if (mInputs[0]->getImpl() && Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { - mImpl = Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this); - } - mOutputs[0]->setBackend(name, device); - } - - void forward() override; + void setBackend(const std::string& name, DeviceIdx_t device = 0) override; static const std::vector<std::string> getInputsName(){ return {"data_input"}; diff --git a/include/aidge/operator/Pop.hpp b/include/aidge/operator/Pop.hpp index c584390ca6b8b151020f8d858e6c2d94683328d1..372faff6a89e364e71f77df3bd4573705ab86fed 100644 --- a/include/aidge/operator/Pop.hpp +++ b/include/aidge/operator/Pop.hpp @@ -24,6 +24,13 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Pop_OpImpl : public OperatorImpl { +public: + Pop_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + void forward() override; +}; + enum class PopAttr { ForwardStep }; class Pop_Op : public OperatorTensor, @@ -39,7 +46,9 @@ public: Pop_Op() : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<PopAttr::ForwardStep>(0)) - {} + { + mImpl = std::make_shared<Pop_OpImpl>(*this); + } /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -49,10 +58,11 @@ public: : OperatorTensor(op), Attributes_(op) { - if (op.mImpl){ + if (!op.backend().empty()) { SET_IMPL_MACRO(Pop_Op, *this, op.backend()); - } else { - mImpl = nullptr; + } + else { + mImpl = std::make_shared<Pop_OpImpl>(*this); } } diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 79a116e4a0a2267084ae3d8961b924a596c2d5e0..e21aa9aea936568dfb5d5ddd40779dc0acc06160 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -47,7 +47,7 @@ public: Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0]->resize(dims); - mImpl = std::make_shared<OperatorImpl>(*this, ""); + mImpl = std::make_shared<OperatorImpl>(*this); } /** @@ -102,9 +102,8 @@ public: return {"data_output"}; } - void forward() override final { - fmt::print("Basic Producer forward() function.\n"); - } + void forward() override final; + void backward() override final { fmt::print("Basic Producer backward() function.\n"); } diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 8f1482019a4c45160125bf0dbff1479d02f62e49..bf0f7ee3492cf4e52b903401af57c701d24f9190 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -23,6 +23,11 @@ #include "aidge/utils/Types.h" namespace Aidge { +class Reshape_OpImpl : public OperatorImpl { +public: + Reshape_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} + void forward() override; +}; enum class ReshapeAttr { Shape }; @@ -42,7 +47,9 @@ public: Reshape_Op(const std::vector<std::int64_t>& shape) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<ReshapeAttr::Shape>(shape)) - {} + { + mImpl = std::make_shared<Reshape_OpImpl>(*this); + } /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -52,10 +59,11 @@ public: : OperatorTensor(op), Attributes_(op) { - if (op.mImpl){ + if (!op.backend().empty()) { SET_IMPL_MACRO(Reshape_Op, *this, op.backend()); - } else { - mImpl = nullptr; + } + else { + mImpl = std::make_shared<Reshape_OpImpl>(*this); } } diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index a6d1d7a9eb5d88dedaf73564847b0f4fbd797c43..b0acdaff7cb75afec78f0564fb95c98f2b32f47b 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -129,16 +129,16 @@ void declare_registrable(py::module& m, const std::string& class_name){ * cyril.moineau@cea.fr */ #ifdef PYBIND -#define SET_IMPL_MACRO(T_Op, op, backend_name) \ +#define SET_IMPL_MACRO(T_Op, op, ...) \ if(Py_IsInitialized()) { \ auto obj = py::cast(&(op)); \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + (op).setImpl(Registrar<T_Op>::create(__VA_ARGS__)(op)); \ } else { \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + (op).setImpl(Registrar<T_Op>::create(__VA_ARGS__)(op)); \ } #else -#define SET_IMPL_MACRO(T_Op, op, backend_name) \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); +#define SET_IMPL_MACRO(T_Op, op, ...) \ + (op).setImpl(Registrar<T_Op>::create(__VA_ARGS__)(op)); #endif } diff --git a/src/operator/Cast.cpp b/src/operator/Cast.cpp index 4f1ac55898b11668ba1c2f5299f8e1ca1d4e5df1..f1c8e25e17c80d58d444a1ddddbaa428b2fc4c41 100644 --- a/src/operator/Cast.cpp +++ b/src/operator/Cast.cpp @@ -20,22 +20,19 @@ #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" -const std::string Aidge::Cast_Op::Type = "Cast"; - -void Aidge::Cast_Op::forward() { - if (mImpl) { - mImpl->forward(); - } - else { - mOutputs[0]->copyCast(*(mInputs[0])); - } - - runHooks(); +void Aidge::Cast_OpImpl::forward() { + const Cast_Op& op = dynamic_cast<const Cast_Op&>(mOp); + op.getOutput(0)->copyCast(*(op.getInput(0))); } +const std::string Aidge::Cast_Op::Type = "Cast"; + void Aidge::Cast_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { if (Registrar<Cast_Op>::exists({name})) { SET_IMPL_MACRO(Cast_Op, *this, name); } + else { + mImpl = std::make_shared<Cast_OpImpl>(*this); + } mOutputs[0]->setBackend(name, device); } diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp index d2bfd17ba29cde3a89e114d57cb6d860cdbc2fee..929000a5f4ceeb4c073b8edd919ac976fc651ae2 100644 --- a/src/operator/Concat.cpp +++ b/src/operator/Concat.cpp @@ -18,6 +18,45 @@ #include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" +void Aidge::Concat_OpImpl::forward() { + const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp); + const DimSize_t axis = op.template getAttr<DimSize_t>("Axis"); + + assert(op.getInput(0) && "missing input in Concat operator"); + DataType datatypeFirstInput = op.getInput(0)->dataType(); + for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) { + assert(op.getInput(i) && "missing input in Concat operator"); + assert(op.getInput(i)->dataType() == datatypeFirstInput); + } + + DimSize_t outputAxisValue = 0; + for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) { + outputAxisValue += op.getInput(i)->dims()[axis]; + } + + DimSize_t prodDimLower = 1; + for (DimIdx_t i = 0; i < axis; ++i) { + prodDimLower *= op.getInput(0)->dims()[i]; + } + DimSize_t prodDimHigher = 1; + for (DimIdx_t i = axis + 1; static_cast<std::size_t>(i) < op.getInput(0)->dims().size(); + ++i) { + prodDimHigher *= op.getInput(0)->dims()[i]; + } + + std::size_t oIndexStart = 0; + std::size_t oIndex = 0; + for (std::size_t inputId = 0; inputId < op.nbInputs(); ++inputId) { + oIndex = oIndexStart; + const DimSize_t iOffset = prodDimHigher*op.getInput(inputId)->dims()[axis]; + for (std::size_t iIndex = 0; iIndex < prodDimLower; ++iIndex) { + op.getOutput(0)->getImpl()->copy(op.getInput(inputId)->getImpl()->rawPtr(iIndex*iOffset), iOffset, oIndex); + oIndex += prodDimHigher*outputAxisValue; + } + oIndexStart += op.getInput(inputId)->dims()[axis]*prodDimHigher; + } +} + const std::string Aidge::Concat_Op::Type = "Concat"; bool Aidge::Concat_Op::computeOutputDims(bool /*allowDataDependency*/) { @@ -54,6 +93,11 @@ bool Aidge::Concat_Op::computeOutputDims(bool /*allowDataDependency*/) { } void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) { - SET_IMPL_MACRO(Concat_Op, *this, name); + if (Registrar<Concat_Op>::exists({name})) { + SET_IMPL_MACRO(Concat_Op, *this, name); + } + else { + mImpl = std::make_shared<Concat_OpImpl>(*this); + } mOutputs[0]->setBackend(name, device); } diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp index 3490a5f6dda864b6f0e645b43e072ddffef3522d..4e802816a13fcfb0fbfa266ca79baac3e6423a3b 100644 --- a/src/operator/Memorize.cpp +++ b/src/operator/Memorize.cpp @@ -20,8 +20,73 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" +Aidge::Elts_t Aidge::Memorize_OpImpl::getNbRequiredData( + Aidge::IOIndex_t inputIdx) const +{ + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + + if (scheduleStep == 0 && inputIdx == 0) { + // No data input is required for the initial step. + // Initialization data is required however. + return Elts_t::NoneElts(); + } + else if (scheduleStep > 0 && inputIdx == 1) { + // No initialization data is required after the initial step. + return Elts_t::NoneElts(); + } + else { + return OperatorImpl::getNbRequiredData(inputIdx); + } +} + +Aidge::Elts_t Aidge::Memorize_OpImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { + assert(mOp.getRawOutput(outputIdx) && "requires valid output"); + + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + + if (endStep > 0 && outputIdx == 1 && scheduleStep >= endStep) { + return Elts_t::NoneElts(); + } + else { + return Elts_t::DataElts(op.getOutput(outputIdx)->size()); + } +} + +void Aidge::Memorize_OpImpl::updateConsummerProducer() { + OperatorImpl::updateConsummerProducer(); + + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + AIDGE_ASSERT(endStep == 0 || scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded"); +} + +void Aidge::Memorize_OpImpl::forward() { + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int forwardStep = op.template getAttr<MemorizeAttr::ForwardStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + AIDGE_ASSERT(endStep == 0 || forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded"); + + if (forwardStep == 0) { + op.getOutput(0)->getImpl()->copy(op.getInput(1)->getImpl()->rawPtr(), op.getInput(1)->size()); + } + else { + op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size()); + } +} + const std::string Aidge::Memorize_Op::Type = "Memorize"; +void Aidge::Memorize_Op::updateConsummerProducer() { + Operator::updateConsummerProducer(); + ++this->template getAttr<MemorizeAttr::ScheduleStep>(); + this->template getAttr<MemorizeAttr::ForwardStep>() = 0; +} + bool Aidge::Memorize_Op::computeOutputDims(bool /*allowDataDependency*/) { for (size_t i = 0; i < 2; ++i) { if (!getInput(i)) { @@ -45,11 +110,6 @@ bool Aidge::Memorize_Op::computeOutputDims(bool /*allowDataDependency*/) { return false; } -void Aidge::Memorize_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - mImpl = Registrar<Memorize_Op>::create({name})(*this); - mOutputs[0]->setBackend(name, device); -} - bool Aidge::Memorize_Op::outputDimsForwarded() const { // Only check the output dims bool forwarded = true; @@ -60,10 +120,14 @@ bool Aidge::Memorize_Op::outputDimsForwarded() const { return forwarded; } -void Aidge::Memorize_Op::updateConsummerProducer() { - Operator::updateConsummerProducer(); - ++this->template getAttr<MemorizeAttr::ScheduleStep>(); - this->template getAttr<MemorizeAttr::ForwardStep>() = 0; +void Aidge::Memorize_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + if (Registrar<Memorize_Op>::exists({name})){ + SET_IMPL_MACRO(Memorize_Op, *this, name); + } + else { + mImpl = std::make_shared<Memorize_OpImpl>(*this); + } + mOutputs[0]->setBackend(name, device); } void Aidge::Memorize_Op::forward() { diff --git a/src/operator/Move.cpp b/src/operator/Move.cpp index d8776e32fca909663bafe3fae3ebf9f5616c69c9..0f635ea655676e488343bb55d9de6423a997af7d 100644 --- a/src/operator/Move.cpp +++ b/src/operator/Move.cpp @@ -12,15 +12,19 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Move.hpp" +void Aidge::Move_OpImpl::forward() { + const Move_Op& op = dynamic_cast<const Move_Op&>(mOp); + op.getOutput(0)->copyFrom(*(op.getInput(0))); +} + const std::string Aidge::Move_Op::Type = "Move"; -void Aidge::Move_Op::forward() { - if (mImpl) { - mImpl->forward(); +void Aidge::Move_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + if (Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { + SET_IMPL_MACRO(Move_Op, *this, {mInputs[0]->getImpl()->backend(), name}); } else { - mOutputs[0]->copyFrom(*(mInputs[0])); + mImpl = std::make_shared<Move_OpImpl>(*this); } - - runHooks(); + mOutputs[0]->setBackend(name, device); } diff --git a/src/operator/Pop.cpp b/src/operator/Pop.cpp index 9e7b36025055399ecf803995d9e87e645debbfe4..6f09d402af8c6416f3f0b444cd48328b4f5a2031 100644 --- a/src/operator/Pop.cpp +++ b/src/operator/Pop.cpp @@ -20,6 +20,20 @@ #include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" +Aidge::Elts_t Aidge::Pop_OpImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { + assert(mOp.getRawInput(inputIdx) && "requires valid input"); + + const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp); + return Elts_t::DataElts(op.getInput(inputIdx)->size() + / op.getInput(inputIdx)->dims()[0]); +} + +void Aidge::Pop_OpImpl::forward() { + const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp); + assert(op.getInput(0) && "missing input #0"); + const unsigned int forwardStep = op.template getAttr<PopAttr::ForwardStep>(); + *op.getOutput(0) = op.getInput(0)->extract({forwardStep}); +} const std::string Aidge::Pop_Op::Type = "Pop"; @@ -43,12 +57,17 @@ void Aidge::Pop_Op::updateConsummerProducer() { this->template getAttr<PopAttr::ForwardStep>() = 0; } +void Aidge::Pop_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + if (Registrar<Pop_Op>::exists({name})){ + SET_IMPL_MACRO(Pop_Op, *this, name); + } + else { + mImpl = std::make_shared<Pop_OpImpl>(*this); + } + mOutputs[0]->setBackend(name, device); +} + void Aidge::Pop_Op::forward() { Operator::forward(); ++this->template getAttr<PopAttr::ForwardStep>(); } - -void Aidge::Pop_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - SET_IMPL_MACRO(Pop_Op, *this, name); - mOutputs[0]->setBackend(name, device); -} diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp index 38bbbc14846f8f4356602b1d3a66058439bb37d0..f384c10138500f454720395e7387c331d67440b6 100644 --- a/src/operator/Producer.cpp +++ b/src/operator/Producer.cpp @@ -32,28 +32,12 @@ Aidge::Producer_Op::Producer_Op(const std::shared_ptr<Aidge::Tensor> tensor, boo Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0] = tensor; // copy the pointer of the Tensor -#ifdef PYBIND - if(Py_IsInitialized()) { - auto obj = py::cast(&(*this)); - setImpl((mOutputs[0]->hasImpl()) ? - (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? - Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : - std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - std::make_shared<OperatorImpl>(*this, "")); - } else { - setImpl((mOutputs[0]->hasImpl()) ? - (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? - Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : - std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - std::make_shared<OperatorImpl>(*this, "")); + if (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ + SET_IMPL_MACRO(Producer_Op, *this, mOutputs[0]->getImpl()->backend()); + } + else { + mImpl = std::make_shared<OperatorImpl>(*this); } -#else - setImpl((mOutputs[0]->hasImpl()) ? - (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? - Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : - std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - std::make_shared<OperatorImpl>(*this, "")); -#endif } /** @@ -66,57 +50,31 @@ Aidge::Producer_Op::Producer_Op(const Aidge::Producer_Op& op) Attributes_(op) { mOutputs[0] = std::make_shared<Tensor>(*(op.getOutput(0))); -#ifdef PYBIND - if(Py_IsInitialized()) { - auto obj = py::cast(&(*this)); - setImpl((mOutputs[0]->hasImpl()) ? - (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? - Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : - std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - std::make_shared<OperatorImpl>(*this, "")); - } else { - setImpl((mOutputs[0]->hasImpl()) ? - (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? - Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : - std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - std::make_shared<OperatorImpl>(*this, "")); - } -#else - setImpl((mOutputs[0]->hasImpl()) ? - (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()}) ? - Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : - std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend())) : - std::make_shared<OperatorImpl>(*this, "")); -#endif - // if (mOutputs[0]->hasImpl()) { - // if (Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ - // setImpl(Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this)); - // } - // else { - // mImpl = std::make_shared<OperatorImpl>(*this, mOutputs[0]->getImpl()->backend()); - // } - - // } else { - // mImpl = nullptr; - // } + if (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})){ + SET_IMPL_MACRO(Producer_Op, *this, mOutputs[0]->getImpl()->backend()); + } + else { + mImpl = std::make_shared<OperatorImpl>(*this); + } } void Aidge::Producer_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { -#ifdef PYBIND - if(Py_IsInitialized()) { - auto obj = py::cast(&(*this)); - setImpl((Registrar<Producer_Op>::exists({name})) ? - Registrar<Producer_Op>::create(name)(*this) : - std::make_shared<OperatorImpl>(*this, "")); - } else { - setImpl((Registrar<Producer_Op>::exists({name})) ? - Registrar<Producer_Op>::create(name)(*this) : - std::make_shared<OperatorImpl>(*this, "")); - } -#else - setImpl((Registrar<Producer_Op>::exists({name})) ? - Registrar<Producer_Op>::create(name)(*this) : - std::make_shared<OperatorImpl>(*this, "")); -#endif + if (Registrar<Producer_Op>::exists({name})){ + SET_IMPL_MACRO(Producer_Op, *this, name); + } + else { + mImpl = std::make_shared<OperatorImpl>(*this); + } mOutputs[0]->setBackend(name, device); -} \ No newline at end of file +} + +void Aidge::Producer_Op::forward() { + if (!backend().empty()) { + mImpl->forward(); + } + else { + fmt::print("Basic Producer forward() function.\n"); + } + + runHooks(); +} diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index 4ae7b121799775c8e22956c1b5b73c0aa59dbcb6..8431971dadf896add25eb04d1e66e25f0ad3e953 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -23,6 +23,11 @@ #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" +void Aidge::Reshape_OpImpl::forward() { + const Reshape_Op& op = dynamic_cast<const Reshape_Op&>(mOp); + op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size()); +} + const std::string Aidge::Reshape_Op::Type = "Reshape"; bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) { @@ -65,6 +70,11 @@ bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) { } void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - SET_IMPL_MACRO(Reshape_Op, *this, name); + if (Registrar<Reshape_Op>::exists({name})){ + SET_IMPL_MACRO(Reshape_Op, *this, name); + } + else { + mImpl = std::make_shared<Reshape_OpImpl>(*this); + } mOutputs[0]->setBackend(name, device); -} \ No newline at end of file +}