diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp index dd790290aa8b25097cbc3a0109b57279a7552777..4a706fd0132072220e803c7852d798b3afaa8257 100644 --- a/include/aidge/backend/cuda.hpp +++ b/include/aidge/backend/cuda.hpp @@ -41,6 +41,7 @@ #include "aidge/backend/cuda/operator/ReduceMeanImpl.hpp" #include "aidge/backend/cuda/operator/ReduceSumImpl.hpp" #include "aidge/backend/cuda/operator/ReLUImpl.hpp" +#include "aidge/backend/cuda/operator/ResizeImpl.hpp" #include "aidge/backend/cuda/operator/RoundImpl.hpp" #include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp" #include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp" diff --git a/include/aidge/backend/cuda/data/Interpolation.cuh b/include/aidge/backend/cuda/data/Interpolation.cuh new file mode 100644 index 0000000000000000000000000000000000000000..316272ac6991a54855148f8e26041b1f65407343 --- /dev/null +++ b/include/aidge/backend/cuda/data/Interpolation.cuh @@ -0,0 +1,179 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#ifndef AIDGE_CUDA_DATA_INTERPOLATION_H_ +#define AIDGE_CUDA_DATA_INTERPOLATION_H_ + +#include "aidge/data/Interpolation.hpp" +namespace Aidge { +namespace InterpolationCUDA { + + /** + * @brief Computes the approximate input coordinates corresponding to given output coordinates, + * according to the specified coordinate transformation mode. + * + * This function performs an "untransform" step for spatial interpolation, mapping output-space + * integer coordinates (`coordOut`) back into approximate input-space floating-point coordinates + * (`coordInApprox`) based on the transformation strategy provided. + * + * @param coordOut Pointer to the output coordinates (integer values). + * @param inputDims Pointer to the dimensions of the input tensor. + * @param outputDims Pointer to the dimensions of the output tensor. + * @param coordTransfoMode The coordinate transformation mode (e.g., AlignCorners, HalfPixel, etc.). + * @param roi Pointer to region of interest values, used only for TFCropAndResize mode; + * expected to be of size 2 * rank. + * @param coordInApprox Pointer to the output array where the approximate input coordinates will be stored. + * @param rank The dimensionality of the spatial domain (e.g., 2 for 2D, 3 for 3D). + */ + __device__ void untransformCoordinates( + const int* coordOut, + const int* inputDims, + const int* outputDims, + Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + const float* roi, + float* coordInApprox, + int rank); + + /** + * @brief Retrieves neighboring input tensor values around a set of continuous coordinates for interpolation. + * + * This function gathers neighboring values from a multidimensional tensor for use in interpolation, + * depending on the interpolation mode (e.g., linear or cubic), scaling, padding strategy, and + * anti-aliasing settings. The gathered values and their coordinates are stored in output buffers. + * + * @tparam T The data type of the input and output tensor values. + * @param tensorValues Pointer to the input tensor values (flattened array). + * @param tensorDims Pointer to the dimensions of the input tensor. + * @param coords Pointer to the floating-point coordinates in the input space. + * @param scales Pointer to scaling factors per dimension. + * @param rank The number of spatial dimensions in the tensor. + * @param mode Interpolation mode (e.g., Linear, Cubic). + * @param paddingMode Padding mode used for out-of-bound accesses (e.g., Zero, Edge). + * @param antialiasing Whether antialiasing is enabled (affects kernel footprint). + * @param outValues Pointer to the buffer where retrieved neighbor values will be stored. + * @param outCoords Pointer to the buffer where the corresponding coordinates of neighbors will be stored. + * Output shape is [maxNeighbours x rank]. + * @param outCount Pointer to a single integer where the number of valid neighbors will be written. + * @param maxNeighbours Maximum number of neighbors to retrieve (capacity of the output buffers). + */ + template <typename T> + __device__ void retrieveNeighboursKernel( + const T* tensorValues, + const int* tensorDims, + float* coords, + const float* scales, + int rank, + Aidge::Interpolation::Mode mode, + Aidge::PadBorderType paddingMode, + bool antialiasing, + T* outValues, + int* outCoords, + int* outCount, + int maxNeighbours + ); + + /** + * @brief Performs N-dimensional linear interpolation given neighbor points and their values. + * + * This function computes a weighted average of neighbor values based on linear interpolation + * weights derived from their distance to a target coordinate. It optionally applies antialiasing + * scaling to the distances. The weights are normalized to ensure smooth interpolation. + * + * @tparam T The data type of the input and output values. + * @param coordToInterpolate Pointer to the floating-point coordinate to interpolate at (length = rank). + * @param scales Pointer to scale factors per dimension. + * @param pointsCoords Pointer to neighbor coordinates (flattened array of size coordsNbr × rank). + * @param pointValues Pointer to values corresponding to neighbor coordinates. + * @param coordsNbr Number of neighboring points used in the interpolation. + * @param rank Number of spatial dimensions. + * @param antialiasing Whether to apply antialiasing by scaling coordinate deltas. + * @return T The interpolated value at the target coordinate. + */ + template <typename T> + __device__ T interpolateLinear( + const float* coordToInterpolate, + const float* scales, + const int* pointsCoords, + const T* pointValues, + int coordsNbr, + int rank, + bool antialiasing + ); + + /** + * @brief Performs N-dimensional cubic interpolation given neighbor points and their values. + * + * This function computes a weighted average of neighbor values based on cubic interpolation weights, + * optionally applying antialiasing and excluding neighbors outside input dimensions. + * Static dimensions (with size 1 and scale 1) are ignored in the weight calculation. + * + * @tparam T The data type of the input and output values. + * @param coordToInterpolate Pointer to the floating-point coordinate to interpolate at (length = rank). + * @param scales Pointer to scale factors per dimension. + * @param pointsCoords Pointer to neighbor coordinates (flattened array of size coordsNbr × rank). + * @param pointValues Pointer to values corresponding to neighbor coordinates. + * @param coordsNbr Number of neighboring points used in the interpolation. + * @param rank Number of spatial dimensions. + * @param inputDims Pointer to input tensor dimensions. + * @param a Cubic interpolation parameter (often -0.75 for Catmull-Rom). + * @param antialiasing Whether to apply antialiasing in weight calculation. + * @param excludeOutside Whether to exclude neighbors outside input dimensions. + * @return T The interpolated value at the target coordinate. + */ + template <typename T> + __device__ T interpolateCubic( + float* coordToInterpolate, + const float* scales, + int* pointsCoords, + T* pointValues, + int coordsNbr, + int rank, + const int* inputDims, + float a, + bool antialiasing, + bool excludeOutside + ); + + /** + * @brief Dispatches to the appropriate interpolation method (linear or cubic) based on the mode. + * + * This function selects the interpolation method and computes the interpolated value + * at the given coordinate using either cubic or linear interpolation. Returns zero + * for unsupported interpolation modes. + * + * @tparam T The data type of the input and output values. + * @param coordToInterpolate Pointer to the floating-point coordinate to interpolate at (length = rank). + * @param scales Pointer to scale factors per dimension. + * @param pointsCoords Pointer to neighbor coordinates (flattened array of size coordsNbr × rank). + * @param pointValues Pointer to values corresponding to neighbor coordinates. + * @param coordsNbr Number of neighboring points used in the interpolation. + * @param rank Number of spatial dimensions. + * @param mode Interpolation mode (Linear or Cubic). + * @param cubicCoeffA Cubic interpolation coefficient (used only for cubic mode). + * @param antialiasing Whether to apply antialiasing in interpolation. + * @param excludeOutside Whether to exclude neighbors outside input dimensions (only cubic mode). + * @param inputDims Pointer to input tensor dimensions. + * @return T The interpolated value at the target coordinate. + */ + template <typename T> + __device__ T interpolate(float* coordToInterpolate, + const float* scales, + int* pointsCoords, + T* pointValues, + int coordsNbr, + int rank, + Aidge::Interpolation::Mode mode, + float cubicCoeffA, + bool antialiasing, + bool excludeOutside, + const int* inputDims); +} // namespace InterpolationCUDA +} // namespace Aidge +#endif /*AIDGE_CUDA_DATA_INTERPOLATION_H_*/ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/ResizeImpl.hpp b/include/aidge/backend/cuda/operator/ResizeImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4614ddefe41fc1c125e8e6abec9052a951501ced --- /dev/null +++ b/include/aidge/backend/cuda/operator/ResizeImpl.hpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_BACKEND_CUDA_OPERATOR_RESIZEIMPL_H_ +#define AIDGE_BACKEND_CUDA_OPERATOR_RESIZEIMPL_H_ + +#include <array> +#include <memory> +#include <tuple> +#include <vector> + +#include <cudnn.h> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Resize.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { +// Operator implementation entry point for the backend +class ResizeImpl_cuda : public OperatorImpl { +public: + ResizeImpl_cuda(const Resize_Op& op) : OperatorImpl(op, "cuda") {} + + static std::unique_ptr<ResizeImpl_cuda> create(const Resize_Op& op) { + return std::make_unique<ResizeImpl_cuda>(op); + } + + virtual std::vector<ImplSpec> getAvailableImplSpecs() const override { + return { + {DataType::Float64}, + {DataType::Float32}, + {DataType::Float16}, + }; + } + + void forward() override; + +private: + template <class T> void forward_(); +}; + +// Implementation entry point registration to Operator +REGISTRAR(Resize_Op, "cuda", Aidge::ResizeImpl_cuda::create); +} // namespace Aidge + +#endif /* AIDGE_BACKEND_CUDA_OPERATOR_RESIZEIMPL_H_ */ diff --git a/include/aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..25bb17064638bd637a07bd079e055988badc72ff --- /dev/null +++ b/include/aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp @@ -0,0 +1,48 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CUDA_OPERATOR_RESIZEIMPL_KERNELS_H_ +#define AIDGE_CUDA_OPERATOR_RESIZEIMPL_KERNELS_H_ + +#include <stdexcept> +#include <cfloat> +#include <cuda.h> +#include <cuda_runtime_api.h> +#include <cuda_fp16.h> + +#include "aidge/data/Data.hpp" +#include "aidge/data/Interpolation.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { + + +template <class T> +void resizeForward(const T* input, T* output,const std::vector<float> &roi, + const std::vector<int>& inputDims, const std::vector<int>& outputDims, + const std::vector<int>& inputStrides, const std::vector<int>& outputStrides, + const std::vector<float> &scales, + const Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + const Aidge::Interpolation::Mode interpMode, + const Aidge::PadBorderType paddingMode, + float cubic_coeff_a, + float extrapolationVal, + bool antialiasing, + bool excludeOutside, + int size); + +} +#endif /* AIDGE_CUDA_OPERATOR_RESIZEIMPL_KERNELS_H_ */ + + + + + diff --git a/src/data/Interpolation.cu b/src/data/Interpolation.cu new file mode 100644 index 0000000000000000000000000000000000000000..04a8ab13bdfa8bd93b74ca7ee3d2330e42b56920 --- /dev/null +++ b/src/data/Interpolation.cu @@ -0,0 +1,535 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <cuda_fp16.h> +#include <math.h> // For isnan() + +#include "aidge/backend/cuda/data/Interpolation.cuh" + +#include "aidge/data/Interpolation.hpp" +template <typename T> +__device__ float toFloat(T val); + +template <> +__device__ float toFloat(float val) { + return val; +} + +template <> +__device__ float toFloat(double val) { + return static_cast<float>(val); +} + +template <> +__device__ float toFloat(__half val) { + return __half2float(val); +} + +template <typename T> +__device__ T fromFloat(float val); + +template <> +__device__ float fromFloat<float>(float val) { + return val; +} + +template <> +__device__ double fromFloat<double>(float val) { + return static_cast<double>(val); +} + +template <> +__device__ __half fromFloat<__half>(float val) { + return __float2half(val); +} + +__device__ void Aidge::InterpolationCUDA::untransformCoordinates( + const int* coordOut, + const int* inputDims, + const int* outputDims, + Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + const float* roi, + float* coordInApprox, + int rank) +{ + for (int i = 0; i < rank; ++i) { + float scale = static_cast<float>(outputDims[i]) / + static_cast<float>(inputDims[i]); + + switch (coordTransfoMode) { + case Aidge::Interpolation::CoordinateTransformation::AlignCorners: + if (inputDims[i] == 1 || outputDims[i] == 1) { + coordInApprox[i] = 0.0f; + } else { + coordInApprox[i] = coordOut[i] * static_cast<float>(inputDims[i] - 1) / + static_cast<float>(outputDims[i] - 1); + } + break; + + case Aidge::Interpolation::CoordinateTransformation::Asymmetric: + coordInApprox[i] = coordOut[i] / scale; + break; + + case Aidge::Interpolation::CoordinateTransformation::HalfPixel: + case Aidge::Interpolation::CoordinateTransformation::HalfPixelSymmetric: + coordInApprox[i] = (coordOut[i] + 0.5f) / scale - 0.5f; + break; + + case Aidge::Interpolation::CoordinateTransformation::PytorchHalfPixel: + coordInApprox[i] = static_cast<float>(outputDims[i]) > 1 + ? (coordOut[i] + 0.5f) / scale - 0.5f + : 0.0f; + break; + + case Aidge::Interpolation::CoordinateTransformation::TFHalfPixelForNN: + coordInApprox[i] = (coordOut[i] + 0.5f) / scale; + break; + + case Aidge::Interpolation::CoordinateTransformation::TFCropAndResize: + { + float roiStart = roi[i]; + float roiEnd = roi[i + rank]; + if (outputDims[i] > 1) { + coordInApprox[i] = roiStart * (inputDims[i] - 1) + + coordOut[i] * (roiEnd - roiStart) * + static_cast<float>(inputDims[i] - 1) / + static_cast<float>(outputDims[i] - 1); + } else { + coordInApprox[i] = 0.5f * (roiStart + roiEnd) * (inputDims[i] - 1); + } + + if (coordInApprox[i] < 0.0f || coordInApprox[i] > static_cast<float>(inputDims[i] - 1)) { + coordInApprox[i] = NAN; + } + break; + } + } + } +} + +template <typename T> +__device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel( + const T* tensorValues, + const int* tensorDims, + float* coords, + const float* scales, + int rank, + Aidge::Interpolation::Mode mode, + Aidge::PadBorderType paddingMode, + bool antialiasing, + T* outValues, + int* outCoords, + int* outCount, + int maxNeighbours +) { + int count = 0; + + // Offset buffer per dimension + const int maxKernelSize = 16; + int offsetRanges[10][maxKernelSize]; + int offsetSizes[10]; + for (int dim = 0; dim < rank; ++dim) { + if (mode == Aidge::Interpolation::Mode::Cubic) { + float scale = scales[dim]; + int kernelWidth = 4; + + int center = (int)floorf(coords[dim]); + + int kernel_size; + int start; + + if (antialiasing && scale < 1.0f) { + float footprint = kernelWidth / scale; + int radius = (int)ceilf((footprint - 1.0f) / 2.0f); + kernel_size = 2 * radius + 1; + offsetSizes[dim] = kernel_size; + + start = center - radius; + for (int i = 0; i < kernel_size; ++i) { + offsetRanges[dim][i] = start + i; + } + } else { + // Default cubic kernel + kernel_size = 4; + offsetSizes[dim] = kernel_size; + + start = center - 1; + for (int i = 0; i < kernel_size; ++i) { + offsetRanges[dim][i] = start + i; + } + } + } else if(mode == Aidge::Interpolation::Mode::Linear && antialiasing) { + float scale = scales[dim]; + + // Only scale height and width (assuming batch and channels are unscaled) + if (dim >= 2 && scale < 1.0f) { + int kernel_size = max(2, (int)ceilf(2.0f / scale)); + offsetSizes[dim] = kernel_size; + + int center = (int)floorf(coords[dim]); + int start = center - (kernel_size - 1) / 2; + + for (int k = 0; k < kernel_size; ++k) { + offsetRanges[dim][k] = start + k; + } + } else { + offsetSizes[dim] = 2; + int center = (int)floorf(coords[dim]); + for (int k = 0; k < 2; ++k) { + offsetRanges[dim][k] = center + k; + } + } + } + else { + offsetSizes[dim] = 2; + offsetRanges[dim][0] = (int)ceilf(coords[dim]); + offsetRanges[dim][1] = (int)floorf(coords[dim]); + } + + } + + int loopCounts[10] = {0}; + + while (true) { + // Build coordinate + int currentCoords[10]; + for (int i = 0; i < rank; ++i) { + currentCoords[i] = offsetRanges[i][loopCounts[i]]; + } + + // Handle padding + bool valid = true; + int clampedCoords[10]; + for (int i = 0; i < rank; ++i) { + int coord = currentCoords[i]; + switch (paddingMode) { + case Aidge::PadBorderType::Edge: + coord = min(max(coord, 0), tensorDims[i] - 1); + break; + case Aidge::PadBorderType::Zero: + if (coord < 0 || coord >= tensorDims[i]) { + valid = false; + } + break; + default: + valid = false; // Unsupported padding + break; + } + clampedCoords[i] = coord; + } + + if (valid && count < maxNeighbours) { + int flatIdx = 0; + int stride = 1; + for (int i = rank - 1; i >= 0; --i) { + flatIdx += clampedCoords[i] * stride; + stride *= tensorDims[i]; + } + + // Store output + for (int j = 0; j < rank; ++j) { + outCoords[count * rank + j] = currentCoords[j]; + } + outValues[count] = tensorValues[flatIdx]; + ++count; + } + + // Next coordinate combination + int dim = rank - 1; + while (dim >= 0) { + loopCounts[dim]++; + if (loopCounts[dim] < offsetSizes[dim]) break; + loopCounts[dim] = 0; + --dim; + } + if (dim < 0) break; + } + + *outCount = count; +} + +template <typename T> +__device__ T Aidge::InterpolationCUDA::interpolateLinear( + const float* coordToInterpolate, + const float* scales, + const int* pointsCoords, + const T* pointValues, + int coordsNbr, + int rank, + bool antialiasing +) { + float resultAccum = 0.0f; + float totalWeight = 0.0f; + + for (int i = 0; i < coordsNbr; ++i) { + float weight = 1.0f; + bool skip = false; + + for (int d = 0; d < rank; ++d) { + float dx = coordToInterpolate[d] - static_cast<float>(pointsCoords[i * rank + d]); + float scaledDx = antialiasing ? dx * fminf(scales[d], 1.0f) : dx; + + float w = fmaxf(0.0f, 1.0f - fabsf(scaledDx)); + if (w == 0.0f) { + skip = true; + break; + } + weight *= w; + } + + if (skip) continue; + + resultAccum += toFloat(pointValues[i]) * weight; + totalWeight += weight; + } + + float finalValue = (totalWeight > 0.0f) ? resultAccum / totalWeight : 0.0f; + return fromFloat<T>(finalValue); +} + +template <typename T> +__device__ float cubicWeight(float x, float a) { + x = fabsf(x); + if (x < 1.0f) { + return ((a + 2.0f) * x * x * x - (a + 3.0f) * x * x + 1.0f); + } else if (x < 2.0f) { + return (a * x * x * x - 5.0f * a * x * x + 8.0f * a * x - 4.0f * a); + } else { + return 0.0f; + } +} + +template <typename T> +__device__ float cubicWeightAntialiased(float x, float scale, float a) { + float s = fminf(scale, 1.0f); + x *= s; + x = fabsf(x); + float x2 = x * x; + float x3 = x2 * x; + if (x <= 1.0f) { + return ((a + 2.0f) * x3 - (a + 3.0f) * x2 + 1.0f); + } else if (x < 2.0f) { + return (a * x3 - 5.0f * a * x2 + 8.0f * a * x - 4.0f * a); + } else { + return 0.0f; + } +} + +template <typename T> +__device__ T Aidge::InterpolationCUDA::interpolateCubic( + float* coordToInterpolate, + const float* scales, + int* pointsCoords, + T* pointValues, + int coordsNbr, + int rank, + const int* inputDims, + float a, + bool antialiasing, + bool excludeOutside +) { + float resultAccum = 0.0f; + float totalWeight = 0.0f; + + for (int i = 0; i < coordsNbr; ++i) { + float w = 1.0f; + bool skip = false; + + for (int d = 0; d < rank; ++d) { + int coord = pointsCoords[i * rank + d]; + float x = coordToInterpolate[d]; + + // Skip static dimensions + if (inputDims[d] == 1 && scales[d] == 1.0f) { + continue; + } + + if (excludeOutside && (coord < 0 || coord >= inputDims[d])) { + skip = true; + break; + } + + float dx = x - static_cast<float>(coord); + float weight = antialiasing + ? cubicWeightAntialiased<T>(dx, scales[d], a) + : cubicWeight<T>(dx, a); + + w *= weight; + } + + if (skip || w == 0.0f) continue; + + resultAccum += toFloat(pointValues[i]) * w; + totalWeight += w; + } + + float finalValue = (totalWeight > 0.0f) ? resultAccum / totalWeight : 0.0f; + return fromFloat<T>(finalValue); +} + +template <typename T> +__device__ T Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate, + const float* scales, + int* pointsCoords, + T* pointValues, + int coordsNbr, + int rank, + Aidge::Interpolation::Mode mode, + float cubicCoeffA, + bool antialiasing, + bool excludeOutside, + const int* inputDims) { + switch (mode) { + case Aidge::Interpolation::Mode::Cubic: return interpolateCubic<T>(coordToInterpolate, scales, pointsCoords, pointValues, coordsNbr, rank, inputDims, cubicCoeffA, antialiasing, excludeOutside); + case Aidge::Interpolation::Mode::Linear: return interpolateLinear<T>(coordToInterpolate, scales, pointsCoords, pointValues, coordsNbr, rank, antialiasing); + default: return T(0); + } +} + +/////// templates instantiations + +template __device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel( + const double* tensorValues, + const int* tensorDims, + float* coords, + const float* scales, + int rank, + Aidge::Interpolation::Mode mode, + Aidge::PadBorderType paddingMode, + bool antialiasing, + double* outValues, + int* outCoords, + int* outCount, + int maxNeighbours); +template __device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel( + const float* tensorValues, + const int* tensorDims, + float* coords, + const float* scales, + int rank, + Aidge::Interpolation::Mode mode, + Aidge::PadBorderType paddingMode, + bool antialiasing, + float* outValues, + int* outCoords, + int* outCount, + int maxNeighbours); +template __device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel( + const half* tensorValues, + const int* tensorDims, + float* coords, + const float* scales, + int rank, + Aidge::Interpolation::Mode mode, + Aidge::PadBorderType paddingMode, + bool antialiasing, + half* outValues, + int* outCoords, + int* outCount, + int maxNeighbours); + +template __device__ double Aidge::InterpolationCUDA::interpolateLinear( + const float* coordToInterpolate, + const float* scales, + const int* pointsCoords, + const double* pointValues, + int coordsNbr, + int rank, + bool antialiasing +); +template __device__ float Aidge::InterpolationCUDA::interpolateLinear( + const float* coordToInterpolate, + const float* scales, + const int* pointsCoords, + const float* pointValues, + int coordsNbr, + int rank, + bool antialiasing +); +template __device__ half Aidge::InterpolationCUDA::interpolateLinear( + const float* coordToInterpolate, + const float* scales, + const int* pointsCoords, + const half* pointValues, + int coordsNbr, + int rank, + bool antialiasing +); + +template __device__ double Aidge::InterpolationCUDA::interpolateCubic( + float* coordToInterpolate, + const float* scales, + int* pointsCoords, + double* pointValues, + int coordsNbr, + int rank, + const int* inputDims, + float a, + bool antialiasing, + bool excludeOutside +); +template __device__ float Aidge::InterpolationCUDA::interpolateCubic( + float* coordToInterpolate, + const float* scales, + int* pointsCoords, + float* pointValues, + int coordsNbr, + int rank, + const int* inputDims, + float a, + bool antialiasing, + bool excludeOutside +); +template __device__ half Aidge::InterpolationCUDA::interpolateCubic( + float* coordToInterpolate, + const float* scales, + int* pointsCoords, + half* pointValues, + int coordsNbr, + int rank, + const int* inputDims, + float a, + bool antialiasing, + bool excludeOutside +); + +template __device__ double Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate, + const float* scales, + int* pointsCoords, + double* pointValues, + int coordsNbr, + int rank, + Aidge::Interpolation::Mode mode, + float cubicCoeffA, + bool antialiasing, + bool excludeOutside, + const int* inputDims); +template __device__ float Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate, + const float* scales, + int* pointsCoords, + float* pointValues, + int coordsNbr, + int rank, + Aidge::Interpolation::Mode mode, + float cubicCoeffA, + bool antialiasing, + bool excludeOutside, + const int* inputDims); +template __device__ half Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate, + const float* scales, + int* pointsCoords, + half* pointValues, + int coordsNbr, + int rank, + Aidge::Interpolation::Mode mode, + float cubicCoeffA, + bool antialiasing, + bool excludeOutside, + const int* inputDims); \ No newline at end of file diff --git a/src/operator/ResizeImpl.cpp b/src/operator/ResizeImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a397149dc6cfdc9e630460eca664a736ce1d7c7 --- /dev/null +++ b/src/operator/ResizeImpl.cpp @@ -0,0 +1,96 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> +#include <cassert> +#include <numeric> +#include <vector> +#include <iostream> + +#include <cuda_fp16.h> +#include "aidge/backend/cuda/data/TensorImpl.hpp" +#include "aidge/backend/cuda/operator/ResizeImpl.hpp" +#include "aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp" +#include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/Resize.hpp" +#include "aidge/utils/Types.h" + +void Aidge::ResizeImpl_cuda::forward() { + const Resize_Op& op = static_cast<const Resize_Op&>(mOp); + // Check inputs + AIDGE_ASSERT(op.getInput(0), "missing input in Resize operator"); + AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run Resize forward because the 0-th input has no implementation."); + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<double>(); + break; + case DataType::Float32: + forward_<float>(); + break; + case DataType::Float16: + forward_<half>(); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); + } +} + +template <class T> +void Aidge::ResizeImpl_cuda::forward_() +{ + const Resize_Op& op = static_cast<const Resize_Op&>(mOp); + // int size = op.getInput(0)->size(); + const T* inputPtr = static_cast<T*>(op.getInput(0)->getImpl()->rawPtr()); + T* outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr()); + std::vector<int> inputDimsInt; + inputDimsInt.reserve(op.getInput(0)->dims().size()); + for (size_t val : op.getInput(0)->dims()) { + inputDimsInt.push_back(static_cast<int>(val)); + } + std::vector<int> outputDimsInt; + outputDimsInt.reserve(op.getOutput(0)->dims().size()); + for (size_t val : op.getOutput(0)->dims()) { + outputDimsInt.push_back(static_cast<int>(val)); + } + std::vector<int> outputStrides(op.getOutput(0)->nbDims(), 1); + if(op.getOutput(0)->nbDims()>1) { + for (int i = op.getOutput(0)->nbDims()-2; i >= 0; i--) { + outputStrides[i] = outputStrides[i+1] * op.getOutput(0)->dims()[i+1]; + } + } + std::vector<int> inputStrides(op.getInput(0)->nbDims(), 1); + if(op.getInput(0)->nbDims()>1) { + for (int i = op.getInput(0)->nbDims()-2; i >= 0; i--) { + inputStrides[i] = inputStrides[i+1] * op.getInput(0)->dims()[i+1]; + } + } + std::vector<float> scales; + scales.reserve(op.getInput(0)->dims().size()); + + for (size_t i = 0; i < op.getInput(0)->dims().size(); ++i) { + scales.push_back(static_cast<float>(op.getOutput(0)->dims()[i]) / static_cast<float>(op.getInput(0)->dims()[i])); + } + + Aidge::resizeForward<T>(inputPtr,outputPtr, op.roi(), + inputDimsInt, outputDimsInt, + inputStrides, outputStrides, + scales, + op.coordinateTransformationMode(), + op.interpolationMode(), + op.paddingMode(), + op.cubic_coeff_a(), + op.extrapolation_val(), + op.antialiasing(), + op.exclude_outside(), + static_cast<int>(op.getOutput(0)->size())); +} diff --git a/src/operator/ResizeImpl_CUDA_kernels.cu b/src/operator/ResizeImpl_CUDA_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..4008f03d17779a53f5d400f580debf947d6efe27 --- /dev/null +++ b/src/operator/ResizeImpl_CUDA_kernels.cu @@ -0,0 +1,216 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <cuda_fp16.h> +#include <math.h> // For isnan() +#include <thrust/device_vector.h> +#include <type_traits> + +#include "aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp" +#include "aidge/backend/cuda/data/Interpolation.cuh" +#include "aidge/data/Interpolation.hpp" + + +template <class T> +__global__ void resizeKernel(const T* input, T* output, + const float* roi, + const int* inputDims, const int* outputDims, + const int* inputStrides, const int* outputStrides, + const float* scales, + Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + Aidge::Interpolation::Mode interpMode, + Aidge::PadBorderType paddingMode, + float cubic_coeff_a, + float extrapolationVal, + bool antialiasing, + bool excludeOutside, + int rank, + int totalElements) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= totalElements) return; + + // Compute output coordinates from flat index + int coordOut[CUDNN_DIM_MAX]; + int tmp = idx; + for (int i = 0; i < rank; ++i) { + coordOut[i] = tmp / outputStrides[i]; + tmp %= outputStrides[i]; + } + + + // Compute input coordinates (float) using provided coordinate transformation + float coordInApprox[CUDNN_DIM_MAX]; + + Aidge::InterpolationCUDA::untransformCoordinates(&coordOut[0], + &inputDims[0], + &outputDims[0], + coordTransfoMode, + &roi[0], + &coordInApprox[0], + rank); + + for (int i = 0; i < rank; ++i) { + if (isnan(coordInApprox[i])) { + output[idx] = static_cast<T>(extrapolationVal); + return; + } + } + + // For rounding-style modes + if ((interpMode == Aidge::Interpolation::Mode::Ceil) || + (interpMode == Aidge::Interpolation::Mode::Floor) || + (interpMode == Aidge::Interpolation::Mode::RoundPreferCeil) || + (interpMode == Aidge::Interpolation::Mode::RoundPreferFloor)) { + + for (int i = 0; i < rank; ++i) { + if (interpMode == Aidge::Interpolation::Mode::Ceil) { + coordInApprox[i] = ceilf(coordInApprox[i]); + } else if (interpMode == Aidge::Interpolation::Mode::Floor) { + coordInApprox[i] = floorf(coordInApprox[i]); + } else if (interpMode == Aidge::Interpolation::Mode::RoundPreferCeil) { + coordInApprox[i] = floorf(coordInApprox[i] + 0.5f); + } else { + coordInApprox[i] = ceilf(coordInApprox[i] - 0.5f); + } + } + + // Convert to integer coordinates + int coordIn[CUDNN_DIM_MAX]; + bool inBounds = true; + for (int i = 0; i < rank; ++i) { + int val = static_cast<int>(coordInApprox[i]); + if (val < 0 || val >= inputDims[i]) { + if (paddingMode == Aidge::PadBorderType::Edge) { + val = max(0, min(val, inputDims[i] - 1)); + } else { + inBounds = false; + break; + } + } + coordIn[i] = val; + } + + if (!inBounds) { + output[idx] = static_cast<T>(extrapolationVal); + return; + } + + // Compute flat input index + int inputIdx = 0; + for (int i = 0; i < rank; ++i) { + inputIdx += coordIn[i] * inputStrides[i]; + } + + output[idx] = input[inputIdx]; + } else { + constexpr int maxNeighbourCount = 2048; + int neighboursCount = 0; + int neighboursCoords[maxNeighbourCount * 4]; // local array in CUDA kernel + T neighbours[maxNeighbourCount]; // local array in CUDA kernel + Aidge::InterpolationCUDA::retrieveNeighboursKernel<T>( + input, + inputDims, + &coordInApprox[0], + scales, + rank, + interpMode, + paddingMode, + antialiasing, + &neighbours[0], + &neighboursCoords[0], + &neighboursCount, + maxNeighbourCount // <-- capacity of buffer + ); + + auto value = Aidge::InterpolationCUDA::interpolate<T>( + &coordInApprox[0], + scales, + &neighboursCoords[0], + &neighbours[0], + neighboursCount, + rank, + interpMode, + cubic_coeff_a, + antialiasing, + excludeOutside, + inputDims + ); + output[idx] = value; + } +} +template <class T> +void Aidge::resizeForward(const T* input, T* output,const std::vector<float> &roi, + const std::vector<int>& inputDims, const std::vector<int>& outputDims, + const std::vector<int>& inputStrides, const std::vector<int>& outputStrides, + const std::vector<float> &scales, + const Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + const Aidge::Interpolation::Mode interpMode, + const Aidge::PadBorderType paddingMode, + float cubic_coeff_a, + float extrapolationVal, + bool antialiasing, + bool excludeOutside, + int size) { + int blockSize = 256; + int numBlocks = (size + blockSize - 1) / blockSize; + const thrust::device_vector<int> d_input_shape = inputDims; + const thrust::device_vector<int> d_output_shape = outputDims; + const thrust::device_vector<int> d_input_strides = inputStrides; + const thrust::device_vector<int> d_output_strides = outputStrides; + const thrust::device_vector<float> d_roi = roi; + const thrust::device_vector<float> d_scales = scales; + resizeKernel<<<numBlocks, blockSize>>>(input, output, thrust::raw_pointer_cast(d_roi.data()), + thrust::raw_pointer_cast(d_input_shape.data()), thrust::raw_pointer_cast(d_output_shape.data()), + thrust::raw_pointer_cast(d_input_strides.data()), thrust::raw_pointer_cast(d_output_strides.data()), + thrust::raw_pointer_cast(d_scales.data()), + coordTransfoMode, interpMode, paddingMode, cubic_coeff_a, extrapolationVal, antialiasing, excludeOutside, outputDims.size(), size); + CHECK_CUDA_STATUS(cudaGetLastError()); + CHECK_CUDA_STATUS(cudaDeviceSynchronize()); +}; + +template void Aidge::resizeForward<double>(const double* input, double* output,const std::vector<float> &roi, + const std::vector<int>& inputDims, const std::vector<int>& outputDims, + const std::vector<int>& inputStrides, const std::vector<int>& outputStrides, + const std::vector<float> &scales, + const Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + const Aidge::Interpolation::Mode interpMode, + const Aidge::PadBorderType paddingMode, + float cubic_coeff_a, + float extrapolationVal, + bool antialiasing, + bool excludeOutside, + int size); + +template void Aidge::resizeForward<float>(const float* input, float* output,const std::vector<float> &roi, + const std::vector<int>& inputDims, const std::vector<int>& outputDims, + const std::vector<int>& inputStrides, const std::vector<int>& outputStrides, + const std::vector<float> &scales, + const Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + const Aidge::Interpolation::Mode interpMode, + const Aidge::PadBorderType paddingMode, + float cubic_coeff_a, + float extrapolationVal, + bool antialiasing, + bool excludeOutside, + int size); + +template void Aidge::resizeForward<half>(const half* input, half* output,const std::vector<float> &roi, + const std::vector<int>& inputDims, const std::vector<int>& outputDims, + const std::vector<int>& inputStrides, const std::vector<int>& outputStrides, + const std::vector<float> &scales, + const Aidge::Interpolation::CoordinateTransformation coordTransfoMode, + const Aidge::Interpolation::Mode interpMode, + const Aidge::PadBorderType paddingMode, + float cubic_coeff_a, + float extrapolationVal, + bool antialiasing, + bool excludeOutside, + int size); + diff --git a/unit_tests/Test_ResizeImpl.cpp b/unit_tests/Test_ResizeImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e2b73ca3a04e5045a65e022168a4036b747a021 --- /dev/null +++ b/unit_tests/Test_ResizeImpl.cpp @@ -0,0 +1,523 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#define alltests +#include <chrono> // std::micro, std::chrono::time_point, + // std::chrono::system_clock +#include <cmath> // std::fabs +#include <cstddef> // std::size_t +#include <cstdint> // std::uint16_t +#include <functional> // std::multiplies +#include <memory> +#include <numeric> // std::accumulate +#include <random> // std::random_device, std::mt19937 + // std::uniform_int_distribution, std::uniform_real_distribution + +#include <catch2/catch_test_macros.hpp> +#include <cuda.h> +#include <fmt/core.h> + +#include "aidge/backend/cpu/data/TensorImpl.hpp" +#include "aidge/backend/cuda/data/TensorImpl.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Resize.hpp" +#include "aidge/utils/ArrayHelpers.hpp" +#include "aidge/utils/TensorUtils.hpp" + +using namespace Aidge; + + +TEST_CASE("[gpu/operator] Resize(forward)") { + Log::setConsoleLevel(Log::Level::Debug); +#ifdef alltests + SECTION("Nearest") { + SECTION("Ceil") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 2, 2>{{ + { + { + { 1.0, 2.0}, + { 3.0, 4.0} + } + } + }}); + input_tensor->setBackend("cuda"); + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 4, 4>{{ + { + { + { 1.0, 1.0, 1.0, 2.0}, + { 1.0, 1.0, 1.0, 2.0}, + { 1.0, 1.0, 1.0, 2.0}, + { 3.0, 3.0, 3.0, 4.0} + } + } + }}); + + std::vector<float> scales = {1.0f, 1.0f, 2.0f, 2.0f}; + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Floor); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + } + SECTION("1-sized input tensor (upscaling)") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 1, 1>{{{{{0.417022}}}}}); + input_tensor->setBackend("cuda"); + + Tensor expectedOutput = Tensor(Array4D<float, 1, 1, 2, 2>{ + {{{{0.417022, 0.417022}, {0.417022, 0.417022}}}}}); + std::vector<std::size_t> sizes = {1, 1, 2, 2}; + auto resize_node = Resize({}, {}, sizes, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Linear); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expectedOutput); + REQUIRE(approxEq<float>(cudaOutput, expectedOutput)); + } + SECTION("Cubic Interpolation") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 4, 4>{{{{ + {1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, + {13.0f, 14.0f, 15.0f, 16.0f}}} + }}); + input_tensor->setBackend("cuda"); + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 8, 8>{{{{ + { 0.47265625, 0.76953125, 1.24609375, 1.87500000, 2.28125000, 2.91015625, 3.38671875, 3.68359375 }, + { 1.66015625, 1.95703125, 2.43359375, 3.06250000, 3.46875000, 4.09765625, 4.57421875, 4.87109375 }, + { 3.56640625, 3.86328125, 4.33984375, 4.96875000, 5.37500000, 6.00390625, 6.48046875, 6.77734375 }, + { 6.08203125, 6.37890625, 6.85546875, 7.48437500, 7.89062500, 8.51953125, 8.99609375, 9.29296875 }, + { 7.70703125, 8.00390625, 8.48046875, 9.10937500, 9.51562500, 10.14453125, 10.62109375, 10.91796875 }, + { 10.22265625, 10.51953125, 10.99609375, 11.62500000, 12.03125000, 12.66015625, 13.13671875, 13.43359375 }, + { 12.12890625, 12.42578125, 12.90234375, 13.53125000, 13.93750000, 14.56640625, 15.04296875, 15.33984375 }, + { 13.31640625, 13.61328125, 14.08984375, 14.71875000, 15.12500000, 15.75390625, 16.23046875, 16.52734375} + }}}}); + + std::vector<float> scales = {1.0f, 1.0f, 2.0f, 2.0f}; // Adjust according to the expected interpolation scale + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + SECTION("CoordinateTransformation") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 4, 4>{{ + { + { + { 1, 2, 3, 4}, + { 5, 6, 7, 8}, + { 9, 10, 11, 12}, + { 13, 14, 15, 16} + } + } + }}); + input_tensor->setBackend("cuda"); + std::vector<float> scales = {1.0, 1.0, 1.5, 1.5}; + SECTION("half_pixel") { + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{ + { + { + {0.56597281, 1.0590283, 1.8657429, 2.4398167, 3.2465289, 3.7395852}, + {2.5381956, 3.03125, 3.837965, 4.4120383, 5.21875, 5.7118063}, + {5.7650542, 6.2581067, 7.0648241, 7.638896, 8.4456081, 8.9386654}, + {8.0613489, 8.5544004, 9.3611183, 9.9351892, 10.7419, 11.234958}, + {11.288198, 11.781251, 12.587969, 13.162039, 13.96875, 14.461807}, + {13.260421, 13.753473, 14.560195, 15.134262, 15.940972, 16.434032} + } + } + }}); + + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + SECTION("half_pixel_symmetric") { + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{ + { + { + {0.56597364, 1.0590289, 1.8657435, 2.4398165, 3.2465293, 3.7395852}, + {2.5381951, 3.03125, 3.837965, 4.4120388, 5.21875, 5.7118063}, + {5.7650528, 6.2581067, 7.0648241, 7.6388974, 8.4456081, 8.9386654}, + {8.0613451, 8.5543985, 9.3611164, 9.9351892, 10.741899, 11.234958}, + {11.288198, 11.781251, 12.587969, 13.162041, 13.96875, 14.461807}, + {13.260421, 13.753473, 14.560195, 15.134266, 15.940972, 16.434032} + } + } + }}); + + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixelSymmetric, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + SECTION("pytorch_half_pixel") { + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{ + { + { + {0.56597281, 1.0590283, 1.8657429, 2.4398167, 3.2465289, 3.7395852}, + {2.5381956, 3.03125, 3.837965, 4.4120383, 5.21875, 5.7118063}, + {5.7650542, 6.2581067, 7.0648241, 7.638896, 8.4456081, 8.9386654}, + {8.0613489, 8.5544004, 9.3611183, 9.9351892, 10.7419, 11.234958}, + {11.288198, 11.781251, 12.587969, 13.162039, 13.96875, 14.461807}, + {13.260421, 13.753473, 14.560195, 15.134262, 15.940972, 16.434032} + } + } + }}); + + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::PytorchHalfPixel, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + SECTION("align_corners") { + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{ + { + { + {1, 1.5039997, 2.2480016, 2.7520003, 3.4959996, 4}, + {3.0159998, 3.5199983, 4.2640014, 4.7680001, 5.5119972, 6.0159988}, + {5.9920053, 6.4960032, 7.2400088, 7.7440076, 8.4880047, 8.9920063}, + {8.0080013, 8.5119991, 9.2560062, 9.760005, 10.504001, 11.008001}, + {10.983998, 11.487995, 12.232005, 12.736004, 13.479996, 13.983998}, + {13, 13.503995, 14.248007, 14.752007, 15.495996, 16} + + } + } + }}); + + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::AlignCorners, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + SECTION("asymmetric") { + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{ + { + { + {1, 1.5740744, 2.3703694, 3, 3.740741, 4.1111126}, + {3.2962966, 3.8703718, 4.666666, 5.2962976, 6.0370393, 6.4074111}, + {6.4814782, 7.0555553, 7.8518462, 8.4814777, 9.2222204, 9.5925922}, + {9, 9.5740776, 10.370367, 11, 11.740744, 12.111115}, + {11.962964, 12.537044, 13.333332, 13.962964, 14.70371, 15.074081}, + {13.444449, 14.018529, 14.814817, 15.44445, 16.185196, 16.555569} + } + } + }}); + + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::Asymmetric, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + SECTION("tf_half_pixel_for_nn") { + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{ + { + { + {2.2962933, 3.037035, 3.6666636, 4.4629622, 5.037034, 5.0925913}, + {5.2592578, 6, 6.6296282, 7.4259291, 8, 8.0555582}, + {7.7777748, 8.5185175, 9.1481438, 9.9444475, 10.518517, 10.574076}, + {10.962964, 11.703708, 12.333335, 13.129641, 13.70371, 13.759269}, + {13.259255, 14, 14.629625, 15.425932, 16, 16.055561}, + {13.48148, 14.222226, 14.851851, 15.648157, 16.222225, 16.277786} + } + } + }}); + + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::TFHalfPixelForNN, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + + + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + } + SECTION("Antialiasing") { + SECTION("Linear") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 6, 6>{{{{ + {0.8910093, 0.06583797, 0.9854186, 0.37615213, 0.5850754, 0.2838311}, + {0.48583397, 0.2452433, 0.88194895, 0.83299613, 0.59694654, 0.6168722}, + {0.13210197, 0.62511665, 0.12456352, 0.1806763, 0.9501198, 0.275962}, + {0.06826835, 0.06309556, 0.02638639, 0.09780658, 0.83262056, 0.63483626}, + {0.21528919, 0.19012423, 0.6685329, 0.61682564, 0.34627828, 0.11396276}, + {0.9585102, 0.50689447, 0.18671302, 0.9953287, 0.857435, 0.8502074}}} + }}); + input_tensor->setBackend("cuda"); + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 3, 3>{{{{ + {0.42198, 0.76913, 0.52068}, + {0.22215, 0.10736, 0.67338}, + {0.4677, 0.61685, 0.54197 } + }}}}); + Tensor expected_out_tensor_antialiasing = Tensor(Array4D<float, 1, 1, 3, 3>{{{{ + {0.50311, 0.61554, 0.50743}, + {0.22988, 0.33839, 0.54707}, + {0.45242, 0.53262, 0.61474} + }}}}); + + std::vector<float> scales = {1.0f, 1.0f, 0.5f, 0.5f}; // Adjust according to the expected interpolation scale + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Linear); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + auto resize_node_antialiasing = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Linear, -.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, true); + auto op_antialiasing = std::static_pointer_cast<Resize_Op>(resize_node_antialiasing->getOperator()); + op_antialiasing->associateInput(0, input_tensor); + op_antialiasing->setDataType(DataType::Float32); + op_antialiasing->setBackend("cuda"); + op_antialiasing->forwardDims(true); + op_antialiasing->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + // REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + CHECK(approxEq<float>(cudaOutput, expected_out_tensor, 1e-5f, 1e-5f) == true); + const auto& cudaOutputAntialiasing = op_antialiasing->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_antialiasing); + REQUIRE(approxEq<float>(cudaOutputAntialiasing, expected_out_tensor_antialiasing, 1e-5f, 1e-5f) == true); + } + #endif + SECTION("Cubic") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 6, 6>{{{{ + {0.8910093, 0.06583797, 0.9854186, 0.37615213, 0.5850754, 0.2838311}, + {0.48583397, 0.2452433, 0.88194895, 0.83299613, 0.59694654, 0.6168722}, + {0.13210197, 0.62511665, 0.12456352, 0.1806763, 0.9501198, 0.275962}, + {0.06826835, 0.06309556, 0.02638639, 0.09780658, 0.83262056, 0.63483626}, + {0.21528919, 0.19012423, 0.6685329, 0.61682564, 0.34627828, 0.11396276}, + {0.9585102, 0.50689447, 0.18671302, 0.9953287, 0.857435, 0.8502074}}} + }}); + input_tensor->setBackend("cuda"); + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 3, 3>{{{{ + { 0.33780938, 0.92826414, 0.5099976 }, + { 0.25023192, -0.14104399, 0.80935836 }, + { 0.4686063, 0.71120584, 0.46677577 } + }}}}); + Tensor expected_out_tensor_antialiasing = Tensor(Array4D<float, 1, 1, 3, 3>{{{{ + { 0.50422025, 0.6695719 , 0.50645566 }, + { 0.16404048, 0.29127517, 0.574788 }, + { 0.41621327, 0.54006433, 0.61721706 } + }}}}); + + std::vector<float> scales = {1.0f, 1.0f, 0.5f, 0.5f}; // Adjust according to the expected interpolation scale + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + auto resize_node_antialiasing = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic, -.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, true); + auto op_antialiasing = std::static_pointer_cast<Resize_Op>(resize_node_antialiasing->getOperator()); + op_antialiasing->associateInput(0, input_tensor); + op_antialiasing->setDataType(DataType::Float32); + op_antialiasing->setBackend("cuda"); + op_antialiasing->forwardDims(true); + op_antialiasing->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor, 1e-5f, 1e-5f) == true); + + const auto& cudaOutputAntialiasing = op_antialiasing->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_antialiasing); + REQUIRE(approxEq<float>(cudaOutputAntialiasing, expected_out_tensor_antialiasing, 1e-5f, 1e-5f) == true); + } + #ifdef alltests + } + SECTION("ROI") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 4, 4>{{{{ + {0.0f, 1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f, 7.0f}, + {8.0f, 9.0f, 10.0f, 11.0f}, + {12.0f, 13.0f, 14.0f, 15.0f}}} + }}); + input_tensor->setBackend("cuda"); + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 2, 2>{{{{ + { 0.0, 3.0 }, + { 12.0, 15.0 } + }}}}); + Tensor expected_out_tensor_roi = Tensor(Array4D<float, 1, 1, 2, 2>{{{{ + { 0.0, 1.5 }, + { 6.0, 7.5 } + }}}}); + + std::vector<float> scales = {1.0f, 1.0f, 0.5f, 0.5f}; + std::vector<float> roi = {0, 0, 0, 0, 1.0f, 1.0f, 0.5f, 0.5f}; + auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::TFCropAndResize, Interpolation::Mode::Linear); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + auto resize_node_roi = Resize(roi, scales, {}, {}, Interpolation::CoordinateTransformation::TFCropAndResize, Interpolation::Mode::Linear); + auto op_roi = std::static_pointer_cast<Resize_Op>(resize_node_roi->getOperator()); + op_roi->associateInput(0, input_tensor); + op_roi->setDataType(DataType::Float32); + op_roi->setBackend("cuda"); + op_roi->forwardDims(true); + op_roi->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + const auto& cudaOutputROI = op_roi->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_roi); + REQUIRE(approxEq<float>(cudaOutputROI, expected_out_tensor_roi)); + } + SECTION("Extrapolation_Value") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 2, 2>{{{{ + {0.0f, 1.0f}, + {2.0f, 3.0f}}} + }}); + input_tensor->setBackend("cuda"); + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 4, 4>{{{{ + { 99.0f, 99.0f, 99.0f, 99.0f }, + { 99.0f, 0.0f, 1.0f, 99.0f }, + { 99.0f, 2.0f, 3.0f, 99.0f }, + { 99.0f, 99.0f, 99.0f, 99.0f } + }}}}); + std::vector<std::size_t> sizes = {1, 1, 4, 4}; + std::vector<float> roi = {0.0f, 0.0f, -0.5f, -0.5f, 1.0f, 1.0f, 1.5f, 1.5f}; + auto resize_node = Resize(roi, {}, sizes, {}, Interpolation::CoordinateTransformation::TFCropAndResize, Interpolation::Mode::RoundPreferFloor, -.75f, 99.0f); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + } + SECTION("Exclude_Outside") { + std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 3, 3>{{{{ + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f}}} + }}); + input_tensor->setBackend("cuda"); + + Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 4, 4>{{{{ + {0.7128906, 1.3144531, 2.2548828, 2.8564453}, + {2.5175781, 3.1191406, 4.0595703, 4.661133}, + {5.338867, 5.9404297, 6.8808594, 7.482422}, + {7.1435547, 7.745117, 8.685547, 9.287109} + }}}}); + Tensor expected_out_tensor_excludeOut = Tensor(Array4D<float, 1, 1, 4, 4>{{{{ + {0.6793893, 1.2565644, 2.2625191, 2.8396947}, + {2.4109144, 2.98809, 3.9940448, 4.5712204}, + {5.4287796, 6.0059547, 7.01191, 7.589085}, + {7.160305, 7.7374797, 8.743436, 9.32061} + }}}}); + + std::vector<std::size_t> sizes = {1, 1, 4, 4}; + auto resize_node = Resize({}, {}, sizes, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic,-.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, false ); + auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator()); + op->associateInput(0, input_tensor); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forwardDims(true); + op->forward(); + + auto resize_node_excludeOut = Resize({}, {}, sizes, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic, -.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, false, true); + auto op_excludeOut = std::static_pointer_cast<Resize_Op>(resize_node_excludeOut->getOperator()); + op_excludeOut->associateInput(0, input_tensor); + op_excludeOut->setDataType(DataType::Float32); + op_excludeOut->setBackend("cuda"); + op_excludeOut->forwardDims(true); + op_excludeOut->forward(); + + std::shared_ptr<Tensor> outputFallback; + const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor); + REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor)); + const auto& cudaOutputExcOut = op_excludeOut->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_excludeOut); + REQUIRE(approxEq<float>(cudaOutputExcOut, expected_out_tensor_excludeOut)); + } + #endif +} \ No newline at end of file