From 75daa0dbc43ec9b80a8501934b108d14f1d35010 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 6 Sep 2023 18:40:43 +0200 Subject: [PATCH] Fixed operator clonning --- include/aidge/operator/Add.hpp | 3 +++ include/aidge/operator/AvgPooling.hpp | 3 +++ include/aidge/operator/BatchNorm.hpp | 3 +++ include/aidge/operator/Conv.hpp | 3 +++ include/aidge/operator/ConvDepthWise.hpp | 3 +++ include/aidge/operator/FC.hpp | 3 +++ include/aidge/operator/GenericOperator.hpp | 3 +++ include/aidge/operator/LeakyReLU.hpp | 3 +++ include/aidge/operator/Matmul.hpp | 3 +++ include/aidge/operator/MetaOperator.hpp | 3 +++ include/aidge/operator/Operator.hpp | 8 ++++++++ include/aidge/operator/Producer.hpp | 4 ++++ include/aidge/operator/ReLU.hpp | 3 +++ include/aidge/operator/Softmax.hpp | 3 +++ src/graph/Node.cpp | 4 ++-- 15 files changed, 50 insertions(+), 2 deletions(-) diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index ff3d1888c..5bc3ef0e1 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -47,6 +47,9 @@ public: } setDatatype(DataType::Float32); } + Operator* clone() const override { + return new Add_Op(*static_cast<const Add_Op*>(this)); + } // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index bf76bd458..197388ebc 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -44,6 +44,9 @@ public: static constexpr const char *Type = "AvgPooling"; AvgPooling_Op() = delete; + Operator* clone() const override { + return new AvgPooling_Op<DIM>(*static_cast<const AvgPooling_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<AvgPoolingParam, std::array<DimSize_t, DIM>, diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 6861c1359..68589c654 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -43,6 +43,9 @@ public: static constexpr const char *Type = "BatchNorm"; BatchNorm_Op() = delete; + Operator* clone() const override { + return new BatchNorm_Op<DIM>(*static_cast<const BatchNorm_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<BatchNormParam, float, float>; template <BatchNormParam e> diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 1edc94b96..11526916a 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -43,6 +43,9 @@ public: static constexpr const char *Type = "Conv"; Conv_Op() = delete; + Operator* clone() const override { + return new Conv_Op<DIM>(*static_cast<const Conv_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>; diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 95a2ff55b..88013c4ea 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -47,6 +47,9 @@ class ConvDepthWise_Op : public Operator, static constexpr const char *Type = "ConvDepthWise"; ConvDepthWise_Op() = delete; + Operator* clone() const override { + return new ConvDepthWise_Op<DIM>(*static_cast<const ConvDepthWise_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<ConvDepthWiseParam, std::array<DimSize_t, DIM>, diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index db92dc9c7..244b0322a 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -43,6 +43,9 @@ public: static constexpr const char* Type = "FC"; FC_Op() = delete; + Operator* clone() const override { + return new FC_Op(*static_cast<const FC_Op*>(this)); + } using Parameterizable_ = Parameterizable<FCParam, DimSize_t, bool>; template <FCParam e> using param = typename Parameterizable_::template param<e>; diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index dab5df9a8..ba56746ac 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -48,6 +48,9 @@ class GenericOperator_Op mOutputs[i] = std::make_shared<Tensor>(); } } + Operator* clone() const override { + return new GenericOperator_Op(*static_cast<const GenericOperator_Op*>(this)); + } /** * @brief Get the Parameter object identified by its name. diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 1dff2550a..e3476d3fd 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -41,6 +41,9 @@ public: static constexpr const char* Type = "LeakyReLU"; LeakyReLU_Op() = delete; + Operator* clone() const override { + return new LeakyReLU_Op(*static_cast<const LeakyReLU_Op*>(this)); + } using Parameterizable_ = Parameterizable<LeakyReLUParam, float>; template <LeakyReLUParam e> using param = typename Parameterizable_::template param<e>; diff --git a/include/aidge/operator/Matmul.hpp b/include/aidge/operator/Matmul.hpp index 639b36691..5dbae2e70 100644 --- a/include/aidge/operator/Matmul.hpp +++ b/include/aidge/operator/Matmul.hpp @@ -42,6 +42,9 @@ public: static constexpr const char* Type = "Matmul"; Matmul_Op() = delete; + Operator* clone() const override { + return new Matmul_Op(*static_cast<const Matmul_Op*>(this)); + } using Parameterizable_ = Parameterizable<MatmulParam, DimSize_t>; template <MatmulParam e> using param = typename Parameterizable_::template param<e>; diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 35a59b56c..f2bd00118 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -21,6 +21,9 @@ public: : Operator("MetaOp") { } + Operator* clone() const override { + return new MetaOperator(*static_cast<const MetaOperator*>(this)); + } ~MetaOperator() = default; }; } diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 30e1ce2a7..585096beb 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -33,8 +33,16 @@ private: public: Operator() = delete; Operator(const char* type) : mType(type) {} + virtual Operator* clone() const = 0; virtual ~Operator(); + Operator(const Operator& op): + std::enable_shared_from_this<Operator>() + { + mType = op.mType; + // mImpl is not set right now. + // TODO: clone the impl as well? + } public: diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index acdc69b69..e8e831e15 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -44,6 +44,10 @@ public: mOutput->resize(dims); } + Operator* clone() const override { + return new Producer_Op(*static_cast<const Producer_Op*>(this)); + } + Producer_Op(const std::shared_ptr<Tensor> tensor) : Operator(Type), mOutput(tensor) diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 141bd3ae1..b3983557c 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -41,6 +41,9 @@ public: { setDatatype(DataType::Float32); } + Operator* clone() const override { + return new ReLU_Op(*static_cast<const ReLU_Op*>(this)); + } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 64e713b33..d6ba3f1fc 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -41,6 +41,9 @@ public: { setDatatype(DataType::Float32); } + Operator* clone() const override { + return new Softmax_Op(*static_cast<const Softmax_Op*>(this)); + } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 7a4dc54c3..a2fb2808c 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -332,13 +332,13 @@ Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { std::shared_ptr<Operator> op = (op->type() == Producer_Op::Type) ? mOperator - : std::make_shared<Operator>(*mOperator); + : std::shared_ptr<Operator>(mOperator->clone()); return std::make_shared<Node>(op, mName); } Aidge::NodePtr Aidge::Node::clone() const { - return std::make_shared<Node>(std::make_shared<Operator>(*mOperator), mName); + return std::make_shared<Node>(std::shared_ptr<Operator>(mOperator->clone()), mName); } ///////////////////////////////////////////////////////////////////////////////////////////// -- GitLab