Skip to content
Snippets Groups Projects
Commit 2eff4d7f authored by Maxence Naud's avatar Maxence Naud Committed by Olivier BICHLER
Browse files

Add GridSample impl for 1D and 2D

parent 10ef1960
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!77Add GridSample impl for 1D and 2D
Pipeline #54150 passed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_GRIDSAMPLEIMPL_H_
#define AIDGE_CPU_OPERATOR_GRIDSAMPLEIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/GridSample.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
namespace Aidge {
// compute kernel registry for forward and backward
class GridSampleImpl1DForward_cpu
: public Registrable<GridSampleImpl1DForward_cpu,
std::tuple<DataType, DataType>,
void(const GridSample_Op&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)> {};
class GridSampleImpl2DForward_cpu
: public Registrable<GridSampleImpl2DForward_cpu,
std::tuple<DataType, DataType>,
void(const GridSample_Op&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)> {};
class GridSampleImpl_cpu : public OperatorImpl {
public:
GridSampleImpl_cpu(const GridSample_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<GridSampleImpl_cpu> create(const GridSample_Op &op) {
return std::make_unique<GridSampleImpl_cpu>(op);
}
public:
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
// add cpu backend to GridSample_Op<1> implementation registry
static Registrar<GridSample_Op> registrarGridSampleImpl_cpu("cpu", Aidge::GridSampleImpl_cpu::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_GRIDSAMPLEIMPL_H_ */
This diff is collapsed.
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/operator/GridSampleImpl.hpp"
#include <functional>
#include <vector>
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/GridSampleImpl_forward_kernels.hpp"
#include "aidge/operator/GridSample.hpp"
#include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::GridSampleImpl_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return Elts_t::DataElts(0);
}
void Aidge::GridSampleImpl_cpu::forward() {
const auto& op_ = static_cast<const GridSample_Op&>(mOp);
// Find the correct kernel type
const auto outputDataType = op_.getOutput(0)->dataType();
const Registrar<GridSampleImpl1DForward_cpu>::registrar_key registrarKey = {
op_.getInput(0)->dataType(),
outputDataType};
std::function<void(const GridSample_Op&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)> kernelFunc;
const std::size_t nbSpatialFeat = op_.getInput(0)->nbDims();
switch (nbSpatialFeat)
{
case 1:
kernelFunc = Registrar<GridSampleImpl1DForward_cpu>::create(registrarKey);
break;
case 2:
kernelFunc = Registrar<GridSampleImpl2DForward_cpu>::create(registrarKey);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "No CPU {} kernel available for {} dimensions.", op_.type(), nbSpatialFeat);
break;
}
// Convert input data (no overhead if not needed!)
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback;
const auto& input0 = std::make_shared<Tensor>(op_.getInput(0)->refCastFrom(input0Fallback, *op_.getOutput(0)));
const auto& input1 = std::make_shared<Tensor>(op_.getInput(1)->refCastFrom(input1Fallback, *op_.getOutput(0)));
// Call kernel
kernelFunc(op_,
input0, // input
input1, // grid
op_.getOutput(0) // output
);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment