Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Gather.cpp 5.69 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cstddef>  // std::size_t
#include <cstdint>  // std::int64_t
#include <string>
#include <vector>

#include "aidge/operator/Gather.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"

const std::string Aidge::Gather_Op::Type = "Gather";


Aidge::Gather_Op::Gather_Op(std::int8_t axis,
              const std::vector<int64_t>& indices,
              const std::vector<Aidge::DimSize_t>& gatheredShape)
    : OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData}, 1),
    mAttributes(std::make_shared<Attributes_>(
        attr<GatherAttr::Axis>(axis),
        attr<GatherAttr::Indices>(indices),
        attr<GatherAttr::GatheredShape>(gatheredShape)))
{
    mImpl = std::make_shared<Gather_OpImpl>(*this);
}

Aidge::Gather_Op::Gather_Op(const Aidge::Gather_Op& op)
    : OperatorTensor(op), mAttributes(op.mAttributes)
{
    if (!op.backend().empty()) {
        SET_IMPL_MACRO(Gather_Op, *this, op.backend());
    }
    else {
        mImpl = std::make_shared<Gather_OpImpl>(*this);
    }
}

std::shared_ptr<Aidge::Operator> Aidge::Gather_Op::clone() const {
    return std::make_shared<Gather_Op>(*this);
}

void Aidge::Gather_OpImpl::forward() {
    const Gather_Op& op = dynamic_cast<const Gather_Op&>(mOp);

    const std::size_t axisIdx = static_cast<std::size_t>(op.axis()) + (op.axis() >= 0 ? 0 : op.getInput(0)->dims().size());

    std::size_t postAxisElems = 1;
    for (std::size_t i = axisIdx + 1; i < op.getInput(0)->dims().size(); ++i) {
        postAxisElems *= op.getInput(0)->dims()[i];
    }
    std::size_t preAxisElems = 1;
    for (std::size_t i = 0; i < axisIdx; ++i) {
        preAxisElems *= op.getInput(0)->dims()[i];
    }

    std::size_t outputOffset = 0;
    for (std::size_t i=0; i<preAxisElems; ++i)
    {
        for(std::size_t j = 0; j < op.indices().size(); ++j)
        {
            const std::size_t idx = op.indices()[j] >= 0 ?
                                        static_cast<std::size_t>(op.indices()[j]) :
                                        static_cast<std::size_t>(op.indices()[j] + static_cast<int>(op.getInput(0)->dims()[axisIdx]));
            op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(i * postAxisElems * op.getInput(0)->dims()[axisIdx] + idx * postAxisElems), postAxisElems, outputOffset);
            outputOffset += postAxisElems;
        }
    }
}

bool Aidge::Gather_Op::dimsForwarded() const {
    if (getInput(1) && !getInput(1)->undefined()) {
        // output dims are data dependent
        return false;
    }

    return OperatorTensor::dimsForwarded();
}

bool Aidge::Gather_Op::forwardDims(bool allowDataDependency) {
    if (inputsAssociated()) {
        // Copy optional input #1, if present, to attribute Indices
        if (getInput(1)) {
            if (!this->indices().empty()) {
                Log::notice("Gather_Op: ignoring non-empty Indices attribute because input#1 takes precedence");
            }

            if (!allowDataDependency) {
                Log::warn("Gather_Op: unable to forwardDims() because output dims are data dependent on input#1");
                return false;
            }

            std::shared_ptr<Tensor> fallback;
            this->gatheredShape() = getInput(1)->dims();
            this->indices().clear(); // If both are provided input would override attrs
            this->indices().reserve(getInput(1)->size());
            const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
            std::copy_n(static_cast<int64_t*>(indices.getImpl()->hostPtr()),
                        indices.size(),
                        std::back_inserter(this->indices()));
        }

        AIDGE_ASSERT(!this->indices().empty(), "Missing input#1 or Indices attribute");

        // Compute output dims
        std::vector<DimSize_t> outDims = getInput(0)->dims();

        std::int8_t axisIdx = this->axis()>=0?
                                this->axis():
                                this->axis()+outDims.size();
        outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
        if( !this->gatheredShape().empty())
        {
            outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx),
                            this->gatheredShape().begin(),
                            this->gatheredShape().end());
        }
        mOutputs[0]->resize(outDims);
        return true;
    }

    return false;
}

void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
    if (Registrar<Gather_Op>::exists({name})) {
        SET_IMPL_MACRO(Gather_Op, *this, name);
    }
    else {
        mImpl = std::make_shared<Gather_OpImpl>(*this);
    }
    mOutputs[0]->setBackend(name, device);
}

std::set<std::string> Aidge::Gather_Op::getAvailableBackends() const {
    return Registrar<Gather_Op>::getKeys();
}

/////////////////////////////////////////

std::shared_ptr<Aidge::Node> Aidge::Gather(std::int8_t axis,
                                        const std::vector<int64_t>& indices,
                                        const std::vector<Aidge::DimSize_t>& gatheredShape,
                                        const std::string& name) {
    return std::make_shared<Node>(std::make_shared<Gather_Op>(axis, indices, gatheredShape), name);
}