-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Gather.cpp 5.69 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <string>
#include <vector>
#include "aidge/operator/Gather.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Gather_Op::Type = "Gather";
Aidge::Gather_Op::Gather_Op(std::int8_t axis,
const std::vector<int64_t>& indices,
const std::vector<Aidge::DimSize_t>& gatheredShape)
: OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<GatherAttr::Axis>(axis),
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape)))
{
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
Aidge::Gather_Op::Gather_Op(const Aidge::Gather_Op& op)
: OperatorTensor(op), mAttributes(op.mAttributes)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Gather_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
}
std::shared_ptr<Aidge::Operator> Aidge::Gather_Op::clone() const {
return std::make_shared<Gather_Op>(*this);
}
void Aidge::Gather_OpImpl::forward() {
const Gather_Op& op = dynamic_cast<const Gather_Op&>(mOp);
const std::size_t axisIdx = static_cast<std::size_t>(op.axis()) + (op.axis() >= 0 ? 0 : op.getInput(0)->dims().size());
std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < op.getInput(0)->dims().size(); ++i) {
postAxisElems *= op.getInput(0)->dims()[i];
}
std::size_t preAxisElems = 1;
for (std::size_t i = 0; i < axisIdx; ++i) {
preAxisElems *= op.getInput(0)->dims()[i];
}
std::size_t outputOffset = 0;
for (std::size_t i=0; i<preAxisElems; ++i)
{
for(std::size_t j = 0; j < op.indices().size(); ++j)
{
const std::size_t idx = op.indices()[j] >= 0 ?
static_cast<std::size_t>(op.indices()[j]) :
static_cast<std::size_t>(op.indices()[j] + static_cast<int>(op.getInput(0)->dims()[axisIdx]));
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(i * postAxisElems * op.getInput(0)->dims()[axisIdx] + idx * postAxisElems), postAxisElems, outputOffset);
outputOffset += postAxisElems;
}
}
}
bool Aidge::Gather_Op::dimsForwarded() const {
if (getInput(1) && !getInput(1)->undefined()) {
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Gather_Op::forwardDims(bool allowDataDependency) {
if (inputsAssociated()) {
// Copy optional input #1, if present, to attribute Indices
if (getInput(1)) {
if (!this->indices().empty()) {
Log::notice("Gather_Op: ignoring non-empty Indices attribute because input#1 takes precedence");
}
if (!allowDataDependency) {
Log::warn("Gather_Op: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
std::shared_ptr<Tensor> fallback;
this->gatheredShape() = getInput(1)->dims();
this->indices().clear(); // If both are provided input would override attrs
this->indices().reserve(getInput(1)->size());
const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(indices.getImpl()->hostPtr()),
indices.size(),
std::back_inserter(this->indices()));
}
AIDGE_ASSERT(!this->indices().empty(), "Missing input#1 or Indices attribute");
// Compute output dims
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::int8_t axisIdx = this->axis()>=0?
this->axis():
this->axis()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if( !this->gatheredShape().empty())
{
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx),
this->gatheredShape().begin(),
this->gatheredShape().end());
}
mOutputs[0]->resize(outDims);
return true;
}
return false;
}
void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Gather_Op>::exists({name})) {
SET_IMPL_MACRO(Gather_Op, *this, name);
}
else {
mImpl = std::make_shared<Gather_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Gather_Op::getAvailableBackends() const {
return Registrar<Gather_Op>::getKeys();
}
/////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Gather(std::int8_t axis,
const std::vector<int64_t>& indices,
const std::vector<Aidge::DimSize_t>& gatheredShape,
const std::string& name) {
return std::make_shared<Node>(std::make_shared<Gather_Op>(axis, indices, gatheredShape), name);
}