Skip to content
Snippets Groups Projects
Commit 59c97851 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

add broadcasting for Add operator

parent 6424edc9
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!65[Add] broadcasting for Arithmetic Operators
......@@ -68,13 +68,7 @@ public:
// }
// void checkDims() const override final {
// assert(outputDimsForwarded());
// for (const auto& in : mInputs) {
// assert(in->dims() == mOutputs[0]->dims());
// }
// }
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Add_Op>::create(name)(*this);
......
......@@ -9,8 +9,67 @@
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <string>
#include <vector>
#include "aidge/operator/Add.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::Add_Op::Type = "Add";
\ No newline at end of file
const std::string Aidge::Add_Op::Type = "Add";
void Aidge::Add_Op::computeOutputDims() {
// check inputs have been associated
bool associated = (nbInputs() > 0); // do not compute anything if no input
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
associated &= !(getInput(i)->empty());
}
if (associated) {
std::vector<std::vector<std::size_t>> inputsDims;
for (std::size_t i = 0; i < nbInputs(); i++)
{
inputsDims.push_back(getInput(i)->dims());
}
std::size_t outNbDims = 1;
for(size_t i=0; i<inputsDims.size() ; ++i)
outNbDims = inputsDims[i].size()>outNbDims?inputsDims[i].size():outNbDims;
std::vector<std::size_t> outDims(outNbDims, 1);
std::vector<std::size_t>::iterator it = outDims.end();
while (it != outDims.begin())
{
--it;
for (size_t i = 0; i < inputsDims.size(); i++)
{
if(!inputsDims[i].empty())
{
std::size_t dim = inputsDims[i].back();
inputsDims[i].pop_back();
if (*it != dim)
{
if(dim != 1)
{
if (*it != 1)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Add operation");
}
else
{
*it = dim;
}
}
}
}
}
}
mOutputs[0]->resize(outDims);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment