From 59c978510a231585f8b7c3f04c55742a211d51e7 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Thu, 11 Jan 2024 10:14:41 +0100 Subject: [PATCH] add broadcasting for Add operator --- include/aidge/operator/Add.hpp | 8 +---- src/operator/Add.cpp | 61 +++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 9aed8299a..97a4ef69b 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -68,13 +68,7 @@ public: // } - // void checkDims() const override final { - // assert(outputDimsForwarded()); - // for (const auto& in : mInputs) { - // assert(in->dims() == mOutputs[0]->dims()); - // } - // } - + void computeOutputDims() override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Add_Op>::create(name)(*this); diff --git a/src/operator/Add.cpp b/src/operator/Add.cpp index 4e638fd86..fd9bdaa83 100644 --- a/src/operator/Add.cpp +++ b/src/operator/Add.cpp @@ -9,8 +9,67 @@ * ********************************************************************************/ +#include <cassert> +#include <cstddef> #include <string> +#include <vector> #include "aidge/operator/Add.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::Add_Op::Type = "Add"; \ No newline at end of file +const std::string Aidge::Add_Op::Type = "Add"; + +void Aidge::Add_Op::computeOutputDims() { + // check inputs have been associated + bool associated = (nbInputs() > 0); // do not compute anything if no input + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + if (!getInput(i)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); + } + associated &= !(getInput(i)->empty()); + } + if (associated) { + std::vector<std::vector<std::size_t>> inputsDims; + for (std::size_t i = 0; i < nbInputs(); i++) + { + inputsDims.push_back(getInput(i)->dims()); + } + + std::size_t outNbDims = 1; + + for(size_t i=0; i<inputsDims.size() ; ++i) + outNbDims = inputsDims[i].size()>outNbDims?inputsDims[i].size():outNbDims; + + std::vector<std::size_t> outDims(outNbDims, 1); + + std::vector<std::size_t>::iterator it = outDims.end(); + while (it != outDims.begin()) + { + --it; + for (size_t i = 0; i < inputsDims.size(); i++) + { + if(!inputsDims[i].empty()) + { + std::size_t dim = inputsDims[i].back(); + inputsDims[i].pop_back(); + if (*it != dim) + { + if(dim != 1) + { + if (*it != 1) + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Add operation"); + } + else + { + *it = dim; + } + } + } + } + } + } + mOutputs[0]->resize(outDims); + } +} -- GitLab