diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 903b6362adf3db0c867dc419086e0cb6ddaa65c7..aa10dea195bb231ed701318cef6e2bfd04a3b7ff 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -26,39 +26,47 @@ namespace Aidge { class Operator : public std::enable_shared_from_this<Operator> { protected: - std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator - std::map<std::string, std::shared_ptr<Hook>> mHooks; + std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator + std::map<std::string, std::shared_ptr<Hook>> mHooks; private: - std::string mType; + std::string mType; + const IOIndex_t mNbData; + const IOIndex_t mNbAttr; + const IOIndex_t mNbOut; public: - Operator() = delete; - Operator(const char* type) : mType(type) {} - virtual std::shared_ptr<Operator> clone() const = 0; - virtual ~Operator(); - - Operator(const Operator& op): - std::enable_shared_from_this<Operator>() - { - mType = op.mType; - mImpl = nullptr; - // Implementation is never cloned. It is up to the non-abstract Operator copy-constructor to create a new implementation matching the copied Operator implementation. - // See https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/merge_requests/8#note_1214050 for the discussion. - // Hooks are not copied. - } + Operator() = delete; + Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbAttr, const IOIndex_t nbOut) + : mType(type), + mNbData(nbData), + mNbAttr(nbAttr), + mNbOut(nbOut) + { + // ctor + } + virtual std::shared_ptr<Operator> clone() const = 0; + virtual ~Operator(); + + Operator(const Operator& op): + std::enable_shared_from_this<Operator>(), + mNbData(op.mNbData), + mNbAttr(op.mNbAttr), + mNbOut(op.mNbOut) + { + mType = op.mType; + mImpl = nullptr; + // Implementation is never cloned. It is up to the non-abstract Operator copy-constructor to create a new implementation matching the copied Operator implementation. + // See https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/merge_requests/8#note_1214050 for the discussion. + // Hooks are not copied. + } public: - virtual void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) = 0; - virtual void computeOutputDims() = 0; - virtual bool outputDimsForwarded() const = 0; + virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) = 0; + virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0; - virtual std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const = 0; - virtual Tensor& input(const IOIndex_t /*inputIdx*/) const = 0; virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0; - virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0; - virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0; std::shared_ptr<Hook> getHook(std::string hookName) { return mHooks[hookName]; @@ -121,10 +129,12 @@ public: return mType; } - virtual IOIndex_t nbInputs() const noexcept = 0; - virtual IOIndex_t nbDataInputs() const noexcept = 0; - virtual IOIndex_t nbOutputs() const noexcept = 0; - static const std::vector<std::string> getInputsName(){ + inline IOIndex_t nbInputs() const noexcept { return mNbData+mNbAttr; }; + inline IOIndex_t nbData() const noexcept { return mNbData; }; + inline IOIndex_t nbAttr() const noexcept { return mNbAttr; }; + inline IOIndex_t nbOutputs() const noexcept { return mNbOut; }; + + static const std::vector<std::string> getInputsName(){ return {}; } static const std::vector<std::string> getOutputsName(){ diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 09a17a428e1de91c0318f710e6f097573cf529a6..f6143f12536400bb7573a72d2596726a277d7148 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -48,8 +48,12 @@ void Aidge::Operator::runHooks() const { } } void Aidge::Operator::forward() { - mImpl->forward(); - runHooks(); + if(mImpl) { + mImpl->forward(); + runHooks(); + } else { + printf("backward: No implementation is linked.\n"); + } } void Aidge::Operator::backward() { mImpl->backward(); }