diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp
index dd790290aa8b25097cbc3a0109b57279a7552777..4a706fd0132072220e803c7852d798b3afaa8257 100644
--- a/include/aidge/backend/cuda.hpp
+++ b/include/aidge/backend/cuda.hpp
@@ -41,6 +41,7 @@
 #include "aidge/backend/cuda/operator/ReduceMeanImpl.hpp"
 #include "aidge/backend/cuda/operator/ReduceSumImpl.hpp"
 #include "aidge/backend/cuda/operator/ReLUImpl.hpp"
+#include "aidge/backend/cuda/operator/ResizeImpl.hpp"
 #include "aidge/backend/cuda/operator/RoundImpl.hpp"
 #include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp"
 #include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp"
diff --git a/include/aidge/backend/cuda/data/Interpolation.cuh b/include/aidge/backend/cuda/data/Interpolation.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..316272ac6991a54855148f8e26041b1f65407343
--- /dev/null
+++ b/include/aidge/backend/cuda/data/Interpolation.cuh
@@ -0,0 +1,179 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#ifndef AIDGE_CUDA_DATA_INTERPOLATION_H_
+#define AIDGE_CUDA_DATA_INTERPOLATION_H_
+
+#include "aidge/data/Interpolation.hpp"
+namespace Aidge {
+namespace InterpolationCUDA {
+
+    /**
+     * @brief Computes the approximate input coordinates corresponding to given output coordinates,
+     *        according to the specified coordinate transformation mode.
+     *
+     * This function performs an "untransform" step for spatial interpolation, mapping output-space
+     * integer coordinates (`coordOut`) back into approximate input-space floating-point coordinates
+     * (`coordInApprox`) based on the transformation strategy provided.
+     *
+     * @param coordOut        Pointer to the output coordinates (integer values).
+     * @param inputDims       Pointer to the dimensions of the input tensor.
+     * @param outputDims      Pointer to the dimensions of the output tensor.
+     * @param coordTransfoMode The coordinate transformation mode (e.g., AlignCorners, HalfPixel, etc.).
+     * @param roi             Pointer to region of interest values, used only for TFCropAndResize mode;
+     *                        expected to be of size 2 * rank.
+     * @param coordInApprox   Pointer to the output array where the approximate input coordinates will be stored.
+     * @param rank            The dimensionality of the spatial domain (e.g., 2 for 2D, 3 for 3D).
+     */
+    __device__ void untransformCoordinates(
+        const int* coordOut,
+        const int* inputDims,
+        const int* outputDims,
+        Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+        const float* roi,
+        float* coordInApprox,
+        int rank);
+
+    /**
+     * @brief Retrieves neighboring input tensor values around a set of continuous coordinates for interpolation.
+     *
+     * This function gathers neighboring values from a multidimensional tensor for use in interpolation,
+     * depending on the interpolation mode (e.g., linear or cubic), scaling, padding strategy, and
+     * anti-aliasing settings. The gathered values and their coordinates are stored in output buffers.
+     *
+     * @tparam T               The data type of the input and output tensor values.
+     * @param tensorValues     Pointer to the input tensor values (flattened array).
+     * @param tensorDims       Pointer to the dimensions of the input tensor.
+     * @param coords           Pointer to the floating-point coordinates in the input space.
+     * @param scales           Pointer to scaling factors per dimension.
+     * @param rank             The number of spatial dimensions in the tensor.
+     * @param mode             Interpolation mode (e.g., Linear, Cubic).
+     * @param paddingMode      Padding mode used for out-of-bound accesses (e.g., Zero, Edge).
+     * @param antialiasing     Whether antialiasing is enabled (affects kernel footprint).
+     * @param outValues        Pointer to the buffer where retrieved neighbor values will be stored.
+     * @param outCoords        Pointer to the buffer where the corresponding coordinates of neighbors will be stored.
+     *                         Output shape is [maxNeighbours x rank].
+     * @param outCount         Pointer to a single integer where the number of valid neighbors will be written.
+     * @param maxNeighbours    Maximum number of neighbors to retrieve (capacity of the output buffers).
+     */
+    template <typename T>
+    __device__ void retrieveNeighboursKernel(
+        const T* tensorValues,
+        const int* tensorDims,
+        float* coords,
+        const float* scales,
+        int rank,
+        Aidge::Interpolation::Mode mode,
+        Aidge::PadBorderType paddingMode,
+        bool antialiasing,
+        T* outValues,
+        int* outCoords,
+        int* outCount,
+        int maxNeighbours
+    );
+
+    /**
+     * @brief Performs N-dimensional linear interpolation given neighbor points and their values.
+     *
+     * This function computes a weighted average of neighbor values based on linear interpolation
+     * weights derived from their distance to a target coordinate. It optionally applies antialiasing
+     * scaling to the distances. The weights are normalized to ensure smooth interpolation.
+     *
+     * @tparam T                   The data type of the input and output values.
+     * @param coordToInterpolate  Pointer to the floating-point coordinate to interpolate at (length = rank).
+     * @param scales              Pointer to scale factors per dimension.
+     * @param pointsCoords        Pointer to neighbor coordinates (flattened array of size coordsNbr × rank).
+     * @param pointValues         Pointer to values corresponding to neighbor coordinates.
+     * @param coordsNbr           Number of neighboring points used in the interpolation.
+     * @param rank                Number of spatial dimensions.
+     * @param antialiasing        Whether to apply antialiasing by scaling coordinate deltas.
+     * @return T                  The interpolated value at the target coordinate.
+     */
+    template <typename T>
+    __device__ T interpolateLinear(
+        const float* coordToInterpolate,
+        const float* scales,
+        const int* pointsCoords,
+        const T* pointValues,
+        int coordsNbr,
+        int rank,
+        bool antialiasing
+    );
+
+    /**
+     * @brief Performs N-dimensional cubic interpolation given neighbor points and their values.
+     *
+     * This function computes a weighted average of neighbor values based on cubic interpolation weights,
+     * optionally applying antialiasing and excluding neighbors outside input dimensions.
+     * Static dimensions (with size 1 and scale 1) are ignored in the weight calculation.
+     *
+     * @tparam T                   The data type of the input and output values.
+     * @param coordToInterpolate  Pointer to the floating-point coordinate to interpolate at (length = rank).
+     * @param scales              Pointer to scale factors per dimension.
+     * @param pointsCoords        Pointer to neighbor coordinates (flattened array of size coordsNbr × rank).
+     * @param pointValues         Pointer to values corresponding to neighbor coordinates.
+     * @param coordsNbr           Number of neighboring points used in the interpolation.
+     * @param rank                Number of spatial dimensions.
+     * @param inputDims           Pointer to input tensor dimensions.
+     * @param a                   Cubic interpolation parameter (often -0.75 for Catmull-Rom).
+     * @param antialiasing        Whether to apply antialiasing in weight calculation.
+     * @param excludeOutside      Whether to exclude neighbors outside input dimensions.
+     * @return T                  The interpolated value at the target coordinate.
+     */
+    template <typename T>
+    __device__ T interpolateCubic(
+        float* coordToInterpolate,
+        const float* scales,
+        int* pointsCoords,
+        T* pointValues,
+        int coordsNbr,
+        int rank,
+        const int* inputDims,
+        float a,
+        bool antialiasing,
+        bool excludeOutside
+    );
+
+    /**
+     * @brief Dispatches to the appropriate interpolation method (linear or cubic) based on the mode.
+     *
+     * This function selects the interpolation method and computes the interpolated value
+     * at the given coordinate using either cubic or linear interpolation. Returns zero
+     * for unsupported interpolation modes.
+     *
+     * @tparam T                   The data type of the input and output values.
+     * @param coordToInterpolate  Pointer to the floating-point coordinate to interpolate at (length = rank).
+     * @param scales              Pointer to scale factors per dimension.
+     * @param pointsCoords        Pointer to neighbor coordinates (flattened array of size coordsNbr × rank).
+     * @param pointValues         Pointer to values corresponding to neighbor coordinates.
+     * @param coordsNbr           Number of neighboring points used in the interpolation.
+     * @param rank                Number of spatial dimensions.
+     * @param mode                Interpolation mode (Linear or Cubic).
+     * @param cubicCoeffA         Cubic interpolation coefficient (used only for cubic mode).
+     * @param antialiasing        Whether to apply antialiasing in interpolation.
+     * @param excludeOutside      Whether to exclude neighbors outside input dimensions (only cubic mode).
+     * @param inputDims           Pointer to input tensor dimensions.
+     * @return T                  The interpolated value at the target coordinate.
+     */
+    template <typename T>
+    __device__ T interpolate(float* coordToInterpolate,
+                            const float* scales,
+                            int* pointsCoords,
+                            T* pointValues,
+                            int coordsNbr,
+                            int rank,
+                            Aidge::Interpolation::Mode mode,
+                            float cubicCoeffA,
+                            bool antialiasing,
+                            bool excludeOutside,
+                            const int* inputDims);
+} // namespace InterpolationCUDA
+} // namespace Aidge
+#endif /*AIDGE_CUDA_DATA_INTERPOLATION_H_*/
\ No newline at end of file
diff --git a/include/aidge/backend/cuda/operator/ResizeImpl.hpp b/include/aidge/backend/cuda/operator/ResizeImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..4614ddefe41fc1c125e8e6abec9052a951501ced
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/ResizeImpl.hpp
@@ -0,0 +1,57 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_BACKEND_CUDA_OPERATOR_RESIZEIMPL_H_
+#define AIDGE_BACKEND_CUDA_OPERATOR_RESIZEIMPL_H_
+
+#include <array>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include <cudnn.h>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Resize.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
+
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+// Operator implementation entry point for the backend
+class ResizeImpl_cuda : public OperatorImpl {
+public:
+    ResizeImpl_cuda(const Resize_Op& op) : OperatorImpl(op, "cuda") {}
+
+    static std::unique_ptr<ResizeImpl_cuda> create(const Resize_Op& op) {
+        return std::make_unique<ResizeImpl_cuda>(op);
+    }
+
+    virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
+        return {
+            {DataType::Float64},
+            {DataType::Float32},
+            {DataType::Float16},
+        };
+    }
+
+    void forward() override;
+
+private:
+    template <class T> void forward_();
+};
+
+// Implementation entry point registration to Operator
+REGISTRAR(Resize_Op, "cuda", Aidge::ResizeImpl_cuda::create);
+}  // namespace Aidge
+
+#endif /* AIDGE_BACKEND_CUDA_OPERATOR_RESIZEIMPL_H_ */
diff --git a/include/aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..25bb17064638bd637a07bd079e055988badc72ff
--- /dev/null
+++ b/include/aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp
@@ -0,0 +1,48 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CUDA_OPERATOR_RESIZEIMPL_KERNELS_H_
+#define AIDGE_CUDA_OPERATOR_RESIZEIMPL_KERNELS_H_
+
+#include <stdexcept>
+#include <cfloat>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+#include <cuda_fp16.h>
+
+#include "aidge/data/Data.hpp"
+#include "aidge/data/Interpolation.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+
+namespace Aidge {
+
+
+template <class T>
+void resizeForward(const T* input, T* output,const std::vector<float> &roi,
+                          const std::vector<int>& inputDims, const std::vector<int>& outputDims,
+                          const std::vector<int>& inputStrides, const std::vector<int>& outputStrides,
+                          const std::vector<float> &scales,
+                          const Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+                          const Aidge::Interpolation::Mode interpMode,
+                          const Aidge::PadBorderType paddingMode,
+                          float cubic_coeff_a,
+                          float extrapolationVal,
+                          bool antialiasing,
+                          bool excludeOutside,
+                          int size);
+
+}
+#endif /* AIDGE_CUDA_OPERATOR_RESIZEIMPL_KERNELS_H_ */
+
+
+
+
+
diff --git a/src/data/Interpolation.cu b/src/data/Interpolation.cu
new file mode 100644
index 0000000000000000000000000000000000000000..04a8ab13bdfa8bd93b74ca7ee3d2330e42b56920
--- /dev/null
+++ b/src/data/Interpolation.cu
@@ -0,0 +1,535 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#include <cuda_fp16.h>
+#include <math.h> // For isnan()
+
+#include "aidge/backend/cuda/data/Interpolation.cuh"
+
+#include "aidge/data/Interpolation.hpp"
+template <typename T>
+__device__ float toFloat(T val);
+
+template <>
+__device__ float toFloat(float val) {
+    return val;
+}
+
+template <>
+__device__ float toFloat(double val) {
+    return static_cast<float>(val);
+}
+
+template <>
+__device__ float toFloat(__half val) {
+    return __half2float(val);
+}
+
+template <typename T>
+__device__ T fromFloat(float val);
+
+template <>
+__device__ float fromFloat<float>(float val) {
+    return val;
+}
+
+template <>
+__device__ double fromFloat<double>(float val) {
+    return static_cast<double>(val);
+}
+
+template <>
+__device__ __half fromFloat<__half>(float val) {
+    return __float2half(val);
+}
+
+__device__ void Aidge::InterpolationCUDA::untransformCoordinates(
+    const int* coordOut,
+    const int* inputDims,
+    const int* outputDims,
+    Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+    const float* roi,
+    float* coordInApprox,
+    int rank)
+{
+    for (int i = 0; i < rank; ++i) {
+        float scale = static_cast<float>(outputDims[i]) /
+                      static_cast<float>(inputDims[i]);
+
+        switch (coordTransfoMode) {
+        case Aidge::Interpolation::CoordinateTransformation::AlignCorners:
+            if (inputDims[i] == 1 || outputDims[i] == 1) {
+                coordInApprox[i] = 0.0f;
+            } else {
+                coordInApprox[i] = coordOut[i] * static_cast<float>(inputDims[i] - 1) /
+                                                   static_cast<float>(outputDims[i] - 1);
+            }
+            break;
+
+        case Aidge::Interpolation::CoordinateTransformation::Asymmetric:
+            coordInApprox[i] = coordOut[i] / scale;
+            break;
+
+        case Aidge::Interpolation::CoordinateTransformation::HalfPixel:
+        case Aidge::Interpolation::CoordinateTransformation::HalfPixelSymmetric:
+            coordInApprox[i] = (coordOut[i] + 0.5f) / scale - 0.5f;
+            break;
+
+        case Aidge::Interpolation::CoordinateTransformation::PytorchHalfPixel:
+            coordInApprox[i] = static_cast<float>(outputDims[i]) > 1
+                                   ? (coordOut[i] + 0.5f) / scale - 0.5f
+                                   : 0.0f;
+            break;
+
+        case Aidge::Interpolation::CoordinateTransformation::TFHalfPixelForNN:
+            coordInApprox[i] = (coordOut[i] + 0.5f) / scale;
+            break;
+
+        case Aidge::Interpolation::CoordinateTransformation::TFCropAndResize:
+        {
+            float roiStart = roi[i];
+            float roiEnd = roi[i + rank];
+            if (outputDims[i] > 1) {
+                coordInApprox[i] = roiStart * (inputDims[i] - 1) +
+                                   coordOut[i] * (roiEnd - roiStart) *
+                                   static_cast<float>(inputDims[i] - 1) /
+                                   static_cast<float>(outputDims[i] - 1);
+            } else {
+                coordInApprox[i] = 0.5f * (roiStart + roiEnd) * (inputDims[i] - 1);
+            }
+
+            if (coordInApprox[i] < 0.0f || coordInApprox[i] > static_cast<float>(inputDims[i] - 1)) {
+                coordInApprox[i] = NAN;
+            }
+            break;
+        }
+        }
+    }
+}
+
+template <typename T>
+__device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel(
+    const T* tensorValues,
+    const int* tensorDims,
+    float* coords,
+    const float* scales,
+    int rank,
+    Aidge::Interpolation::Mode mode,
+    Aidge::PadBorderType paddingMode,
+    bool antialiasing,
+    T* outValues,
+    int* outCoords,
+    int* outCount,
+    int maxNeighbours
+) {
+    int count = 0;
+
+    // Offset buffer per dimension
+    const int maxKernelSize = 16;
+    int offsetRanges[10][maxKernelSize];
+    int offsetSizes[10];
+    for (int dim = 0; dim < rank; ++dim) {
+        if (mode == Aidge::Interpolation::Mode::Cubic) {
+            float scale = scales[dim];
+            int kernelWidth = 4;
+
+            int center = (int)floorf(coords[dim]);
+
+            int kernel_size;
+            int start;
+
+            if (antialiasing && scale < 1.0f) {
+                float footprint = kernelWidth / scale;
+                int radius = (int)ceilf((footprint - 1.0f) / 2.0f);
+                kernel_size = 2 * radius + 1;
+                offsetSizes[dim] = kernel_size;
+
+                start = center - radius;
+                for (int i = 0; i < kernel_size; ++i) {
+                    offsetRanges[dim][i] = start + i;
+                }
+            } else {
+                // Default cubic kernel
+                kernel_size = 4;
+                offsetSizes[dim] = kernel_size;
+
+                start = center - 1;
+                for (int i = 0; i < kernel_size; ++i) {
+                    offsetRanges[dim][i] = start + i;
+                }
+            }
+        } else if(mode == Aidge::Interpolation::Mode::Linear && antialiasing) {
+            float scale = scales[dim];
+
+            // Only scale height and width (assuming batch and channels are unscaled)
+            if (dim >= 2 && scale < 1.0f) {
+                int kernel_size = max(2, (int)ceilf(2.0f / scale));
+                offsetSizes[dim] = kernel_size;
+
+                int center = (int)floorf(coords[dim]);
+                int start = center - (kernel_size - 1) / 2;
+
+                for (int k = 0; k < kernel_size; ++k) {
+                    offsetRanges[dim][k] = start + k;
+                }
+            } else {
+                offsetSizes[dim] = 2;
+                int center = (int)floorf(coords[dim]);
+                for (int k = 0; k < 2; ++k) {
+                    offsetRanges[dim][k] = center + k;
+                }
+            }
+        }
+        else {
+            offsetSizes[dim] = 2;
+            offsetRanges[dim][0] = (int)ceilf(coords[dim]);
+            offsetRanges[dim][1] = (int)floorf(coords[dim]);
+        }
+
+    }
+
+    int loopCounts[10] = {0};
+
+    while (true) {
+        // Build coordinate
+        int currentCoords[10];
+        for (int i = 0; i < rank; ++i) {
+            currentCoords[i] = offsetRanges[i][loopCounts[i]];
+        }
+
+        // Handle padding
+        bool valid = true;
+        int clampedCoords[10];
+        for (int i = 0; i < rank; ++i) {
+            int coord = currentCoords[i];
+            switch (paddingMode) {
+                case Aidge::PadBorderType::Edge:
+                    coord = min(max(coord, 0), tensorDims[i] - 1);
+                    break;
+                case Aidge::PadBorderType::Zero:
+                    if (coord < 0 || coord >= tensorDims[i]) {
+                        valid = false;
+                    }
+                    break;
+                default:
+                    valid = false;  // Unsupported padding
+                    break;
+            }
+            clampedCoords[i] = coord;
+        }
+
+        if (valid && count < maxNeighbours) {
+            int flatIdx = 0;
+            int stride = 1;
+            for (int i = rank - 1; i >= 0; --i) {
+                flatIdx += clampedCoords[i] * stride;
+                stride *= tensorDims[i];
+            }
+
+            // Store output
+            for (int j = 0; j < rank; ++j) {
+                outCoords[count * rank + j] = currentCoords[j];
+            }
+            outValues[count] = tensorValues[flatIdx];
+            ++count;
+        }
+
+        // Next coordinate combination
+        int dim = rank - 1;
+        while (dim >= 0) {
+            loopCounts[dim]++;
+            if (loopCounts[dim] < offsetSizes[dim]) break;
+            loopCounts[dim] = 0;
+            --dim;
+        }
+        if (dim < 0) break;
+    }
+
+    *outCount = count;
+}
+
+template <typename T>
+__device__ T Aidge::InterpolationCUDA::interpolateLinear(
+    const float* coordToInterpolate,
+    const float* scales,
+    const int* pointsCoords,
+    const T* pointValues,
+    int coordsNbr,
+    int rank,
+    bool antialiasing
+) {
+    float resultAccum = 0.0f;
+    float totalWeight = 0.0f;
+
+    for (int i = 0; i < coordsNbr; ++i) {
+        float weight = 1.0f;
+        bool skip = false;
+
+        for (int d = 0; d < rank; ++d) {
+            float dx = coordToInterpolate[d] - static_cast<float>(pointsCoords[i * rank + d]);
+            float scaledDx = antialiasing ? dx * fminf(scales[d], 1.0f) : dx;
+
+            float w = fmaxf(0.0f, 1.0f - fabsf(scaledDx));
+            if (w == 0.0f) {
+                skip = true;
+                break;
+            }
+            weight *= w;
+        }
+
+        if (skip) continue;
+
+        resultAccum += toFloat(pointValues[i]) * weight;
+        totalWeight += weight;
+    }
+
+    float finalValue = (totalWeight > 0.0f) ? resultAccum / totalWeight : 0.0f;
+    return fromFloat<T>(finalValue);
+}
+
+template <typename T>
+__device__ float cubicWeight(float x, float a) {
+    x = fabsf(x);
+    if (x < 1.0f) {
+        return ((a + 2.0f) * x * x * x - (a + 3.0f) * x * x + 1.0f);
+    } else if (x < 2.0f) {
+        return (a * x * x * x - 5.0f * a * x * x + 8.0f * a * x - 4.0f * a);
+    } else {
+        return 0.0f;
+    }
+}
+
+template <typename T>
+__device__ float cubicWeightAntialiased(float x, float scale, float a) {
+    float s = fminf(scale, 1.0f);
+    x *= s;
+    x = fabsf(x);
+    float x2 = x * x;
+    float x3 = x2 * x;
+    if (x <= 1.0f) {
+        return ((a + 2.0f) * x3 - (a + 3.0f) * x2 + 1.0f);
+    } else if (x < 2.0f) {
+        return (a * x3 - 5.0f * a * x2 + 8.0f * a * x - 4.0f * a);
+    } else {
+        return 0.0f;
+    }
+}
+
+template <typename T>
+__device__ T Aidge::InterpolationCUDA::interpolateCubic(
+    float* coordToInterpolate,
+    const float* scales,
+    int* pointsCoords,
+    T* pointValues,
+    int coordsNbr,
+    int rank,
+    const int* inputDims,
+    float a,
+    bool antialiasing,
+    bool excludeOutside
+) {
+    float resultAccum = 0.0f;
+    float totalWeight = 0.0f;
+
+    for (int i = 0; i < coordsNbr; ++i) {
+        float w = 1.0f;
+        bool skip = false;
+
+        for (int d = 0; d < rank; ++d) {
+            int coord = pointsCoords[i * rank + d];
+            float x = coordToInterpolate[d];
+
+            // Skip static dimensions
+            if (inputDims[d] == 1 && scales[d] == 1.0f) {
+                continue;
+            }
+
+            if (excludeOutside && (coord < 0 || coord >= inputDims[d])) {
+                skip = true;
+                break;
+            }
+
+            float dx = x - static_cast<float>(coord);
+            float weight = antialiasing
+                ? cubicWeightAntialiased<T>(dx, scales[d], a)
+                : cubicWeight<T>(dx, a);
+
+            w *= weight;
+        }
+
+        if (skip || w == 0.0f) continue;
+
+        resultAccum += toFloat(pointValues[i]) * w;
+        totalWeight += w;
+    }
+
+    float finalValue = (totalWeight > 0.0f) ? resultAccum / totalWeight : 0.0f;
+    return fromFloat<T>(finalValue);
+}
+
+template <typename T>
+__device__ T Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate,
+                        const float* scales,
+                        int* pointsCoords,
+                        T* pointValues,
+                        int coordsNbr,
+                        int rank,
+                        Aidge::Interpolation::Mode mode,
+                        float cubicCoeffA,
+                        bool antialiasing,
+                        bool excludeOutside,
+                        const int* inputDims) {
+    switch (mode) {
+        case Aidge::Interpolation::Mode::Cubic: return interpolateCubic<T>(coordToInterpolate, scales, pointsCoords, pointValues, coordsNbr, rank, inputDims, cubicCoeffA, antialiasing, excludeOutside);
+        case Aidge::Interpolation::Mode::Linear: return interpolateLinear<T>(coordToInterpolate, scales, pointsCoords, pointValues, coordsNbr, rank, antialiasing);
+        default: return T(0);
+    }
+}
+
+/////// templates instantiations
+
+template __device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel(
+    const double* tensorValues,
+    const int* tensorDims,
+    float* coords,
+    const float* scales,
+    int rank,
+    Aidge::Interpolation::Mode mode,
+    Aidge::PadBorderType paddingMode,
+    bool antialiasing,
+    double* outValues,
+    int* outCoords,
+    int* outCount,
+    int maxNeighbours);
+template __device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel(
+    const float* tensorValues,
+    const int* tensorDims,
+    float* coords,
+    const float* scales,
+    int rank,
+    Aidge::Interpolation::Mode mode,
+    Aidge::PadBorderType paddingMode,
+    bool antialiasing,
+    float* outValues,
+    int* outCoords,
+    int* outCount,
+    int maxNeighbours);
+template __device__ void Aidge::InterpolationCUDA::retrieveNeighboursKernel(
+    const half* tensorValues,
+    const int* tensorDims,
+    float* coords,
+    const float* scales,
+    int rank,
+    Aidge::Interpolation::Mode mode,
+    Aidge::PadBorderType paddingMode,
+    bool antialiasing,
+    half* outValues,
+    int* outCoords,
+    int* outCount,
+    int maxNeighbours);
+
+template __device__ double Aidge::InterpolationCUDA::interpolateLinear(
+    const float* coordToInterpolate,
+    const float* scales,
+    const int* pointsCoords,
+    const double* pointValues,
+    int coordsNbr,
+    int rank,
+    bool antialiasing
+);
+template __device__ float Aidge::InterpolationCUDA::interpolateLinear(
+    const float* coordToInterpolate,
+    const float* scales,
+    const int* pointsCoords,
+    const float* pointValues,
+    int coordsNbr,
+    int rank,
+    bool antialiasing
+);
+template __device__ half Aidge::InterpolationCUDA::interpolateLinear(
+    const float* coordToInterpolate,
+    const float* scales,
+    const int* pointsCoords,
+    const half* pointValues,
+    int coordsNbr,
+    int rank,
+    bool antialiasing
+);
+
+template __device__ double Aidge::InterpolationCUDA::interpolateCubic(
+    float* coordToInterpolate,
+    const float* scales,
+    int* pointsCoords,
+    double* pointValues,
+    int coordsNbr,
+    int rank,
+    const int* inputDims,
+    float a,
+    bool antialiasing,
+    bool excludeOutside
+);
+template __device__ float Aidge::InterpolationCUDA::interpolateCubic(
+    float* coordToInterpolate,
+    const float* scales,
+    int* pointsCoords,
+    float* pointValues,
+    int coordsNbr,
+    int rank,
+    const int* inputDims,
+    float a,
+    bool antialiasing,
+    bool excludeOutside
+);
+template __device__ half Aidge::InterpolationCUDA::interpolateCubic(
+    float* coordToInterpolate,
+    const float* scales,
+    int* pointsCoords,
+    half* pointValues,
+    int coordsNbr,
+    int rank,
+    const int* inputDims,
+    float a,
+    bool antialiasing,
+    bool excludeOutside
+);
+
+template __device__ double Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate,
+                        const float* scales,
+                        int* pointsCoords,
+                        double* pointValues,
+                        int coordsNbr,
+                        int rank,
+                        Aidge::Interpolation::Mode mode,
+                        float cubicCoeffA,
+                        bool antialiasing,
+                        bool excludeOutside,
+                        const int* inputDims);
+template __device__ float Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate,
+                        const float* scales,
+                        int* pointsCoords,
+                        float* pointValues,
+                        int coordsNbr,
+                        int rank,
+                        Aidge::Interpolation::Mode mode,
+                        float cubicCoeffA,
+                        bool antialiasing,
+                        bool excludeOutside,
+                        const int* inputDims);
+template __device__ half Aidge::InterpolationCUDA::interpolate(float* coordToInterpolate,
+                        const float* scales,
+                        int* pointsCoords,
+                        half* pointValues,
+                        int coordsNbr,
+                        int rank,
+                        Aidge::Interpolation::Mode mode,
+                        float cubicCoeffA,
+                        bool antialiasing,
+                        bool excludeOutside,
+                        const int* inputDims);
\ No newline at end of file
diff --git a/src/operator/ResizeImpl.cpp b/src/operator/ResizeImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5a397149dc6cfdc9e630460eca664a736ce1d7c7
--- /dev/null
+++ b/src/operator/ResizeImpl.cpp
@@ -0,0 +1,96 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <algorithm>
+#include <cassert>
+#include <numeric>
+#include <vector>
+#include <iostream>
+
+#include <cuda_fp16.h>
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/operator/ResizeImpl.hpp"
+#include "aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/backend/cuda/utils/CudaUtils.hpp"
+#include "aidge/operator/Resize.hpp"
+#include "aidge/utils/Types.h"
+
+void Aidge::ResizeImpl_cuda::forward() {
+    const Resize_Op& op = static_cast<const Resize_Op&>(mOp);
+    // Check inputs
+    AIDGE_ASSERT(op.getInput(0), "missing input in Resize operator");
+    AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run Resize forward because the 0-th input has no implementation.");
+    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
+        case DataType::Float64:
+            forward_<double>();
+            break;
+        case DataType::Float32:
+            forward_<float>();
+            break;
+        case DataType::Float16:
+            forward_<half>();
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    }
+}
+
+template <class T>
+void Aidge::ResizeImpl_cuda::forward_() 
+{
+    const Resize_Op& op = static_cast<const Resize_Op&>(mOp);
+    // int size = op.getInput(0)->size();
+    const T* inputPtr = static_cast<T*>(op.getInput(0)->getImpl()->rawPtr());
+    T* outputPtr = static_cast<T*>(op.getOutput(0)->getImpl()->rawPtr());
+    std::vector<int> inputDimsInt;
+    inputDimsInt.reserve(op.getInput(0)->dims().size());
+    for (size_t val : op.getInput(0)->dims()) {
+        inputDimsInt.push_back(static_cast<int>(val));
+    }
+    std::vector<int> outputDimsInt;
+    outputDimsInt.reserve(op.getOutput(0)->dims().size());
+    for (size_t val : op.getOutput(0)->dims()) {
+        outputDimsInt.push_back(static_cast<int>(val));
+    }
+    std::vector<int> outputStrides(op.getOutput(0)->nbDims(), 1);
+    if(op.getOutput(0)->nbDims()>1) {
+        for (int i = op.getOutput(0)->nbDims()-2; i >= 0; i--) {
+            outputStrides[i] = outputStrides[i+1] *  op.getOutput(0)->dims()[i+1];
+        }
+    }
+    std::vector<int> inputStrides(op.getInput(0)->nbDims(), 1);
+    if(op.getInput(0)->nbDims()>1) {
+        for (int i = op.getInput(0)->nbDims()-2; i >= 0; i--) {
+            inputStrides[i] = inputStrides[i+1] *  op.getInput(0)->dims()[i+1];
+        }
+    }
+    std::vector<float> scales;
+    scales.reserve(op.getInput(0)->dims().size());
+    
+    for (size_t i = 0; i < op.getInput(0)->dims().size(); ++i) {
+        scales.push_back(static_cast<float>(op.getOutput(0)->dims()[i]) / static_cast<float>(op.getInput(0)->dims()[i]));
+    }
+
+    Aidge::resizeForward<T>(inputPtr,outputPtr, op.roi(),
+                            inputDimsInt, outputDimsInt,
+                            inputStrides, outputStrides,
+                            scales,
+                            op.coordinateTransformationMode(),
+                            op.interpolationMode(),
+                            op.paddingMode(),
+                            op.cubic_coeff_a(),
+                            op.extrapolation_val(),
+                            op.antialiasing(),
+                            op.exclude_outside(),
+                            static_cast<int>(op.getOutput(0)->size()));
+}
diff --git a/src/operator/ResizeImpl_CUDA_kernels.cu b/src/operator/ResizeImpl_CUDA_kernels.cu
new file mode 100644
index 0000000000000000000000000000000000000000..4008f03d17779a53f5d400f580debf947d6efe27
--- /dev/null
+++ b/src/operator/ResizeImpl_CUDA_kernels.cu
@@ -0,0 +1,216 @@
+/********************************************************************************
+ * Copyright (c) 2025 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#include <cuda_fp16.h>
+#include <math.h> // For isnan()
+#include <thrust/device_vector.h>
+#include <type_traits>
+
+#include "aidge/backend/cuda/operator/ResizeImpl_CUDA_kernels.hpp"
+#include "aidge/backend/cuda/data/Interpolation.cuh"
+#include "aidge/data/Interpolation.hpp"
+
+
+template <class T>
+__global__ void resizeKernel(const T* input, T* output,
+                            const float* roi,
+                            const int* inputDims, const int* outputDims,
+                            const int* inputStrides, const int* outputStrides,
+                            const float* scales,
+                            Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+                            Aidge::Interpolation::Mode interpMode,
+                            Aidge::PadBorderType paddingMode,
+                            float cubic_coeff_a,
+                            float extrapolationVal,
+                            bool antialiasing,
+                            bool excludeOutside,
+                            int rank,
+                            int totalElements) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx >= totalElements) return;
+
+    // Compute output coordinates from flat index
+    int coordOut[CUDNN_DIM_MAX];
+    int tmp = idx;
+    for (int i = 0; i < rank; ++i) {
+        coordOut[i] = tmp / outputStrides[i];
+        tmp %= outputStrides[i];
+    }
+
+
+    // Compute input coordinates (float) using provided coordinate transformation
+    float coordInApprox[CUDNN_DIM_MAX];
+
+    Aidge::InterpolationCUDA::untransformCoordinates(&coordOut[0],
+                            &inputDims[0],
+                            &outputDims[0],
+                            coordTransfoMode,
+                            &roi[0],
+                            &coordInApprox[0],
+                            rank);
+
+    for (int i = 0; i < rank; ++i) {
+        if (isnan(coordInApprox[i])) {
+            output[idx] = static_cast<T>(extrapolationVal);
+            return;
+        }
+    }
+
+    // For rounding-style modes
+    if ((interpMode == Aidge::Interpolation::Mode::Ceil) ||
+        (interpMode == Aidge::Interpolation::Mode::Floor) ||
+        (interpMode == Aidge::Interpolation::Mode::RoundPreferCeil) ||
+        (interpMode == Aidge::Interpolation::Mode::RoundPreferFloor)) {
+
+        for (int i = 0; i < rank; ++i) {
+            if (interpMode == Aidge::Interpolation::Mode::Ceil) {
+                coordInApprox[i] = ceilf(coordInApprox[i]);
+            } else if (interpMode == Aidge::Interpolation::Mode::Floor) {
+                coordInApprox[i] = floorf(coordInApprox[i]);
+            } else if (interpMode == Aidge::Interpolation::Mode::RoundPreferCeil) {
+                coordInApprox[i] = floorf(coordInApprox[i] + 0.5f);
+            } else {
+                coordInApprox[i] = ceilf(coordInApprox[i] - 0.5f);
+            }
+        }
+
+        // Convert to integer coordinates
+        int coordIn[CUDNN_DIM_MAX];
+        bool inBounds = true;
+        for (int i = 0; i < rank; ++i) {
+            int val = static_cast<int>(coordInApprox[i]);
+            if (val < 0 || val >= inputDims[i]) {
+                if (paddingMode == Aidge::PadBorderType::Edge) {
+                    val = max(0, min(val, inputDims[i] - 1));
+                } else {
+                    inBounds = false;
+                    break;
+                }
+            }
+            coordIn[i] = val;
+        }
+
+        if (!inBounds) {
+            output[idx] = static_cast<T>(extrapolationVal);
+            return;
+        }
+
+        // Compute flat input index
+        int inputIdx = 0;
+        for (int i = 0; i < rank; ++i) {
+            inputIdx += coordIn[i] * inputStrides[i];
+        }
+
+        output[idx] = input[inputIdx];
+    } else {
+    constexpr int maxNeighbourCount = 2048;
+    int neighboursCount = 0;
+    int neighboursCoords[maxNeighbourCount * 4]; // local array in CUDA kernel
+    T neighbours[maxNeighbourCount]; // local array in CUDA kernel
+    Aidge::InterpolationCUDA::retrieveNeighboursKernel<T>(
+        input,
+        inputDims,
+        &coordInApprox[0],
+        scales,
+        rank,
+        interpMode,
+        paddingMode,
+        antialiasing,
+        &neighbours[0],
+        &neighboursCoords[0],
+        &neighboursCount,
+        maxNeighbourCount // <-- capacity of buffer
+    );
+
+        auto value = Aidge::InterpolationCUDA::interpolate<T>(
+            &coordInApprox[0],
+            scales,
+            &neighboursCoords[0],
+            &neighbours[0],
+            neighboursCount,
+            rank,
+            interpMode,
+            cubic_coeff_a,
+            antialiasing,
+            excludeOutside,
+            inputDims
+        );
+        output[idx] = value;
+    }
+}
+template <class T>
+void Aidge::resizeForward(const T* input, T* output,const std::vector<float> &roi,
+                          const std::vector<int>& inputDims, const std::vector<int>& outputDims,
+                          const std::vector<int>& inputStrides, const std::vector<int>& outputStrides,
+                          const std::vector<float> &scales,
+                          const Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+                          const Aidge::Interpolation::Mode interpMode,
+                          const Aidge::PadBorderType paddingMode,
+                          float cubic_coeff_a,
+                          float extrapolationVal,
+                          bool antialiasing,
+                          bool excludeOutside,
+                          int size) {
+    int blockSize = 256;
+    int numBlocks = (size + blockSize - 1) / blockSize;
+    const thrust::device_vector<int> d_input_shape = inputDims;
+    const thrust::device_vector<int> d_output_shape = outputDims;
+    const thrust::device_vector<int> d_input_strides = inputStrides;
+    const thrust::device_vector<int> d_output_strides = outputStrides;
+    const thrust::device_vector<float> d_roi = roi;
+    const thrust::device_vector<float> d_scales = scales;
+    resizeKernel<<<numBlocks, blockSize>>>(input, output, thrust::raw_pointer_cast(d_roi.data()),
+                                        thrust::raw_pointer_cast(d_input_shape.data()), thrust::raw_pointer_cast(d_output_shape.data()),
+                                        thrust::raw_pointer_cast(d_input_strides.data()), thrust::raw_pointer_cast(d_output_strides.data()),
+                                        thrust::raw_pointer_cast(d_scales.data()),
+                                        coordTransfoMode, interpMode, paddingMode, cubic_coeff_a, extrapolationVal, antialiasing, excludeOutside, outputDims.size(), size);
+    CHECK_CUDA_STATUS(cudaGetLastError());
+    CHECK_CUDA_STATUS(cudaDeviceSynchronize());
+};
+
+template void Aidge::resizeForward<double>(const double* input, double* output,const std::vector<float> &roi,
+                                        const std::vector<int>& inputDims, const std::vector<int>& outputDims,
+                                        const std::vector<int>& inputStrides, const std::vector<int>& outputStrides,
+                                        const std::vector<float> &scales,
+                                        const Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+                                        const Aidge::Interpolation::Mode interpMode,
+                                        const Aidge::PadBorderType paddingMode,
+                                        float cubic_coeff_a,
+                                        float extrapolationVal,
+                                        bool antialiasing,
+                                        bool excludeOutside,
+                                        int size);
+
+template void Aidge::resizeForward<float>(const float* input, float* output,const std::vector<float> &roi,
+                                        const std::vector<int>& inputDims, const std::vector<int>& outputDims,
+                                        const std::vector<int>& inputStrides, const std::vector<int>& outputStrides,
+                                        const std::vector<float> &scales,
+                                        const Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+                                        const Aidge::Interpolation::Mode interpMode,
+                                        const Aidge::PadBorderType paddingMode,
+                                        float cubic_coeff_a,
+                                        float extrapolationVal,
+                                        bool antialiasing,
+                                        bool excludeOutside,
+                                        int size);
+
+template void Aidge::resizeForward<half>(const half* input, half* output,const std::vector<float> &roi,
+                                        const std::vector<int>& inputDims, const std::vector<int>& outputDims,
+                                        const std::vector<int>& inputStrides, const std::vector<int>& outputStrides,
+                                        const std::vector<float> &scales,
+                                        const Aidge::Interpolation::CoordinateTransformation coordTransfoMode,
+                                        const Aidge::Interpolation::Mode interpMode,
+                                        const Aidge::PadBorderType paddingMode,
+                                        float cubic_coeff_a,
+                                        float extrapolationVal,
+                                        bool antialiasing,
+                                        bool excludeOutside,
+                                        int size);
+
diff --git a/unit_tests/Test_ResizeImpl.cpp b/unit_tests/Test_ResizeImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0e2b73ca3a04e5045a65e022168a4036b747a021
--- /dev/null
+++ b/unit_tests/Test_ResizeImpl.cpp
@@ -0,0 +1,523 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#define alltests
+#include <chrono>      // std::micro, std::chrono::time_point,
+                       // std::chrono::system_clock
+#include <cmath>       // std::fabs
+#include <cstddef>     // std::size_t
+#include <cstdint>     // std::uint16_t
+#include <functional>  // std::multiplies
+#include <memory>
+#include <numeric>     // std::accumulate
+#include <random>      // std::random_device, std::mt19937
+                       // std::uniform_int_distribution, std::uniform_real_distribution
+
+#include <catch2/catch_test_macros.hpp>
+#include <cuda.h>
+#include <fmt/core.h>
+
+#include "aidge/backend/cpu/data/TensorImpl.hpp"
+#include "aidge/backend/cuda/data/TensorImpl.hpp"
+#include "aidge/data/Data.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/Resize.hpp"
+#include "aidge/utils/ArrayHelpers.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+using namespace Aidge;
+
+
+TEST_CASE("[gpu/operator] Resize(forward)") {
+    Log::setConsoleLevel(Log::Level::Debug);
+#ifdef alltests
+    SECTION("Nearest") {
+        SECTION("Ceil") {
+            std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 2, 2>{{
+                {
+                    {
+                        { 1.0, 2.0},
+                        { 3.0, 4.0}
+                    }
+                }
+            }});
+            input_tensor->setBackend("cuda");
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 4, 4>{{
+                {
+                    {
+                        { 1.0, 1.0, 1.0, 2.0},
+                        { 1.0, 1.0, 1.0, 2.0},
+                        { 1.0, 1.0, 1.0, 2.0},
+                        { 3.0, 3.0, 3.0, 4.0}
+                    }
+                }
+            }});
+
+            std::vector<float> scales = {1.0f, 1.0f, 2.0f, 2.0f};
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Floor);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+
+
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        }
+    }
+    SECTION("1-sized input tensor (upscaling)") {
+        std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 1, 1>{{{{{0.417022}}}}});
+        input_tensor->setBackend("cuda");
+
+        Tensor expectedOutput = Tensor(Array4D<float, 1, 1, 2, 2>{
+            {{{{0.417022, 0.417022}, {0.417022, 0.417022}}}}});
+        std::vector<std::size_t> sizes = {1, 1, 2, 2};
+        auto resize_node = Resize({}, {}, sizes, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Linear);
+        auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+        op->associateInput(0, input_tensor);
+
+
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forwardDims(true);
+        op->forward();
+
+        std::shared_ptr<Tensor> outputFallback;
+        const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expectedOutput);
+        REQUIRE(approxEq<float>(cudaOutput, expectedOutput));
+    }
+    SECTION("Cubic Interpolation") {
+        std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 4, 4>{{{{
+            {1.0f, 2.0f, 3.0f, 4.0f},
+            {5.0f, 6.0f, 7.0f, 8.0f},
+            {9.0f, 10.0f, 11.0f, 12.0f},
+            {13.0f, 14.0f, 15.0f, 16.0f}}}
+        }});
+        input_tensor->setBackend("cuda");
+        Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 8, 8>{{{{
+            { 0.47265625, 0.76953125, 1.24609375, 1.87500000, 2.28125000, 2.91015625, 3.38671875, 3.68359375 },
+            { 1.66015625, 1.95703125, 2.43359375, 3.06250000, 3.46875000, 4.09765625, 4.57421875, 4.87109375 },
+            { 3.56640625, 3.86328125, 4.33984375, 4.96875000, 5.37500000, 6.00390625, 6.48046875, 6.77734375 },
+            { 6.08203125, 6.37890625, 6.85546875, 7.48437500, 7.89062500, 8.51953125, 8.99609375, 9.29296875 },
+            { 7.70703125, 8.00390625, 8.48046875, 9.10937500, 9.51562500, 10.14453125, 10.62109375, 10.91796875 },
+            { 10.22265625, 10.51953125, 10.99609375, 11.62500000, 12.03125000, 12.66015625, 13.13671875, 13.43359375 },
+            { 12.12890625, 12.42578125, 12.90234375, 13.53125000, 13.93750000, 14.56640625, 15.04296875, 15.33984375 },
+            { 13.31640625, 13.61328125, 14.08984375, 14.71875000, 15.12500000, 15.75390625, 16.23046875, 16.52734375}
+        }}}});      
+
+        std::vector<float> scales = {1.0f, 1.0f, 2.0f, 2.0f}; // Adjust according to the expected interpolation scale
+        auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic);
+        auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+        op->associateInput(0, input_tensor);
+    
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forwardDims(true);
+        op->forward();
+
+        std::shared_ptr<Tensor> outputFallback;
+        const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+        REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+    }
+ SECTION("CoordinateTransformation") {
+        std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 4, 4>{{
+            {
+                {
+                    {  1,  2,  3,  4},
+                    {  5,  6,  7,  8},
+                    {  9, 10, 11, 12},
+                    { 13, 14, 15, 16}
+                }
+            }
+        }});
+        input_tensor->setBackend("cuda");
+        std::vector<float> scales = {1.0, 1.0, 1.5, 1.5};
+        SECTION("half_pixel") {
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{
+                {
+                    {
+                        {0.56597281, 1.0590283, 1.8657429, 2.4398167, 3.2465289, 3.7395852},
+                        {2.5381956, 3.03125, 3.837965, 4.4120383, 5.21875, 5.7118063},
+                        {5.7650542, 6.2581067, 7.0648241, 7.638896, 8.4456081, 8.9386654},
+                        {8.0613489, 8.5544004, 9.3611183, 9.9351892, 10.7419, 11.234958},
+                        {11.288198, 11.781251, 12.587969, 13.162039, 13.96875, 14.461807},
+                        {13.260421, 13.753473, 14.560195, 15.134262, 15.940972, 16.434032}                       
+                    }
+                }
+            }});
+
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+
+
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        }
+        SECTION("half_pixel_symmetric") {
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{
+                {
+                    {
+                        {0.56597364, 1.0590289, 1.8657435, 2.4398165, 3.2465293, 3.7395852},
+                        {2.5381951, 3.03125, 3.837965, 4.4120388, 5.21875, 5.7118063},
+                        {5.7650528, 6.2581067, 7.0648241, 7.6388974, 8.4456081, 8.9386654},
+                        {8.0613451, 8.5543985, 9.3611164, 9.9351892, 10.741899, 11.234958},
+                        {11.288198, 11.781251, 12.587969, 13.162041, 13.96875, 14.461807},
+                        {13.260421, 13.753473, 14.560195, 15.134266, 15.940972, 16.434032}           
+                    }
+                }
+            }});
+
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixelSymmetric, Interpolation::Mode::Cubic);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+
+
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        }
+        SECTION("pytorch_half_pixel") {
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{
+                {
+                    {
+                        {0.56597281, 1.0590283, 1.8657429, 2.4398167, 3.2465289, 3.7395852},
+                        {2.5381956, 3.03125, 3.837965, 4.4120383, 5.21875, 5.7118063},
+                        {5.7650542, 6.2581067, 7.0648241, 7.638896, 8.4456081, 8.9386654},
+                        {8.0613489, 8.5544004, 9.3611183, 9.9351892, 10.7419, 11.234958},
+                        {11.288198, 11.781251, 12.587969, 13.162039, 13.96875, 14.461807},
+                        {13.260421, 13.753473, 14.560195, 15.134262, 15.940972, 16.434032}           
+                    }
+                }
+            }});
+
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::PytorchHalfPixel, Interpolation::Mode::Cubic);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+
+
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        }
+        SECTION("align_corners") {
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{
+                {
+                    {
+                        {1, 1.5039997, 2.2480016, 2.7520003, 3.4959996, 4},
+                        {3.0159998, 3.5199983, 4.2640014, 4.7680001, 5.5119972, 6.0159988},
+                        {5.9920053, 6.4960032, 7.2400088, 7.7440076, 8.4880047, 8.9920063},
+                        {8.0080013, 8.5119991, 9.2560062, 9.760005, 10.504001, 11.008001},
+                        {10.983998, 11.487995, 12.232005, 12.736004, 13.479996, 13.983998},
+                        {13, 13.503995, 14.248007, 14.752007, 15.495996, 16}
+           
+                    }
+                }
+            }});
+
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::AlignCorners, Interpolation::Mode::Cubic);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+
+
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        }
+        SECTION("asymmetric") {
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{
+                {
+                    {
+                        {1, 1.5740744, 2.3703694, 3, 3.740741, 4.1111126},
+                        {3.2962966, 3.8703718, 4.666666, 5.2962976, 6.0370393, 6.4074111},
+                        {6.4814782, 7.0555553, 7.8518462, 8.4814777, 9.2222204, 9.5925922},
+                        {9, 9.5740776, 10.370367, 11, 11.740744, 12.111115},
+                        {11.962964, 12.537044, 13.333332, 13.962964, 14.70371, 15.074081},
+                        {13.444449, 14.018529, 14.814817, 15.44445, 16.185196, 16.555569}                                       
+                    }
+                }
+            }});
+
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::Asymmetric, Interpolation::Mode::Cubic);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+
+
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        }
+        SECTION("tf_half_pixel_for_nn") {
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 6, 6>{{
+                {
+                    {
+                        {2.2962933, 3.037035, 3.6666636, 4.4629622, 5.037034, 5.0925913},
+                        {5.2592578, 6, 6.6296282, 7.4259291, 8, 8.0555582},
+                        {7.7777748, 8.5185175, 9.1481438, 9.9444475, 10.518517, 10.574076},
+                        {10.962964, 11.703708, 12.333335, 13.129641, 13.70371, 13.759269},
+                        {13.259255, 14, 14.629625, 15.425932, 16, 16.055561},
+                        {13.48148, 14.222226, 14.851851, 15.648157, 16.222225, 16.277786}           
+                    }
+                }
+            }});
+
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::TFHalfPixelForNN, Interpolation::Mode::Cubic);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+
+
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        }
+    }
+ SECTION("Antialiasing") {
+        SECTION("Linear") {
+            std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 6, 6>{{{{
+                {0.8910093,  0.06583797, 0.9854186,  0.37615213, 0.5850754,  0.2838311},
+                {0.48583397, 0.2452433,  0.88194895, 0.83299613, 0.59694654, 0.6168722},
+                {0.13210197, 0.62511665, 0.12456352, 0.1806763,  0.9501198, 0.275962},
+                {0.06826835, 0.06309556, 0.02638639, 0.09780658, 0.83262056, 0.63483626},
+                {0.21528919, 0.19012423, 0.6685329,  0.61682564, 0.34627828, 0.11396276},
+                {0.9585102,  0.50689447, 0.18671302, 0.9953287,  0.857435,  0.8502074}}}
+            }});
+            input_tensor->setBackend("cuda");
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 3, 3>{{{{
+                {0.42198, 0.76913, 0.52068},
+                {0.22215, 0.10736, 0.67338},
+                {0.4677,  0.61685, 0.54197 }
+            }}}});
+            Tensor expected_out_tensor_antialiasing = Tensor(Array4D<float, 1, 1, 3, 3>{{{{
+                        {0.50311, 0.61554, 0.50743},
+                        {0.22988, 0.33839, 0.54707},
+                        {0.45242, 0.53262, 0.61474}
+                    }}}});
+
+            std::vector<float> scales = {1.0f, 1.0f, 0.5f, 0.5f}; // Adjust according to the expected interpolation scale
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Linear);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+
+            auto resize_node_antialiasing = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Linear, -.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, true);
+            auto op_antialiasing = std::static_pointer_cast<Resize_Op>(resize_node_antialiasing->getOperator());
+            op_antialiasing->associateInput(0, input_tensor);
+            op_antialiasing->setDataType(DataType::Float32);
+            op_antialiasing->setBackend("cuda");
+            op_antialiasing->forwardDims(true);
+            op_antialiasing->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            // REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+            CHECK(approxEq<float>(cudaOutput, expected_out_tensor, 1e-5f, 1e-5f) == true);
+            const auto& cudaOutputAntialiasing = op_antialiasing->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_antialiasing);
+            REQUIRE(approxEq<float>(cudaOutputAntialiasing, expected_out_tensor_antialiasing, 1e-5f, 1e-5f) == true);
+        }
+    #endif
+        SECTION("Cubic") {
+            std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 6, 6>{{{{
+                {0.8910093,  0.06583797, 0.9854186,  0.37615213, 0.5850754,  0.2838311},
+                {0.48583397, 0.2452433,  0.88194895, 0.83299613, 0.59694654, 0.6168722},
+                {0.13210197, 0.62511665, 0.12456352, 0.1806763,  0.9501198, 0.275962},
+                {0.06826835, 0.06309556, 0.02638639, 0.09780658, 0.83262056, 0.63483626},
+                {0.21528919, 0.19012423, 0.6685329,  0.61682564, 0.34627828, 0.11396276},
+                {0.9585102,  0.50689447, 0.18671302, 0.9953287,  0.857435,  0.8502074}}}
+            }});
+            input_tensor->setBackend("cuda");
+            Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 3, 3>{{{{
+                { 0.33780938,  0.92826414,  0.5099976 },
+                { 0.25023192, -0.14104399,  0.80935836 },
+                { 0.4686063,   0.71120584,  0.46677577 }
+            }}}});
+            Tensor expected_out_tensor_antialiasing = Tensor(Array4D<float, 1, 1, 3, 3>{{{{
+                { 0.50422025, 0.6695719 ,  0.50645566 },
+                { 0.16404048, 0.29127517, 0.574788 },
+                { 0.41621327, 0.54006433, 0.61721706 }
+            }}}});
+
+            std::vector<float> scales = {1.0f, 1.0f, 0.5f, 0.5f}; // Adjust according to the expected interpolation scale
+            auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic);
+            auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+            op->associateInput(0, input_tensor);
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+            op->forwardDims(true);
+            op->forward();
+            auto resize_node_antialiasing = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic, -.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, true);
+            auto op_antialiasing = std::static_pointer_cast<Resize_Op>(resize_node_antialiasing->getOperator());
+            op_antialiasing->associateInput(0, input_tensor);
+            op_antialiasing->setDataType(DataType::Float32);
+            op_antialiasing->setBackend("cuda");
+            op_antialiasing->forwardDims(true);
+            op_antialiasing->forward();
+
+            std::shared_ptr<Tensor> outputFallback;
+            const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+            REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor, 1e-5f, 1e-5f) == true);
+            
+            const auto& cudaOutputAntialiasing = op_antialiasing->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_antialiasing);
+            REQUIRE(approxEq<float>(cudaOutputAntialiasing, expected_out_tensor_antialiasing, 1e-5f, 1e-5f) == true);
+        }
+    #ifdef alltests
+    }
+    SECTION("ROI") {
+        std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 4, 4>{{{{
+            {0.0f, 1.0f, 2.0f, 3.0f},
+            {4.0f, 5.0f, 6.0f, 7.0f},
+            {8.0f, 9.0f, 10.0f, 11.0f},
+            {12.0f, 13.0f, 14.0f, 15.0f}}}
+        }});
+        input_tensor->setBackend("cuda");
+        Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 2, 2>{{{{
+            {  0.0,  3.0 },
+            { 12.0, 15.0 }
+        }}}});
+        Tensor expected_out_tensor_roi = Tensor(Array4D<float, 1, 1, 2, 2>{{{{
+            {  0.0,  1.5 },
+            { 6.0, 7.5 }
+        }}}});
+
+        std::vector<float> scales = {1.0f, 1.0f, 0.5f, 0.5f};
+        std::vector<float> roi = {0, 0, 0, 0, 1.0f, 1.0f, 0.5f, 0.5f};
+        auto resize_node = Resize({}, scales, {}, {}, Interpolation::CoordinateTransformation::TFCropAndResize, Interpolation::Mode::Linear);
+        auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+        op->associateInput(0, input_tensor);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forwardDims(true);
+        op->forward();
+
+        auto resize_node_roi = Resize(roi, scales, {}, {}, Interpolation::CoordinateTransformation::TFCropAndResize, Interpolation::Mode::Linear);
+        auto op_roi = std::static_pointer_cast<Resize_Op>(resize_node_roi->getOperator());
+        op_roi->associateInput(0, input_tensor);
+        op_roi->setDataType(DataType::Float32);
+        op_roi->setBackend("cuda");
+        op_roi->forwardDims(true);
+        op_roi->forward();
+
+        std::shared_ptr<Tensor> outputFallback;
+        const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+        REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        const auto& cudaOutputROI = op_roi->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_roi);
+        REQUIRE(approxEq<float>(cudaOutputROI, expected_out_tensor_roi));
+    }
+    SECTION("Extrapolation_Value") {
+        std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 2, 2>{{{{
+            {0.0f, 1.0f},
+            {2.0f, 3.0f}}}
+        }});
+        input_tensor->setBackend("cuda");
+        Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 4, 4>{{{{
+            {  99.0f,  99.0f, 99.0f, 99.0f },
+            {  99.0f,   0.0f,  1.0f, 99.0f },
+            {  99.0f,   2.0f,  3.0f, 99.0f },
+            {  99.0f,  99.0f, 99.0f, 99.0f }
+        }}}});
+        std::vector<std::size_t> sizes = {1, 1, 4, 4};
+        std::vector<float> roi = {0.0f, 0.0f, -0.5f, -0.5f, 1.0f, 1.0f, 1.5f, 1.5f};
+        auto resize_node = Resize(roi, {}, sizes, {}, Interpolation::CoordinateTransformation::TFCropAndResize, Interpolation::Mode::RoundPreferFloor, -.75f, 99.0f);
+        auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+        op->associateInput(0, input_tensor);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forwardDims(true);
+        op->forward();
+
+        std::shared_ptr<Tensor> outputFallback;
+        const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+        REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+    }
+    SECTION("Exclude_Outside") {
+        std::shared_ptr<Tensor> input_tensor = std::make_shared<Tensor>(Array4D<float, 1, 1, 3, 3>{{{{
+            {1.0f, 2.0f, 3.0f},
+            {4.0f, 5.0f, 6.0f},
+            {7.0f, 8.0f, 9.0f}}}
+        }});
+        input_tensor->setBackend("cuda");
+
+        Tensor expected_out_tensor = Tensor(Array4D<float, 1, 1, 4, 4>{{{{
+            {0.7128906, 1.3144531, 2.2548828, 2.8564453},
+            {2.5175781, 3.1191406, 4.0595703, 4.661133},
+            {5.338867,  5.9404297, 6.8808594, 7.482422},
+            {7.1435547, 7.745117,  8.685547,  9.287109}
+        }}}});
+        Tensor expected_out_tensor_excludeOut = Tensor(Array4D<float, 1, 1, 4, 4>{{{{
+            {0.6793893, 1.2565644, 2.2625191, 2.8396947},
+            {2.4109144, 2.98809,   3.9940448, 4.5712204},
+            {5.4287796, 6.0059547, 7.01191,   7.589085},
+            {7.160305,  7.7374797, 8.743436,  9.32061}
+        }}}});
+
+        std::vector<std::size_t> sizes = {1, 1, 4, 4};
+        auto resize_node = Resize({}, {}, sizes, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic,-.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, false );
+        auto op = std::static_pointer_cast<Resize_Op>(resize_node->getOperator());
+        op->associateInput(0, input_tensor);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cuda");
+        op->forwardDims(true);
+        op->forward();
+
+        auto resize_node_excludeOut = Resize({}, {}, sizes, {}, Interpolation::CoordinateTransformation::HalfPixel, Interpolation::Mode::Cubic, -.75f, 0.0f, PadBorderType::Edge,AspectRatio::Stretch, false, true);
+        auto op_excludeOut = std::static_pointer_cast<Resize_Op>(resize_node_excludeOut->getOperator());
+        op_excludeOut->associateInput(0, input_tensor);
+        op_excludeOut->setDataType(DataType::Float32);
+        op_excludeOut->setBackend("cuda");
+        op_excludeOut->forwardDims(true);
+        op_excludeOut->forward();
+
+        std::shared_ptr<Tensor> outputFallback;
+        const auto& cudaOutput = op->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor);
+        REQUIRE(approxEq<float>(cudaOutput, expected_out_tensor));
+        const auto& cudaOutputExcOut = op_excludeOut->getOutput(0)->refCastFrom(outputFallback, expected_out_tensor_excludeOut);
+        REQUIRE(approxEq<float>(cudaOutputExcOut, expected_out_tensor_excludeOut));
+    }
+    #endif
+}
\ No newline at end of file