diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 9aed8299a67ab719141b6fe199ebf3f52fb7d387..97a4ef69bd371e80c4e63303feac5e64197670b3 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -68,13 +68,7 @@ public: // } - // void checkDims() const override final { - // assert(outputDimsForwarded()); - // for (const auto& in : mInputs) { - // assert(in->dims() == mOutputs[0]->dims()); - // } - // } - + void computeOutputDims() override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Add_Op>::create(name)(*this); diff --git a/src/operator/Add.cpp b/src/operator/Add.cpp index 4e638fd86da487565a89760925e45339213fa8f9..fd9bdaa8326a3460ce1e986fb64a0a1087786a7a 100644 --- a/src/operator/Add.cpp +++ b/src/operator/Add.cpp @@ -9,8 +9,67 @@ * ********************************************************************************/ +#include <cassert> +#include <cstddef> #include <string> +#include <vector> #include "aidge/operator/Add.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::Add_Op::Type = "Add"; \ No newline at end of file +const std::string Aidge::Add_Op::Type = "Add"; + +void Aidge::Add_Op::computeOutputDims() { + // check inputs have been associated + bool associated = (nbInputs() > 0); // do not compute anything if no input + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + if (!getInput(i)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); + } + associated &= !(getInput(i)->empty()); + } + if (associated) { + std::vector<std::vector<std::size_t>> inputsDims; + for (std::size_t i = 0; i < nbInputs(); i++) + { + inputsDims.push_back(getInput(i)->dims()); + } + + std::size_t outNbDims = 1; + + for(size_t i=0; i<inputsDims.size() ; ++i) + outNbDims = inputsDims[i].size()>outNbDims?inputsDims[i].size():outNbDims; + + std::vector<std::size_t> outDims(outNbDims, 1); + + std::vector<std::size_t>::iterator it = outDims.end(); + while (it != outDims.begin()) + { + --it; + for (size_t i = 0; i < inputsDims.size(); i++) + { + if(!inputsDims[i].empty()) + { + std::size_t dim = inputsDims[i].back(); + inputsDims[i].pop_back(); + if (*it != dim) + { + if(dim != 1) + { + if (*it != 1) + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Add operation"); + } + else + { + *it = dim; + } + } + } + } + } + } + mOutputs[0]->resize(outDims); + } +}