Skip to content
Snippets Groups Projects
Commit ee7ad790 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added Gather default implementation

parent e7a524b0
No related branches found
No related tags found
2 merge requests!610.2.2,!54Make forwardDims() optional and handle data dependency & moved several operators impl to core
Pipeline #43495 failed
......@@ -21,7 +21,6 @@
#include "aidge/backend/cpu/operator/DivImpl.hpp"
#include "aidge/backend/cpu/operator/ErfImpl.hpp"
#include "aidge/backend/cpu/operator/FCImpl.hpp"
#include "aidge/backend/cpu/operator/GatherImpl.hpp"
#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
#include "aidge/backend/cpu/operator/MatMulImpl.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_GATHERIMPL_H_
#define AIDGE_CPU_OPERATOR_GATHERIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Gather.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Gather_Op;
// compute kernel registry for forward and backward
class GatherImplForward_cpu
: public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> {
};
class GatherImplBackward_cpu
: public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> {
};
class GatherImpl_cpu : public OperatorImpl {
public:
GatherImpl_cpu(const Gather_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<GatherImpl_cpu> create(const Gather_Op& op) {
return std::make_unique<GatherImpl_cpu>(op);
}
void forward() override;
};
namespace {
static Registrar<Gather_Op> registrarGatherImpl_cpu("cpu", Aidge::GatherImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_GATHERIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include <cstddef>
#include <cmath>
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/operator/GatherImpl.hpp"
namespace Aidge {
template <class I, class O>
void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_)
{
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const std::size_t axisIdx = std::get<2>(attrs)>=0 ?
std::get<2>(attrs) :
static_cast<std::size_t>(std::get<2>(attrs)) + inputDims.size();
std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
postAxisElems *= inputDims[i];
}
std::size_t preAxisElems = 1;
for (std::size_t i = 0; i < axisIdx; ++i) {
preAxisElems *= inputDims[i];
}
const std::vector<std::int64_t> indices = std::get<0>(attrs);
for (std::size_t i=0; i<preAxisElems; ++i)
{
for(std::size_t j=0; j<indices.size(); ++j)
{
const std::size_t idx = indices[j] >= 0 ? indices[j] : static_cast<std::size_t>(indices[j]) + inputDims[axisIdx];
const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems);
std::copy_n(startPtr, postAxisElems, output);
output += postAxisElems;
}
}
}
namespace {
static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::GatherImpl_cpu_forward_kernel<float, float>);
static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::GatherImpl_cpu_forward_kernel<int, int>);
static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::GatherImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/operator/GatherImpl.hpp"
#include <memory>
#include <vector>
#include "aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Gather.hpp"
#include "aidge/utils/Types.h"
void Aidge::GatherImpl_cpu::forward() {
const Gather_Op& op = static_cast<const Gather_Op&>(mOp);
auto kernelFunc = Registrar<GatherImplForward_cpu>::create({
op.getInput(0)->dataType(),
op.getOutput(0)->dataType()
});
// Call kernel
kernelFunc(dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes(),
op.getInput(0)->dims(),
op.getInput(0)->getImpl()->rawPtr(),
op.getOutput(0)->getImpl()->rawPtr()
);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment