From 75daa0dbc43ec9b80a8501934b108d14f1d35010 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 6 Sep 2023 18:40:43 +0200
Subject: [PATCH] Fixed operator clonning

---
 include/aidge/operator/Add.hpp             | 3 +++
 include/aidge/operator/AvgPooling.hpp      | 3 +++
 include/aidge/operator/BatchNorm.hpp       | 3 +++
 include/aidge/operator/Conv.hpp            | 3 +++
 include/aidge/operator/ConvDepthWise.hpp   | 3 +++
 include/aidge/operator/FC.hpp              | 3 +++
 include/aidge/operator/GenericOperator.hpp | 3 +++
 include/aidge/operator/LeakyReLU.hpp       | 3 +++
 include/aidge/operator/Matmul.hpp          | 3 +++
 include/aidge/operator/MetaOperator.hpp    | 3 +++
 include/aidge/operator/Operator.hpp        | 8 ++++++++
 include/aidge/operator/Producer.hpp        | 4 ++++
 include/aidge/operator/ReLU.hpp            | 3 +++
 include/aidge/operator/Softmax.hpp         | 3 +++
 src/graph/Node.cpp                         | 4 ++--
 15 files changed, 50 insertions(+), 2 deletions(-)

diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp
index ff3d1888c..5bc3ef0e1 100644
--- a/include/aidge/operator/Add.hpp
+++ b/include/aidge/operator/Add.hpp
@@ -47,6 +47,9 @@ public:
         }
         setDatatype(DataType::Float32);
     }
+    Operator* clone() const override {
+        return new Add_Op(*static_cast<const Add_Op*>(this));
+    }
 
     // Data operator[](const char* inputName) override final {
     //     std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp
index bf76bd458..197388ebc 100644
--- a/include/aidge/operator/AvgPooling.hpp
+++ b/include/aidge/operator/AvgPooling.hpp
@@ -44,6 +44,9 @@ public:
     static constexpr const char *Type = "AvgPooling";
 
     AvgPooling_Op() = delete;
+    Operator* clone() const override {
+        return new AvgPooling_Op<DIM>(*static_cast<const AvgPooling_Op<DIM>*>(this));
+    }
 
     using Parameterizable_ = Parameterizable<AvgPoolingParam,
                                              std::array<DimSize_t, DIM>,
diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp
index 6861c1359..68589c654 100644
--- a/include/aidge/operator/BatchNorm.hpp
+++ b/include/aidge/operator/BatchNorm.hpp
@@ -43,6 +43,9 @@ public:
     static constexpr const char *Type = "BatchNorm";
 
     BatchNorm_Op() = delete;
+    Operator* clone() const override {
+        return new BatchNorm_Op<DIM>(*static_cast<const BatchNorm_Op<DIM>*>(this));
+    }
 
     using Parameterizable_ = Parameterizable<BatchNormParam, float, float>;
     template <BatchNormParam e>
diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index 1edc94b96..11526916a 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -43,6 +43,9 @@ public:
     static constexpr const char *Type = "Conv";
 
     Conv_Op() = delete;
+    Operator* clone() const override {
+        return new Conv_Op<DIM>(*static_cast<const Conv_Op<DIM>*>(this));
+    }
 
     using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>,
                                              DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>;
diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp
index 95a2ff55b..88013c4ea 100644
--- a/include/aidge/operator/ConvDepthWise.hpp
+++ b/include/aidge/operator/ConvDepthWise.hpp
@@ -47,6 +47,9 @@ class ConvDepthWise_Op : public Operator,
     static constexpr const char *Type = "ConvDepthWise";
 
     ConvDepthWise_Op() = delete;
+    Operator* clone() const override {
+        return new ConvDepthWise_Op<DIM>(*static_cast<const ConvDepthWise_Op<DIM>*>(this));
+    }
 
     using Parameterizable_ = Parameterizable<ConvDepthWiseParam,
                                              std::array<DimSize_t, DIM>,
diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp
index db92dc9c7..244b0322a 100644
--- a/include/aidge/operator/FC.hpp
+++ b/include/aidge/operator/FC.hpp
@@ -43,6 +43,9 @@ public:
     static constexpr const char* Type = "FC";
 
     FC_Op() = delete;
+    Operator* clone() const override {
+        return new FC_Op(*static_cast<const FC_Op*>(this));
+    }
 
     using Parameterizable_ = Parameterizable<FCParam, DimSize_t, bool>;
     template <FCParam e> using param = typename Parameterizable_::template param<e>;
diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp
index dab5df9a8..ba56746ac 100644
--- a/include/aidge/operator/GenericOperator.hpp
+++ b/include/aidge/operator/GenericOperator.hpp
@@ -48,6 +48,9 @@ class GenericOperator_Op
             mOutputs[i] = std::make_shared<Tensor>();
         }
     }
