Skip to content
Snippets Groups Projects
Commit 933a48d5 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added Gather default implementation

parent 89d22270
No related branches found
No related tags found
3 merge requests!1190.2.1,!113Draft: Fix slice,!104Make forwardDims() optional and handle data dependency
Pipeline #43496 passed
......@@ -25,6 +25,12 @@
#include "aidge/utils/Types.h"
namespace Aidge {
class Gather_OpImpl : public OperatorImpl {
public:
Gather_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class GatherAttr { Indices, GatheredShape, Axis };
class Gather_Op : public OperatorTensor,
......@@ -46,7 +52,9 @@ public:
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape),
attr<GatherAttr::Axis>(axis))
{}
{
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......@@ -56,10 +64,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
if (op.mImpl){
if (!op.backend().empty()) {
SET_IMPL_MACRO(Gather_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
else {
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
}
......
......@@ -20,6 +20,35 @@
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Gather_OpImpl::forward() {
const Gather_Op& op = dynamic_cast<const Gather_Op&>(mOp);
const auto axis = op.template getAttr<std::int64_t>("Axis");
const std::size_t axisIdx = axis>=0 ?
axis :
static_cast<std::size_t>(axis) + op.getInput(0)->dims().size();
std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < op.getInput(0)->dims().size(); ++i) {
postAxisElems *= op.getInput(0)->dims()[i];
}
std::size_t preAxisElems = 1;
for (std::size_t i = 0; i < axisIdx; ++i) {
preAxisElems *= op.getInput(0)->dims()[i];
}
const auto indices = op.template getAttr<std::vector<std::int64_t>>("Indices");
std::size_t outputOffset = 0;
for (std::size_t i=0; i<preAxisElems; ++i)
{
for(std::size_t j=0; j<indices.size(); ++j)
{
const std::size_t idx = indices[j] >= 0 ? indices[j] : static_cast<std::size_t>(indices[j]) + op.getInput(0)->dims()[axisIdx];
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(i * postAxisElems * op.getInput(0)->dims()[axisIdx] + idx * postAxisElems), postAxisElems, outputOffset);
outputOffset += postAxisElems;
}
}
}
const std::string Aidge::Gather_Op::Type = "Gather";
......@@ -53,6 +82,11 @@ bool Aidge::Gather_Op::computeOutputDims(bool /*allowDataDependency*/) {
}
void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(Gather_Op, *this, name);
if (Registrar<Gather_Op>::exists({name})) {
SET_IMPL_MACRO(Gather_Op, *this, name);
}
else {
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment