Skip to content
Snippets Groups Projects
Commit 44c25e9d authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'feat_145_GridSample' into 'dev'

Add GridSample impl for 1D and 2D

See merge request !77
parents 10ef1960 2eff4d7f
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!77Add GridSample impl for 1D and 2D
Pipeline #54174 passed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_GRIDSAMPLEIMPL_H_
#define AIDGE_CPU_OPERATOR_GRIDSAMPLEIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/GridSample.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
namespace Aidge {
// compute kernel registry for forward and backward
class GridSampleImpl1DForward_cpu
: public Registrable<GridSampleImpl1DForward_cpu,
std::tuple<DataType, DataType>,
void(const GridSample_Op&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)> {};
class GridSampleImpl2DForward_cpu
: public Registrable<GridSampleImpl2DForward_cpu,
std::tuple<DataType, DataType>,
void(const GridSample_Op&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)> {};
class GridSampleImpl_cpu : public OperatorImpl {
public:
GridSampleImpl_cpu(const GridSample_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<GridSampleImpl_cpu> create(const GridSample_Op &op) {
return std::make_unique<GridSampleImpl_cpu>(op);
}
public:
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
// add cpu backend to GridSample_Op<1> implementation registry
static Registrar<GridSample_Op> registrarGridSampleImpl_cpu("cpu", Aidge::GridSampleImpl_cpu::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_GRIDSAMPLEIMPL_H_ */
This diff is collapsed.
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/operator/GridSampleImpl.hpp"
#include <functional>
#include <vector>
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/GridSampleImpl_forward_kernels.hpp"
#include "aidge/operator/GridSample.hpp"
#include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::GridSampleImpl_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return Elts_t::DataElts(0);
}
void Aidge::GridSampleImpl_cpu::forward() {
const auto& op_ = static_cast<const GridSample_Op&>(mOp);
// Find the correct kernel type
const auto outputDataType = op_.getOutput(0)->dataType();
const Registrar<GridSampleImpl1DForward_cpu>::registrar_key registrarKey = {
op_.getInput(0)->dataType(),
outputDataType};
std::function<void(const GridSample_Op&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)> kernelFunc;
const std::size_t nbSpatialFeat = op_.getInput(0)->nbDims();
switch (nbSpatialFeat)
{
case 1:
kernelFunc = Registrar<GridSampleImpl1DForward_cpu>::create(registrarKey);
break;
case 2:
kernelFunc = Registrar<GridSampleImpl2DForward_cpu>::create(registrarKey);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "No CPU {} kernel available for {} dimensions.", op_.type(), nbSpatialFeat);
break;
}
// Convert input data (no overhead if not needed!)
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback;
const auto& input0 = std::make_shared<Tensor>(op_.getInput(0)->refCastFrom(input0Fallback, *op_.getOutput(0)));
const auto& input1 = std::make_shared<Tensor>(op_.getInput(1)->refCastFrom(input1Fallback, *op_.getOutput(0)));
// Call kernel
kernelFunc(op_,
input0, // input
input1, // grid
op_.getOutput(0) // output
);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment