diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index ff3d1888c3bc70b61a3d4da42908d40de2d1d73e..5bc3ef0e1b0e6a330cd54a8ab5d2d552c7180e95 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -47,6 +47,9 @@ public: } setDatatype(DataType::Float32); } + Operator* clone() const override { + return new Add_Op(*static_cast<const Add_Op*>(this)); + } // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index bf76bd45893b43043b81cd6563c500be27c66b42..197388ebc89d3644aedb48fb3e10b72250df5ac6 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -44,6 +44,9 @@ public: static constexpr const char *Type = "AvgPooling"; AvgPooling_Op() = delete; + Operator* clone() const override { + return new AvgPooling_Op<DIM>(*static_cast<const AvgPooling_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<AvgPoolingParam, std::array<DimSize_t, DIM>, diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 6861c1359737f3f344f0c7d9b2d12c9ff35b88ad..68589c654ad21f1c65d63bf89e2e848d6a6e1211 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -43,6 +43,9 @@ public: static constexpr const char *Type = "BatchNorm"; BatchNorm_Op() = delete; + Operator* clone() const override { + return new BatchNorm_Op<DIM>(*static_cast<const BatchNorm_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<BatchNormParam, float, float>; template <BatchNormParam e> diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 1edc94b96763cc163646037a8bd069023511df67..11526916a3ecf98e988b0b330fc84e9e81b6878c 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -43,6 +43,9 @@ public: static constexpr const char *Type = "Conv"; Conv_Op() = delete; + Operator* clone() const override { + return new Conv_Op<DIM>(*static_cast<const Conv_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>; diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 95a2ff55b70dbed9299fb3dca98fb9b0e700d210..88013c4ea65b0f7ab874dcf2d04e45d9957881a9 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -47,6 +47,9 @@ class ConvDepthWise_Op : public Operator, static constexpr const char *Type = "ConvDepthWise"; ConvDepthWise_Op() = delete; + Operator* clone() const override { + return new ConvDepthWise_Op<DIM>(*static_cast<const ConvDepthWise_Op<DIM>*>(this)); + } using Parameterizable_ = Parameterizable<ConvDepthWiseParam, std::array<DimSize_t, DIM>, diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index db92dc9c735416d250fa32e2f9010b21b8f808c0..244b0322a34883b10b9f68b716a546b4b43857ff 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -43,6 +43,9 @@ public: static constexpr const char* Type = "FC"; FC_Op() = delete; + Operator* clone() const override { + return new FC_Op(*static_cast<const FC_Op*>(this)); + } using Parameterizable_ = Parameterizable<FCParam, DimSize_t, bool>; template <FCParam e> using param = typename Parameterizable_::template param<e>; diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index dab5df9a8f2d1e7d2cd680703d70e38d564c2564..ba56746acbda1421661843596af05a17f1014638 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -48,6 +48,9 @@ class GenericOperator_Op mOutputs[i] = std::make_shared<Tensor>(); } } + Operator* clone() const override { + return new GenericOperator_Op(*static_cast<const GenericOperator_Op*>(this)); + } /** * @brief Get the Parameter object identified by its name. diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 1dff2550a42245351afab5b8bb1a708a8d0d8c0b..e3476d3fdde5802e9dc536674a6a1c5eb5a016b1 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -41,6 +41,9 @@ public: static constexpr const char* Type = "LeakyReLU"; LeakyReLU_Op() = delete; + Operator* clone() const override { + return new LeakyReLU_Op(*static_cast<const LeakyReLU_Op*>(this)); + } using Parameterizable_ = Parameterizable<LeakyReLUParam, float>; template <LeakyReLUParam e> using param = typename Parameterizable_::template param<e>; diff --git a/include/aidge/operator/Matmul.hpp b/include/aidge/operator/Matmul.hpp index 639b366912060b3e085510f312d94568e6b65f03..5dbae2e70a189d1cf42c4df337da0d6fe98f72d2 100644 --- a/include/aidge/operator/Matmul.hpp +++ b/include/aidge/operator/Matmul.hpp @@ -42,6 +42,9 @@ public: static constexpr const char* Type = "Matmul"; Matmul_Op() = delete; + Operator* clone() const override { + return new Matmul_Op(*static_cast<const Matmul_Op*>(this)); + } using Parameterizable_ = Parameterizable<MatmulParam, DimSize_t>; template <MatmulParam e> using param = typename Parameterizable_::template param<e>; diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 35a59b56cbf5c10a78116f81de96a8baddc03ff0..f2bd0011874ba2ef28d9b41eb06d272a8e0ddd24 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -21,6 +21,9 @@ public: : Operator("MetaOp") { } + Operator* clone() const override { + return new MetaOperator(*static_cast<const MetaOperator*>(this)); + } ~MetaOperator() = default; }; } diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 30e1ce2a7f664485077282405ec60ddf49513cb5..585096bebfb7dea6f491d550c841a20c4a9d532e 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -33,8 +33,16 @@ private: public: Operator() = delete; Operator(const char* type) : mType(type) {} + virtual Operator* clone() const = 0; virtual ~Operator(); + Operator(const Operator& op): + std::enable_shared_from_this<Operator>() + { + mType = op.mType; + // mImpl is not set right now. + // TODO: clone the impl as well? + } public: diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index acdc69b69ab86b25a11d889980b9236e41928316..e8e831e150557173918d82d2769e6e08f34ef262 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -44,6 +44,10 @@ public: mOutput->resize(dims); } + Operator* clone() const override { + return new Producer_Op(*static_cast<const Producer_Op*>(this)); + } + Producer_Op(const std::shared_ptr<Tensor> tensor) : Operator(Type), mOutput(tensor) diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 141bd3ae12c7875a90d2549a24e5c141f3ff6aba..b3983557cb7ca8519094fdab820ca7c6a2d6b4c5 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -41,6 +41,9 @@ public: { setDatatype(DataType::Float32); } + Operator* clone() const override { + return new ReLU_Op(*static_cast<const ReLU_Op*>(this)); + } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 64e713b331bbbbf612ee5102ba0ea82fb108350e..d6ba3f1fcf64f3fa3f48bccd23ef21b5deada26e 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -41,6 +41,9 @@ public: { setDatatype(DataType::Float32); } + Operator* clone() const override { + return new Softmax_Op(*static_cast<const Softmax_Op*>(this)); + } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 7a4dc54c30cf0f843c421effe355da65b4d89815..a2fb2808c4d392fdd4486f4240d46f5f525f7d72 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -332,13 +332,13 @@ Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { std::shared_ptr<Operator> op = (op->type() == Producer_Op::Type) ? mOperator - : std::make_shared<Operator>(*mOperator); + : std::shared_ptr<Operator>(mOperator->clone()); return std::make_shared<Node>(op, mName); } Aidge::NodePtr Aidge::Node::clone() const { - return std::make_shared<Node>(std::make_shared<Operator>(*mOperator), mName); + return std::make_shared<Node>(std::shared_ptr<Operator>(mOperator->clone()), mName); } /////////////////////////////////////////////////////////////////////////////////////////////