diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 8d9114c17b4a5692f04d90ceec725858ecace0a7..01d590aa7425cb62ab665c0078019a6c8ab4a66a 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -12,7 +12,6 @@ #ifndef AIDGE_CORE_OPERATOR_CONCAT_H_ #define AIDGE_CORE_OPERATOR_CONCAT_H_ -#include <cassert> #include <numeric> #include <vector> #include <cmath> @@ -67,29 +66,41 @@ public: return std::make_shared<Concat_Op>(*this); } + // Data operator[](const char* inputName) override final { + // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : + // (strcmp(inputName, "weight") ? mInputs[1] : + // (strcmp(inputName, "bias") ? mInputs[2] : + // nullptr)); + // assert((in!=nullptr) && "No such parameter"); + // return *in; + // } + void computeOutputDims() override final { - if (!mInputs.empty() && !mInputs[0]->empty()) - { - Concat_Op::Attrs attr = getStaticAttributes(); - const int& axis = static_cast<const int&>(std::get<0>(attr)); - std::size_t dimOnAxis = 0; - for(std::size_t i=0; i< mInputs.size(); ++i) - { - dimOnAxis += mInputs[i]->dims()[axis]; + // Every input is non-empty with the same number of dimensions + bool associated = (getInput(0) != nullptr); + associated &= !(getInput(0)->empty()) && (getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims()); // do not compute anything if no input + auto outputDims = getInput(0)->dims(); + const auto firstInputNbDims = getInput(0) -> nbDims(); + for (IOIndex_t i = 1; i < nbInputs(); ++i) { + if (!getInput(i)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); } - std::vector<DimSize_t> outputDims; - for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) { - if(i==axis) - outputDims.push_back(dimOnAxis); - else - outputDims.push_back(mInputs[0]->dims()[i]); + associated &= (getInput(i)->nbDims() == firstInputNbDims); + for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) { + if (dim == getAttr<ConcatAttr::Axis>()) { + outputDims[dim] += getInput(i)->dims()[dim]; + } + else { + associated &= (getInput(i)->dims()[dim] == outputDims[dim]); + } } - mOutputs[0]->resize(outputDims); + } + if (associated) { + getOutput(0)->resize(outputDims); } } - void setBackend(const std::string& name) override { mImpl = Registrar<Concat_Op>::create(name)(*this); mOutputs[0]->setBackend(name); @@ -108,14 +119,16 @@ public: } }; -inline std::shared_ptr<Node> Concat(IOIndex_t nbIn, int axis, const std::string& name = "") { +inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis = 0, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Concat_Op>(nbIn, axis), name); } -} // namespace Aidge +} namespace { -template <> -const char *const EnumStrings<Aidge::ConcatAttr>::data[] = {"Axis"}; + template <> + const char* const EnumStrings<Aidge::ConcatAttr>::data[] = { + "Axis" + }; } #endif /* AIDGE_CORE_OPERATOR_CONCAT_H_ */ diff --git a/python_binding/operator/pybind_Concat.cpp b/python_binding/operator/pybind_Concat.cpp index 701afa3e1280c9f0ba4eb09292e3fd2a5e8060c4..2b7e5d6b99194e914e48dc6263d0bdcd6a4a8a2f 100644 --- a/python_binding/operator/pybind_Concat.cpp +++ b/python_binding/operator/pybind_Concat.cpp @@ -23,6 +23,6 @@ void init_Concat(py::module& m) { .def("get_inputs_name", &Concat_Op::getInputsName) .def("get_outputs_name", &Concat_Op::getOutputsName); - m.def("Concat", &Concat, py::arg("axis"), py::arg("nb_in"), py::arg("name") = ""); + m.def("Concat", &Concat, py::arg("nbIn"), py::arg("axis"), py::arg("name") = ""); } } // namespace Aidge