Skip to content
Snippets Groups Projects
Commit 5e9dd771 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Pad Operator

parent 49dfeac8
No related branches found
No related tags found
2 merge requests!32version 0.2.1,!14MobileNet operators
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_BACKEND_CUDA_OPERATOR_PADIMPL_H_
#define AIDGE_BACKEND_CUDA_OPERATOR_PADIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include <cudnn.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
template <DimIdx_t DIM>
class PadImpl_cuda : public OperatorImpl {
private:
// CuDNN specific variables
std::shared_ptr<Tensor> mInputFallback;
int mLeftPad, mTopPad;
double mPadVal;
unsigned int mPadType;
public:
PadImpl_cuda(const Pad_Op<DIM> &op) : OperatorImpl(op, "cuda") {}
static std::unique_ptr<PadImpl_cuda> create(const Pad_Op<2> &op) {
return std::make_unique<PadImpl_cuda>(op);
}
public:
void forward();
private:
template <class T> void forward_(const Tensor& input);
};
namespace {
// add cuda backend to Pad_Op<2> implementation registry
static Registrar<Pad_Op<2>> registrarPadImpl_cuda("cuda", Aidge::PadImpl_cuda<2>::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_BACKEND_CUDA_OPERATOR_PADIMPL_H_ */
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CUDA_OPERATOR_PADIMPL_FORWARD_KERNEL_H_
#define AIDGE_CUDA_OPERATOR_PADIMPL_FORWARD_KERNEL_H_
#include "aidge/data/Data.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge
{
template <class T>
void cudaPadding(const cudaDeviceProp &deviceProp,
unsigned int nbOutputs,
unsigned int outputsWidth,
unsigned int outputsHeight,
unsigned int nbChannels,
unsigned int batchSize,
unsigned int inputWidth,
unsigned int inputHeight,
int leftPad,
int topPad,
unsigned int padType,
T padValue,
const T *input,
T *outputs);
}
#endif /* AIDGE_CUDA_OPERATOR_PADIMPL_FORWARD_KERNEL_H_ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <vector>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/PadImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/operator/PadImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/utils/Types.h"
template <Aidge::DimIdx_t DIM>
void Aidge::PadImpl_cuda<DIM>::forward()
{
const Pad_Op<DIM> &op = static_cast<const Pad_Op<DIM> &>(mOp);
assert(mOp.getRawInput(0) && "missing input #0");
const auto &input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0));
auto paddingBorders = op.template getAttr<PadAttr::BeginEndBorders>();
mLeftPad = paddingBorders[2];
mTopPad = paddingBorders[0];
mPadVal = op.template getAttr<PadAttr::BorderValue>();
mPadType = static_cast<unsigned int>(op.template getAttr<PadAttr::BorderType>());
switch (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType())
{
case DataType::Float64:
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<half>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
}
template <Aidge::DimIdx_t DIM>
template <class T>
void Aidge::PadImpl_cuda<DIM>::forward_(const Tensor &input)
{
const auto outDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
const T *inputPtr = static_cast<const T *>(input.getImpl()->rawPtr());
T *output = static_cast<T *>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
Aidge::cudaPadding(CudaContext::getDeviceProp(),
outDims[1],
outDims[3],
outDims[2],
input.dims()[1],
input.dims()[0],
input.dims()[3],
input.dims()[2],
mLeftPad,
mTopPad,
mPadType,
static_cast<T>(mPadVal),
inputPtr,
output);
}
// Template declarations
template class Aidge::PadImpl_cuda<2>;
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cuda/operator/PadImpl_CUDA_kernels.hpp"
template <typename T>
__global__ void cudaPadding_kernel(unsigned int nbOutputs,
unsigned int outputWidth,
unsigned int outputHeight,
unsigned int nbChannels,
unsigned int inputWidth,
unsigned int inputHeight,
int leftPad,
int topPad,
unsigned int padType,
T padValue,
const T *input,
T *outputs)
{
const unsigned int inputOffset = (blockIdx.z * blockDim.z + threadIdx.z) * nbChannels * inputWidth * inputHeight;
const unsigned int outputOffset = (blockIdx.z * blockDim.z + threadIdx.z) * nbOutputs * outputWidth * outputHeight;
// nbCh = nbChannels for propagate
// = nbOutputs for back-propagate
const unsigned int nbCh = min(nbChannels, nbOutputs);
for (unsigned int ch = blockIdx.x; ch < nbCh; ch += gridDim.x)
{
for (unsigned int oy = threadIdx.y; oy < outputHeight; oy += blockDim.y)
{
for (unsigned int ox = threadIdx.x; ox < outputWidth; ox += blockDim.x)
{
T outputValue = padValue;
if (padType == 0) // Const padding
{
int ix = (int)ox - leftPad;
int iy = (int)oy - topPad;
if (ix >= 0 && ix < (int)inputWidth && iy >= 0 && iy < (int)inputHeight)
{
outputValue = input[ix +
iy * inputWidth + ch * inputWidth * inputHeight + inputOffset];
}
}
else if (padType == 1) // Edge padding
{
int ix = max(0, min((int)inputWidth - 1, (int)ox - leftPad));
int iy = max(0, min((int)inputHeight - 1, (int)oy - topPad));
outputValue = input[ix +
iy * inputWidth + ch * inputWidth * inputHeight + inputOffset];
}
else if (padType == 2) // Reflect padding
{
int ix = (int)ox - leftPad;
int iy = (int)oy - topPad;
if (ix < 0)
ix = 0 - ix;
if (iy < 0)
iy = 0 - iy;
if (ix >= (int)inputWidth)
ix = (int)inputWidth - ix;
if (iy >= (int)inputHeight)
iy = (int)inputHeight - iy;
outputValue = input[ix +
iy * inputWidth + ch * inputWidth * inputHeight + inputOffset];
}
else if (padType == 3) // Wrap padding
{
int ix = (inputWidth + (int)ox - leftPad) % inputWidth;
int iy = (inputHeight + (int)oy - topPad) % inputHeight;
outputValue = input[ix +
iy * inputWidth + ch * inputWidth * inputHeight + inputOffset];
}
outputs[ox + oy * outputWidth + ch * outputWidth * outputHeight + outputOffset] = outputValue;
}
}
}
}
template <> // double
void Aidge::cudaPadding(const cudaDeviceProp &deviceProp,
unsigned int nbOutputs,
unsigned int outputsWidth,
unsigned int outputsHeight,
unsigned int nbChannels,
unsigned int batchSize,
unsigned int inputWidth,
unsigned int inputHeight,
int leftPad,
int topPad,
unsigned int padType,
double padValue,
const double *input,
double *outputs)
{
const unsigned int maxSize = (unsigned int)deviceProp.maxThreadsPerBlock;
const unsigned int prefMultiple = (unsigned int)deviceProp.warpSize;
const unsigned int groupSize = (outputsWidth * outputsHeight < maxSize)
? outputsWidth * outputsHeight
: maxSize;
const unsigned int reqWidth = (unsigned int)ceilf((float)groupSize / (float)outputsWidth);
const unsigned int groupWidth = min(prefMultiple, reqWidth);
const dim3 blocksPerGrid = {nbChannels, 1, batchSize};
const dim3 threadsPerBlocks = {groupWidth, groupSize / groupWidth, 1};
cudaPadding_kernel<<<blocksPerGrid, threadsPerBlocks>>>(nbOutputs,
outputsWidth,
outputsHeight,
nbChannels,
inputWidth,
inputHeight,
leftPad,
topPad,
padType,
padValue,
input,
outputs);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
template <> // float
void Aidge::cudaPadding(const cudaDeviceProp &deviceProp,
unsigned int nbOutputs,
unsigned int outputsWidth,
unsigned int outputsHeight,
unsigned int nbChannels,
unsigned int batchSize,
unsigned int inputWidth,
unsigned int inputHeight,
int leftPad,
int topPad,
unsigned int padType,
float padValue,
const float *input,
float *outputs)
{
const unsigned int maxSize = (unsigned int)deviceProp.maxThreadsPerBlock;
const unsigned int prefMultiple = (unsigned int)deviceProp.warpSize;
const unsigned int groupSize = (outputsWidth * outputsHeight < maxSize)
? outputsWidth * outputsHeight
: maxSize;
const unsigned int reqWidth = (unsigned int)ceilf((float)groupSize / (float)outputsWidth);
const unsigned int groupWidth = min(prefMultiple, reqWidth);
const dim3 blocksPerGrid = {nbChannels, 1, batchSize};
const dim3 threadsPerBlocks = {groupWidth, groupSize / groupWidth, 1};
cudaPadding_kernel<<<blocksPerGrid, threadsPerBlocks>>>(nbOutputs,
outputsWidth,
outputsHeight,
nbChannels,
inputWidth,
inputHeight,
leftPad,
topPad,
padType,
padValue,
input,
outputs);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
template <> // half
void Aidge::cudaPadding(const cudaDeviceProp &deviceProp,
unsigned int nbOutputs,
unsigned int outputsWidth,
unsigned int outputsHeight,
unsigned int nbChannels,
unsigned int batchSize,
unsigned int inputWidth,
unsigned int inputHeight,
int leftPad,
int topPad,
unsigned int padType,
half padValue,
const half *input,
half *outputs)
{
const unsigned int maxSize = (unsigned int)deviceProp.maxThreadsPerBlock;
const unsigned int prefMultiple = (unsigned int)deviceProp.warpSize;
const unsigned int groupSize = (outputsWidth * outputsHeight < maxSize)
? outputsWidth * outputsHeight
: maxSize;
const unsigned int reqWidth = (unsigned int)ceilf((float)groupSize / (float)outputsWidth);
const unsigned int groupWidth = min(prefMultiple, reqWidth);
const dim3 blocksPerGrid = {nbChannels, 1, batchSize};
const dim3 threadsPerBlocks = {groupWidth, groupSize / groupWidth, 1};
cudaPadding_kernel<<<blocksPerGrid, threadsPerBlocks>>>(nbOutputs,
outputsWidth,
outputsHeight,
nbChannels,
inputWidth,
inputHeight,
leftPad,
topPad,
padType,
padValue,
input,
outputs);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
\ No newline at end of file
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment