Skip to content
Snippets Groups Projects
Commit 48ecdca1 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Added Gather default implementation

parent f616e3d2
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,12 @@
#include "aidge/utils/Types.h"
namespace Aidge {
class Gather_OpImpl : public OperatorImpl {
public:
Gather_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class GatherAttr { Indices, GatheredShape, Axis };
class Gather_Op : public OperatorTensor,
......@@ -46,7 +52,9 @@ public:
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape),
attr<GatherAttr::Axis>(axis))
{}
{
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......@@ -56,10 +64,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
if (op.mImpl){
if (!op.backend().empty()) {
SET_IMPL_MACRO(Gather_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
else {
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
}
......
......@@ -20,6 +20,35 @@
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Gather_OpImpl::forward() {
const Gather_Op& op = dynamic_cast<const Gather_Op&>(mOp);
const auto axis = op.template getAttr<std::int64_t>("Axis");
const std::size_t axisIdx = axis>=0 ?
axis :
static_cast<std::size_t>(axis) + op.getInput(0)->dims().size();
std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < op.getInput(0)->dims().size(); ++i) {
postAxisElems *= op.getInput(0)->dims()[i];
}
std::size_t preAxisElems = 1;
for (std::size_t i = 0; i < axisIdx; ++i) {
preAxisElems *= op.getInput(0)->dims()[i];
}
const auto indices = op.template getAttr<std::vector<std::int64_t>>("Indices");
std::size_t outputOffset = 0;
for (std::size_t i=0; i<preAxisElems; ++i)
{
for(std::size_t j=0; j<indices.size(); ++j)
{
const std::size_t idx = indices[j] >= 0 ? indices[j] : static_cast<std::size_t>(indices[j]) + op.getInput(0)->dims()[axisIdx];
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(i * postAxisElems * op.getInput(0)->dims()[axisIdx] + idx * postAxisElems), postAxisElems, outputOffset);
outputOffset += postAxisElems;
}
}
}
const std::string Aidge::Gather_Op::Type = "Gather";
......@@ -53,6 +82,11 @@ bool Aidge::Gather_Op::computeOutputDims(bool /*allowDataDependency*/) {
}
void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(Gather_Op, *this, name);
if (Registrar<Gather_Op>::exists({name})) {
SET_IMPL_MACRO(Gather_Op, *this, name);
}
else {
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment