From 2ac231d8a2e633b4439607a6065ebe919c264d29 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 19 Oct 2023 08:35:44 +0000
Subject: [PATCH] [Upd] Add Operator to support an undefined number of inputs.
 Remove templates

---
 include/aidge/aidge.hpp                 |  2 +-
 include/aidge/operator/Add.hpp          | 66 ++++++++++++-------------
 python_binding/operator/pybind_Add.cpp  | 12 ++---
 unit_tests/recipies/Test_FuseMulAdd.cpp |  4 +-
 4 files changed, 41 insertions(+), 43 deletions(-)

diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp
index 16fa9967c..8a1b50a0e 100644
--- a/include/aidge/aidge.hpp
+++ b/include/aidge/aidge.hpp
@@ -33,7 +33,7 @@
 #include "aidge/operator/Add.hpp"
 #include "aidge/operator/AvgPooling.hpp"
 #include "aidge/operator/BatchNorm.hpp"
-#include "aidge/operator/Concat.hpp"
+// #include "aidge/operator/Concat.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/ConvDepthWise.hpp"
 #include "aidge/operator/FC.hpp"
diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp
index 65c7e8ce0..ceb058dbd 100644
--- a/include/aidge/operator/Add.hpp
+++ b/include/aidge/operator/Add.hpp
@@ -16,7 +16,7 @@
 #include <vector>
 #include <cmath>
 #include <memory>
-#include <array>
+#include <vector>
 
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/operator/Operator.hpp"
@@ -26,24 +26,23 @@
 
 namespace Aidge {
 
-template <std::size_t NUM>
 class Add_Op : public Operator,
-    public Registrable<Add_Op<NUM>, std::string, std::unique_ptr<OperatorImpl>(const Add_Op<NUM>&)> {
-public:
+    public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> {
+private:
     // FIXME: change accessibility
-    std::array<std::shared_ptr<Tensor>, NUM> mInputs;
+    std::vector<std::shared_ptr<Tensor>> mInputs;
     const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
+    const IOIndex_t mNbInputs;
 
 public:
     static constexpr const char* Type = "Add";
 
-    constexpr Add_Op()
-            : Operator(Type)
+    Add_Op(const IOIndex_t nbIn)
+        : Operator(Type),
+          mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())),
+          mNbInputs(nbIn)
     {
-        assert(NUM > 0 && "Add should have at least one input");
-        for (std::size_t i = 0; i<NUM; ++i) {
-            mInputs[i] = std::make_shared<Tensor>();
-        }
+        assert(nbIn > 0 && "Add should have at least one input");
         setDatatype(DataType::Float32);
     }
 
@@ -51,17 +50,16 @@ public:
      * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
      * @param op Operator to copy.
      */
-    Add_Op(const Add_Op<NUM>& op)
+    Add_Op(const Add_Op& op)
         : Operator(Type),
+          mInputs(op.mInputs),
+          mNbInputs(op.mNbInputs),
           mOutput(std::make_shared<Tensor>(*op.mOutput))
     {
         // cpy-ctor
-        assert(NUM > 0 && "Add should have at least one input");
-        for (std::size_t i = 0; i<NUM; ++i) {
-            mInputs[i] = std::make_shared<Tensor>();
-        }
+        assert(mNbInputs > 0 && "Add should have at least one input");
         setDatatype(op.mOutput->dataType());
-        mImpl = op.mImpl ? Registrar<Add_Op<NUM>>::create(mOutput->getImpl()->backend())(*this) : nullptr;
+        mImpl = op.mImpl ? Registrar<Add_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
     }
 
     /**
@@ -82,7 +80,7 @@ public:
     // }
 
     void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
-        assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
+        assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
         assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
 
         mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
@@ -92,10 +90,10 @@ public:
         if (!mInputs[0]->empty()) {
             const auto expectedDims =  mInputs[0]->dims();
             std::size_t nonEmptyInputTensor = 1;
-            for (; nonEmptyInputTensor<NUM && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) {
+            for (; nonEmptyInputTensor < mNbInputs && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) {
                 assert(expectedDims == mInputs[nonEmptyInputTensor]->dims());
             }
-            if (nonEmptyInputTensor == NUM) {
+            if (nonEmptyInputTensor == mNbInputs) {
                 mOutput->resize(expectedDims);
             }
         }
@@ -103,8 +101,8 @@ public:
 
     bool outputDimsForwarded() const override final {
         std::size_t forwarded = 0;
-        for (; forwarded < NUM && (!mInputs[forwarded]->empty()); ++forwarded) {}
-        return ((forwarded==NUM) && !(mOutput->empty()));
+        for (; forwarded < mNbInputs && (!mInputs[forwarded]->empty()); ++forwarded) {}
+        return ((forwarded==mNbInputs) && !(mOutput->empty()));
     }
 
     // void checkDims() const override final {
@@ -114,13 +112,13 @@ public:
     //     }
     // }
     inline Tensor& input(const IOIndex_t inputIdx) const override final {
-        assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
+        assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
         return *(mInputs[inputIdx].get());
     }
     inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
 
     inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
-        assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
+        assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
         return mInputs[inputIdx];
     }
     inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
@@ -130,7 +128,7 @@ public:
     }
 
     std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
-        assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
+        assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator.");
         return std::static_pointer_cast<Data>(mInputs[inputIdx]);
     }
     std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
@@ -141,11 +139,11 @@ public:
 
 
     void setBackend(const std::string& name) override {
-        mImpl = Registrar<Add_Op<NUM>>::create(name)(*this);
+        mImpl = Registrar<Add_Op>::create(name)(*this);
         mOutput->setBackend(name);
 
         // FIXME: temporary workaround
-        for (std::size_t i = 0; i < NUM; ++i) {
+        for (std::size_t i = 0; i < mNbInputs; ++i) {
             mInputs[i]->setBackend(name);
         }
     }
@@ -154,15 +152,16 @@ public:
         mOutput->setDatatype(datatype);
 
         // FIXME: temporary workaround
-        for (std::size_t i = 0; i < NUM; ++i) {
+        for (std::size_t i = 0; i < mNbInputs; ++i) {
             mInputs[i]->setDatatype(datatype);
         }
     }
 
-    inline IOIndex_t nbInputs() const noexcept override final { return NUM; }
-    inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; }
+    inline IOIndex_t nbInputs() const noexcept override final { return mNbInputs; }
+    inline IOIndex_t nbDataInputs() const noexcept override final { return mNbInputs; }
     inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
-        static const std::vector<std::string> getInputsName(){
+
+    static const std::vector<std::string> getInputsName(){
         return {"data_input_0", "data_input_n"};
     }
     static const std::vector<std::string> getOutputsName(){
@@ -170,9 +169,8 @@ public:
     }
 };
 
-template <std::size_t NUM>
-inline std::shared_ptr<Node> Add(const std::string& name = "") {
-    return std::make_shared<Node>(std::make_shared<Add_Op<NUM>>(), name);
+inline std::shared_ptr<Node> Add(const IOIndex_t nbIn, const std::string& name = "") {
+    return std::make_shared<Node>(std::make_shared<Add_Op>(nbIn), name);
 }
 }
 
diff --git a/python_binding/operator/pybind_Add.cpp b/python_binding/operator/pybind_Add.cpp
index 0b2323c5c..bff795a73 100644
--- a/python_binding/operator/pybind_Add.cpp
+++ b/python_binding/operator/pybind_Add.cpp
@@ -19,15 +19,15 @@
 namespace py = pybind11;
 namespace Aidge {
 
-template <std::size_t NUM> void declare_Add(py::module &m) {
-  py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "AddOp", py::multiple_inheritance())
-  .def("get_inputs_name", &Add_Op<NUM>::getInputsName)
-  .def("get_outputs_name", &Add_Op<NUM>::getOutputsName);
+void declare_Add(py::module &m) {
+  py::class_<Add_Op, std::shared_ptr<Add_Op>, Operator>(m, "AddOp", py::multiple_inheritance())
+  .def("get_inputs_name", &Add_Op::getInputsName)
+  .def("get_outputs_name", &Add_Op::getOutputsName);
 
-  m.def("Add", &Add<NUM>, py::arg("name") = "");
+  m.def("Add", &Add, py::arg("nbIn"), py::arg("name") = "");
 }
 
 void init_Add(py::module &m) {
-  declare_Add<2>(m);
+  declare_Add(m);
 }
 } // namespace Aidge
diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp
index 92b2b7c13..7e64f3ff5 100644
--- a/unit_tests/recipies/Test_FuseMulAdd.cpp
+++ b/unit_tests/recipies/Test_FuseMulAdd.cpp
@@ -25,9 +25,9 @@ namespace Aidge {
 TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
     // generate the original GraphView
     auto matmul0 = MatMul(5, "matmul0");
-    auto add0 = Add<2>("add0");
+    auto add0 = Add(2, "add0");
     auto matmul1 = MatMul(5, "matmul1");
-    auto add1 = Add<2>("add1");
+    auto add1 = Add(2, "add1");
 
     auto b0 = Producer({5}, "B0");
     auto w0 = Producer({5, 5}, "W0");
-- 
GitLab