From 712235192f91f42307770accdb8a9742839e88a5 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 21 Nov 2023 15:21:01 +0100 Subject: [PATCH] fix python binding of concat by adding nb_in attr --- include/aidge/operator/Concat.hpp | 52 +++++++++++------------ python_binding/operator/pybind_Concat.cpp | 2 +- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 2daf876b9..7a090e2cd 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -48,11 +48,15 @@ public: using Attributes_ = StaticAttributes<ConcatAttr, int>; template <ConcatAttr e> using attr = typename Attributes_::template attr<e>; - Concat_Op(int axis) + Concat_Op(int axis, IOIndex_t nbIn) : Operator(Type), - Attributes_( - attr<ConcatAttr::Axis>(axis)) + mNbIn(nbIn), + Attributes_(attr<ConcatAttr::Axis>(axis)) { + mInputs = std::vector<std::shared_ptr<Tensor>>(nbIn); + for (std::size_t i = 0; i < nbIn; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } setDatatype(DataType::Float32); } @@ -67,12 +71,12 @@ public: mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor - setDatatype(op.mOutput->dataType()); mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; - mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn); - for (std::size_t i = 0; i < mNbIn; ++i) { + mInputs = std::vector<std::shared_ptr<Tensor>>(op.mNbIn); + for (std::size_t i = 0; i < op.mNbIn; ++i) { mInputs[i] = std::make_shared<Tensor>(); } + setDatatype(op.mOutput->dataType()); } /** @@ -84,30 +88,25 @@ public: } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - // assert(inputIdx < mNbIn && "operators supports only x inputs"); - - if (strcmp(data->type(), Tensor::Type) == 0) { - // TODO: associate input only if of type Tensor, otherwise do nothing - if(inputIdx<mInputs.size()) - mInputs.insert( mInputs.begin() + inputIdx, std::dynamic_pointer_cast<Tensor>(data)); - else - mInputs.emplace_back(std::dynamic_pointer_cast<Tensor>(data)); - - mNbIn = mInputs.size(); - } + assert(inputIdx < mNbIn && "index out of bound"); + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } void computeOutputDims() override final { if (!mInputs.empty() && !mInputs[0]->empty()) { - // mOutput->resize(mInputs[0]->dims()); - Concat_Op::Attrs attr = getStaticAttributes(); const int& axis = static_cast<const int&>(std::get<0>(attr)); + std::size_t dimOnAxis = 0; + for(std::size_t i=0; i<mNbIn; ++i) + { + dimOnAxis += mInputs[i]->dims()[axis]; + } std::vector<DimSize_t> outputDims; for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) { if(i==axis) - outputDims.push_back(mInputs.size() * mInputs[0]->dims()[i]); + outputDims.push_back(dimOnAxis); else outputDims.push_back(mInputs[0]->dims()[i]); } @@ -121,8 +120,7 @@ public: inline Tensor& input(const IOIndex_t inputIdx) const override final { - assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator"); - printf("Info: using input() on a GenericOperator.\n"); + assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator"); return *mInputs[inputIdx]; } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } @@ -133,7 +131,7 @@ public: return mInputs[inputIdx]; } inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { - assert((outputIdx == 0) && "Concat Operator has only 1 output"); + assert((outputIdx == 0) && "Concat operator has only 1 output"); (void) outputIdx; // avoid unused warning return mOutput; } @@ -143,7 +141,7 @@ public: return std::static_pointer_cast<Data>(mInputs[inputIdx]); } std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { - assert(outputIdx == 0 && "operator supports only 1 output"); + assert(outputIdx == 0 && "Concat operator supports only 1 output"); (void) outputIdx; // avoid unused warning return std::static_pointer_cast<Data>(mOutput); } @@ -172,15 +170,15 @@ public: inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } static const std::vector<std::string> getInputsName(){ - return {"data_input"}; + return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } }; -inline std::shared_ptr<Node> Concat(int axis, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Concat_Op>(axis), name); +inline std::shared_ptr<Node> Concat(int axis, IOIndex_t nbIn, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Concat_Op>(axis, nbIn), name); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Concat.cpp b/python_binding/operator/pybind_Concat.cpp index a3f78a4d1..9e587f0f0 100644 --- a/python_binding/operator/pybind_Concat.cpp +++ b/python_binding/operator/pybind_Concat.cpp @@ -23,6 +23,6 @@ void init_Concat(py::module& m) { .def("get_inputs_name", &Concat_Op::getInputsName) .def("get_outputs_name", &Concat_Op::getOutputsName); - m.def("Concat", &Concat, py::arg("axis"), py::arg("name") = ""); + m.def("Concat", &Concat, py::arg("axis"), py::arg("nb_in"), py::arg("name") = ""); } } // namespace Aidge -- GitLab