+    Operator* clone() const override {
+        return new GenericOperator_Op(*static_cast<const GenericOperator_Op*>(this));
+    }
 
     /**
      * @brief Get the Parameter object identified by its name.
diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp
index 1dff2550a..e3476d3fd 100644
--- a/include/aidge/operator/LeakyReLU.hpp
+++ b/include/aidge/operator/LeakyReLU.hpp
@@ -41,6 +41,9 @@ public:
     static constexpr const char* Type = "LeakyReLU";
 
     LeakyReLU_Op() = delete;
+    Operator* clone() const override {
+        return new LeakyReLU_Op(*static_cast<const LeakyReLU_Op*>(this));
+    }
 
     using Parameterizable_ = Parameterizable<LeakyReLUParam, float>;
     template <LeakyReLUParam e> using param = typename Parameterizable_::template param<e>;
diff --git a/include/aidge/operator/Matmul.hpp b/include/aidge/operator/Matmul.hpp
index 639b36691..5dbae2e70 100644
--- a/include/aidge/operator/Matmul.hpp
+++ b/include/aidge/operator/Matmul.hpp
@@ -42,6 +42,9 @@ public:
     static constexpr const char* Type = "Matmul";
 
     Matmul_Op() = delete;
+    Operator* clone() const override {
+        return new Matmul_Op(*static_cast<const Matmul_Op*>(this));
+    }
 
     using Parameterizable_ = Parameterizable<MatmulParam, DimSize_t>;
     template <MatmulParam e> using param = typename Parameterizable_::template param<e>;
diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp
index 35a59b56c..f2bd00118 100644
--- a/include/aidge/operator/MetaOperator.hpp
+++ b/include/aidge/operator/MetaOperator.hpp
@@ -21,6 +21,9 @@ public:
         : Operator("MetaOp")
     {
     }
+    Operator* clone() const override {
+        return new MetaOperator(*static_cast<const MetaOperator*>(this));
+    }
     ~MetaOperator() = default;
 };
 }
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 30e1ce2a7..585096beb 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -33,8 +33,16 @@ private:
 public:
   Operator() = delete;
   Operator(const char* type) : mType(type) {}
+  virtual Operator* clone() const = 0;
   virtual ~Operator();
 
+  Operator(const Operator& op):
+    std::enable_shared_from_this<Operator>()
+  {
+    mType = op.mType;
+    // mImpl is not set right now.
+    // TODO: clone the impl as well?
+  }
 
 public:
 
diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp
index acdc69b69..e8e831e15 100644
--- a/include/aidge/operator/Producer.hpp
+++ b/include/aidge/operator/Producer.hpp
@@ -44,6 +44,10 @@ public:
         mOutput->resize(dims);
     }
 
+    Operator* clone() const override {
+        return new Producer_Op(*static_cast<const Producer_Op*>(this));
+    }
+
     Producer_Op(const std::shared_ptr<Tensor> tensor)
         : Operator(Type),
           mOutput(tensor)
diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp
index 141bd3ae1..b3983557c 100644
--- a/include/aidge/operator/ReLU.hpp
+++ b/include/aidge/operator/ReLU.hpp
@@ -41,6 +41,9 @@ public:
     {
         setDatatype(DataType::Float32);
     }
+    Operator* clone() const override {
+        return new ReLU_Op(*static_cast<const ReLU_Op*>(this));
+    }
 
     void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
         assert(inputIdx == 0 && "operator supports only 1 input");
diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp
index 64e713b33..d6ba3f1fc 100644
--- a/include/aidge/operator/Softmax.hpp
+++ b/include/aidge/operator/Softmax.hpp
@@ -41,6 +41,9 @@ public:
     {
         setDatatype(DataType::Float32);
     }
+    Operator* clone() const override {
+        return new Softmax_Op(*static_cast<const Softmax_Op*>(this));
+    }
 
     void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
         assert(inputIdx == 0 && "operator supports only 1 input");
diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp
index 7a4dc54c3..a2fb2808c 100644
--- a/src/graph/Node.cpp
+++ b/src/graph/Node.cpp
@@ -332,13 +332,13 @@ Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
 Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
     std::shared_ptr<Operator> op = (op->type() == Producer_Op::Type)
         ? mOperator
-        : std::make_shared<Operator>(*mOperator);
+        : std::shared_ptr<Operator>(mOperator->clone());
 
     return std::make_shared<Node>(op, mName);
 }
 
 Aidge::NodePtr Aidge::Node::clone() const {
-    return std::make_shared<Node>(std::make_shared<Operator>(*mOperator), mName);
+    return std::make_shared<Node>(std::shared_ptr<Operator>(mOperator->clone()), mName);
 }
 
 /////////////////////////////////////////////////////////////////////////////////////////////
-- 
GitLab