diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 944183d0b5866af6c037cf64c91c014e2f25bad6..04044ed1c77915ec10b5af5b660cf8e6b20c81b2 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -9,12 +9,11 @@ * ********************************************************************************/ -#ifndef AIDGE_OPERATORIMPL_H_ -#define AIDGE_OPERATORIMPL_H_ +#ifndef AIDGE_BACKEND_OPERATORIMPL_H_ +#define AIDGE_BACKEND_OPERATORIMPL_H_ -#include <cstddef> +#include <string> #include <vector> -#include <memory> #include "aidge/utils/Types.h" @@ -83,4 +82,4 @@ protected: }; } // namespace Aidge -#endif /* AIDGE_OPERATORIMPL_H_ */ +#endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */ diff --git a/include/aidge/operator/Memorize.hpp b/include/aidge/operator/Memorize.hpp index 73433aaca51d07fc3f01682e47cc19433c5c86bf..7de34563adcaabd63ab036232d4d7b6539fd11eb 100644 --- a/include/aidge/operator/Memorize.hpp +++ b/include/aidge/operator/Memorize.hpp @@ -12,17 +12,17 @@ #ifndef AIDGE_CORE_OPERATOR_MEMORIZE_H_ #define AIDGE_CORE_OPERATOR_MEMORIZE_H_ -#include <cassert> #include <memory> +#include <string> #include <vector> -#include "aidge/utils/Registrar.hpp" -#include "aidge/operator/OperatorTensor.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" -#include "aidge/utils/Types.h" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Registrar.hpp" #include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" namespace Aidge { enum class MemorizeAttr { ScheduleStep, ForwardStep, EndStep }; @@ -47,14 +47,19 @@ public: } /** - * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), + * but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ Memorize_Op(const Memorize_Op& op) : OperatorTensor(op), Attributes_(op) { - mImpl = op.mImpl ? Registrar<Memorize_Op>::create(op.backend())(*this) : nullptr; + if (op.mImpl) { + SET_IMPL_MACRO(Memorize_Op, *this, op.backend()); + } else { + mImpl = nullptr; + } mOutputs[1] = mOutputs[0]; } @@ -66,10 +71,7 @@ public: return std::make_shared<Memorize_Op>(*this); } - void setBackend(const std::string& name, DeviceIdx_t device = 0) override { - mImpl = Registrar<Memorize_Op>::create({name})(*this); - mOutputs[0]->setBackend(name, device); - } + void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; void computeOutputDims() override; bool outputDimsForwarded() const override; @@ -98,4 +100,4 @@ const char *const EnumStrings<Aidge::MemorizeAttr>::data[] = { }; } -#endif /* AIDGE_CORE_OPERATOR_MEMORIZE_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_MEMORIZE_H_ */ diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp index 6e34c1a2005f551c255e9b7441e853015354337f..6e54a234d2fc78c8e8e9a43a7528709c8e51adc4 100644 --- a/src/operator/Memorize.cpp +++ b/src/operator/Memorize.cpp @@ -9,9 +9,17 @@ * ********************************************************************************/ -#include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Memorize.hpp" +#include <memory> +#include <string> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + const std::string Aidge::Memorize_Op::Type = "Memorize"; void Aidge::Memorize_Op::computeOutputDims() { @@ -33,6 +41,11 @@ void Aidge::Memorize_Op::computeOutputDims() { } } +void Aidge::Memorize_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + mImpl = Registrar<Memorize_Op>::create({name})(*this); + mOutputs[0]->setBackend(name, device); +} + bool Aidge::Memorize_Op::outputDimsForwarded() const { // Only check the output dims bool forwarded = true; diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp index 6b16117d6387c5de4f0d81e20b89568dde97a5b2..72a04de04fda8a432309de8b4a69b1dfb6af1370 100644 --- a/src/operator/Pow.cpp +++ b/src/operator/Pow.cpp @@ -15,6 +15,7 @@ #include <vector> #include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" #include "aidge/operator/Pow.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp"