From 712235192f91f42307770accdb8a9742839e88a5 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 21 Nov 2023 15:21:01 +0100
Subject: [PATCH] fix python binding of concat by adding nb_in attr

---
 include/aidge/operator/Concat.hpp         | 52 +++++++++++------------
 python_binding/operator/pybind_Concat.cpp |  2 +-
 2 files changed, 26 insertions(+), 28 deletions(-)

diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp
index 2daf876b9..7a090e2cd 100644
--- a/include/aidge/operator/Concat.hpp
+++ b/include/aidge/operator/Concat.hpp
@@ -48,11 +48,15 @@ public:
     
     using Attributes_ = StaticAttributes<ConcatAttr, int>;
     template <ConcatAttr e> using attr = typename Attributes_::template attr<e>;
-    Concat_Op(int axis)
+    Concat_Op(int axis, IOIndex_t nbIn)
             : Operator(Type),
-            Attributes_(
-                attr<ConcatAttr::Axis>(axis))
+              mNbIn(nbIn),
+              Attributes_(attr<ConcatAttr::Axis>(axis))
     {
+        mInputs = std::vector<std::shared_ptr<Tensor>>(nbIn);
+        for (std::size_t i = 0; i < nbIn; ++i) {
+            mInputs[i] = std::make_shared<Tensor>();
+        }
         setDatatype(DataType::Float32);
     }
 
@@ -67,12 +71,12 @@ public:
           mOutput(std::make_shared<Tensor>(*op.mOutput))
     {
         // cpy-ctor
-        setDatatype(op.mOutput->dataType());
         mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
-        mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn);
-        for (std::size_t i = 0; i < mNbIn; ++i) {
+        mInputs = std::vector<std::shared_ptr<Tensor>>(op.mNbIn);
+        for (std::size_t i = 0; i < op.mNbIn; ++i) {
             mInputs[i] = std::make_shared<Tensor>();
         }
+        setDatatype(op.mOutput->dataType());
     }
 
     /**
@@ -84,30 +88,25 @@ public:
     }
 
     void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
-        // assert(inputIdx < mNbIn && "operators supports only x inputs");
-        
-        if (strcmp(data->type(), Tensor::Type) == 0) {
-            // TODO: associate input only if of type Tensor, otherwise do nothing
-            if(inputIdx<mInputs.size())
-                mInputs.insert( mInputs.begin() + inputIdx, std::dynamic_pointer_cast<Tensor>(data));
-            else
-                mInputs.emplace_back(std::dynamic_pointer_cast<Tensor>(data));
-
-            mNbIn = mInputs.size();
-        }
+        assert(inputIdx < mNbIn && "index out of bound");
+        assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
+        mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
     }
 
     void computeOutputDims() override final {
         if (!mInputs.empty() && !mInputs[0]->empty())
         {
-            // mOutput->resize(mInputs[0]->dims());
-
             Concat_Op::Attrs attr = getStaticAttributes();
             const int& axis = static_cast<const int&>(std::get<0>(attr));
+            std::size_t dimOnAxis = 0;
+            for(std::size_t i=0; i<mNbIn; ++i)
+            {
+                dimOnAxis += mInputs[i]->dims()[axis];
+            }
             std::vector<DimSize_t> outputDims;
             for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) {
                 if(i==axis)
-                    outputDims.push_back(mInputs.size() * mInputs[0]->dims()[i]);
+                    outputDims.push_back(dimOnAxis);
                 else
                     outputDims.push_back(mInputs[0]->dims()[i]);
             }
@@ -121,8 +120,7 @@ public:
 
 
     inline Tensor& input(const IOIndex_t inputIdx) const override final {
-        assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator");
-        printf("Info: using input() on a GenericOperator.\n");
+        assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator");
         return *mInputs[inputIdx];
     }
     inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
@@ -133,7 +131,7 @@ public:
         return mInputs[inputIdx];
     }
     inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
-        assert((outputIdx == 0) && "Concat Operator has only 1 output");
+        assert((outputIdx == 0) && "Concat operator has only 1 output");
         (void) outputIdx; // avoid unused warning
         return mOutput;
     }
@@ -143,7 +141,7 @@ public:
         return std::static_pointer_cast<Data>(mInputs[inputIdx]);
     }
     std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
-        assert(outputIdx == 0 && "operator supports only 1 output");
+        assert(outputIdx == 0 && "Concat operator supports only 1 output");
         (void) outputIdx; // avoid unused warning
         return std::static_pointer_cast<Data>(mOutput);
     }
@@ -172,15 +170,15 @@ public:
     inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
     inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
     static const std::vector<std::string> getInputsName(){
-        return {"data_input"};
+        return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type
     }
     static const std::vector<std::string> getOutputsName(){
         return {"data_output"};
     }
 };
 
-inline std::shared_ptr<Node> Concat(int axis, const std::string& name = "") {
-    return std::make_shared<Node>(std::make_shared<Concat_Op>(axis), name);
+inline std::shared_ptr<Node> Concat(int axis, IOIndex_t nbIn, const std::string& name = "") {
+    return std::make_shared<Node>(std::make_shared<Concat_Op>(axis, nbIn), name);
 }
 } // namespace Aidge
 
diff --git a/python_binding/operator/pybind_Concat.cpp b/python_binding/operator/pybind_Concat.cpp
index a3f78a4d1..9e587f0f0 100644
--- a/python_binding/operator/pybind_Concat.cpp
+++ b/python_binding/operator/pybind_Concat.cpp
@@ -23,6 +23,6 @@ void init_Concat(py::module& m) {
     .def("get_inputs_name", &Concat_Op::getInputsName)
     .def("get_outputs_name", &Concat_Op::getOutputsName);
 
-    m.def("Concat", &Concat, py::arg("axis"), py::arg("name") = "");
+    m.def("Concat", &Concat, py::arg("axis"), py::arg("nb_in"), py::arg("name") = "");
 }
 }  // namespace Aidge
-- 
GitLab