diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp index 55ad6a12fe892c3ae7716391be7cf0b843283447..1cc3e975e229f77ac92d32926e2ba3291f6c0cb4 100644 --- a/include/aidge/backend/cuda.hpp +++ b/include/aidge/backend/cuda.hpp @@ -12,7 +12,9 @@ #ifndef AIDGE_BACKEND_CUDA_IMPORTS_H_ #define AIDGE_BACKEND_CUDA_IMPORTS_H_ +// XXX #include "aidge/backend/cuda/operator/OperatorImpl.hpp" +#include "aidge/backend/cuda/operator/ScalingImpl.hpp" #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/AddImpl.hpp" diff --git a/include/aidge/backend/cuda/operator/ScalingImpl.hpp b/include/aidge/backend/cuda/operator/ScalingImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4939a1613f02e067f57a0be5ba43a86a10407aa9 --- /dev/null +++ b/include/aidge/backend/cuda/operator/ScalingImpl.hpp @@ -0,0 +1,59 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_BACKEND_CUDA_OPERATOR_SCALINGIMPL_H_ +#define AIDGE_BACKEND_CUDA_OPERATOR_SCALINGIMPL_H_ + +#include <array> +#include <memory> +#include <tuple> +#include <vector> + +#include <cudnn.h> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Scaling.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { +class ScalingImpl_cuda : public OperatorImpl { +private: + std::shared_ptr<Tensor> mInputFallback; + std::shared_ptr<Tensor> mOutputGradFallback; + +public: + ScalingImpl_cuda(const Scaling_Op &op) : OperatorImpl(op, "cuda") {} + + static std::unique_ptr<ScalingImpl_cuda> create(const Scaling_Op &op) { + return std::make_unique<ScalingImpl_cuda>(op); + } + +public: + void forward(); + void backward(); + // ~ScalingImpl_cuda(); + +private: + template <class T> void forward_(const Tensor& input); + template <class T> void backward_(const Tensor& output_grad); +}; + +namespace { +// add cuda backend to Scaling_Op implementation registry +static Registrar<Scaling_Op> registrarScalingImpl_cuda("cuda", Aidge::ScalingImpl_cuda::create); +} // namespace + +} // namespace Aidge + +#endif /* AIDGE_BACKEND_CUDA_OPERATOR_SCALINGIMPL_H_ */ diff --git a/include/aidge/backend/cuda/operator/ScalingImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ScalingImpl_CUDA_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..590fccf8fa826a3c8703bf573ec20e138e60ee04 --- /dev/null +++ b/include/aidge/backend/cuda/operator/ScalingImpl_CUDA_kernels.hpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CUDA_OPERATOR_SCALINGIMPL_FORWARD_KERNEL_H_ +#define AIDGE_CUDA_OPERATOR_SCALINGIMPL_FORWARD_KERNEL_H_ + +#include <stdexcept> +#include <cfloat> +#include <cuda.h> +#include <cuda_runtime_api.h> +#include <cuda_fp16.h> + +#include "aidge/data/Data.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +template <class T> +void ScalingImpl_cuda_forward_kernel( + const std::size_t size, + const float scalingFactor, + const int nbBits, + const bool isOutputUnsigned, + const T* input, + T* output); + +} +#endif /* AIDGE_CUDA_OPERATOR_SCALINGIMPL_FORWARD_KERNEL_H_ */ + + + + + diff --git a/src/operator/ScalingImpl.cpp b/src/operator/ScalingImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0df2772116de93326696f7cd06ece65921444b1e --- /dev/null +++ b/src/operator/ScalingImpl.cpp @@ -0,0 +1,64 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <vector> + +#include "aidge/backend/cuda/data/TensorImpl.hpp" +#include "aidge/backend/cuda/operator/ScalingImpl.hpp" +#include "aidge/backend/cuda/operator/ScalingImpl_CUDA_kernels.hpp" +#include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/Scaling.hpp" +#include "aidge/utils/Types.h" + +void Aidge::ScalingImpl_cuda::forward() { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + + assert(mOp.getRawInput(0) && "missing input #0"); + + const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); + + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<double>(input); + break; + case DataType::Float32: + forward_<float>(input); + break; + case DataType::Float16: + forward_<half>(input); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); + } +} + +template <class T> +void Aidge::ScalingImpl_cuda::forward_(const Tensor& input) { + + //const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const Scaling_Op& op_ = static_cast<const Scaling_Op&>(mOp); + + const T * inputPtr = static_cast<const T*>(input.getImpl()->rawPtr()); + T * outputPtr = static_cast<T*>(op_.getOutput(0)->getImpl()->rawPtr()); + + Aidge::ScalingImpl_cuda_forward_kernel<T>( + op_.getOutput(0)->size(), + op_.scalingFactor(), + op_.quantizedNbBits(), + op_.isOutputUnsigned(), + inputPtr, + outputPtr); +} + +// TODO ... +void Aidge::ScalingImpl_cuda::backward() {} \ No newline at end of file diff --git a/src/operator/ScalingImpl_CUDA_kernels.cu b/src/operator/ScalingImpl_CUDA_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..efcbc0ef63b511e0db6b1fa1edaf5cfbbe15d673 --- /dev/null +++ b/src/operator/ScalingImpl_CUDA_kernels.cu @@ -0,0 +1,94 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/backend/cuda/operator/ScalingImpl_CUDA_kernels.hpp" + +template<typename T> +__device__ T round_helper(T x) { + return std::round(x); +} + +template<> +__device__ half round_helper<half>(half x) { + float x_float = __half2float(x); + return __float2half(std::round(x_float)); +} + +template<typename T> +__device__ T clamp_helper(T x, T min, T max) { + return x <= min ? min : x >= max ? max : x; +} + +template <class T> +__global__ void ScalingImpl_cuda_forward_kernel_( + const std::size_t size, + const float scalingFactor, + const int nbBits, + const bool isOutputUnsigned, + const T* input, + T* output) +{ + const size_t index = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + //if (index >= inputLength) return; + + const T lower = isOutputUnsigned ? 0 : -(1ll << (nbBits - 1ll)); + const T upper = isOutputUnsigned ? (1ll << nbBits) - 1ll : (1ll << (nbBits - 1ll)) - 1ll; + + for (size_t i = index; i < size; i += stride) { + output[i] = input[i] * static_cast<T>(scalingFactor); + if(nbBits > 0) { + output[i] = clamp_helper(round_helper(output[i]), lower, upper); + } + } +} + +template <class T> +void Aidge::ScalingImpl_cuda_forward_kernel( + const std::size_t size, + const float scalingFactor, + const int nbBits, + const bool isOutputUnsigned, + const T* input, + T* output) +{ + int blockSize = 256; + int numBlocks = (size + blockSize - 1) / blockSize; + // Launch the kernel + ScalingImpl_cuda_forward_kernel_<<<numBlocks, blockSize>>>(size, scalingFactor, nbBits, isOutputUnsigned, input, output); +} + + +template void Aidge::ScalingImpl_cuda_forward_kernel<double>( + const std::size_t size, + const float scalingFactor, + const int nbBits, + const bool isOutputUnsigned, + const double* input, + double* output); + +template void Aidge::ScalingImpl_cuda_forward_kernel<float>( + const std::size_t size, + const float scalingFactor, + const int nbBits, + const bool isOutputUnsigned, + const float* input, + float* output); + + +template void Aidge::ScalingImpl_cuda_forward_kernel<half>( + const std::size_t size, + const float scalingFactor, + const int nbBits, + const bool isOutputUnsigned, + const half* input, + half* output);