diff --git a/CMakeLists.txt b/CMakeLists.txt index 01ebb6f258b173aee6df867c5c5c991ec936df57..8dfd9982c7562a92e82212a6b2c9536b6fa5f451 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,7 +75,7 @@ target_link_libraries(${module_name} ) if( ${ENABLE_ASAN} ) - message("Building ${module_name} with ASAN.") + message("Building ${module_name} with ASAN.") set(SANITIZE_FLAGS -fsanitize=address -fno-omit-frame-pointer) target_link_libraries(${module_name} PUBLIC diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp index 580dce246b4c43e9a82fc977103145f79ae0976e..da62b81022550a79d63fa1f20aa9429753e5ab6c 100644 --- a/include/aidge/backend/cuda.hpp +++ b/include/aidge/backend/cuda.hpp @@ -22,6 +22,8 @@ #include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp" #include "aidge/backend/cuda/operator/PadImpl.hpp" #include "aidge/backend/cuda/operator/ReLUImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp" #include "aidge/backend/cuda/operator/ReshapeImpl.hpp" #include "aidge/backend/cuda/operator/SigmoidImpl.hpp" #include "aidge/backend/cuda/operator/SubImpl.hpp" diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4c6dc6eb57261dd230c023722a131b8858f5951 --- /dev/null +++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#ifndef AIDGE_BACKEND_CUDA_OPERATOR_SHIFTGELUIMPL_H_ +#define AIDGE_BACKEND_CUDA_OPERATOR_SHIFTGELUIMPL_H_ + +#include <array> +#include <memory> +#include <tuple> +#include <vector> + +#include <cudnn.h> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/ShiftGELU.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { +class ShiftGELUImpl_cuda : public OperatorImpl { +private: + std::shared_ptr<Tensor> mInputFallback; +public: + ShiftGELUImpl_cuda(const ShiftGELU_Op &op) : OperatorImpl(op, "cuda") {} + + static std::unique_ptr<ShiftGELUImpl_cuda> create(const ShiftGELU_Op &op) { + return std::make_unique<ShiftGELUImpl_cuda>(op); + } + +public: + void forward(); + //~ShiftGELUImpl_cuda(); + +private: + template <class T> void forward_(const Tensor& input); + +}; + +namespace { +// add cuda backend to ShiftGELU_Op implementation registry +static Registrar<ShiftGELU_Op> registrarShiftGELUImpl_cuda("cuda", Aidge::ShiftGELUImpl_cuda::create); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_BACKEND_CUDA_OPERATOR_SHIFTGELUIMPL_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc259366fd3d06744e6f22b13d7ef651cb535d0a --- /dev/null +++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp @@ -0,0 +1,34 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#ifndef AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_ +#define AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_ + +#include <stdexcept> +#include <cfloat> +#include <cuda.h> +#include <cuda_runtime_api.h> +#include <cuda_fp16.h> + +#include "aidge/data/Data.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { + +extern void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input); + +template <class T> +__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits); +} + +#endif /* AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8d72ba0b15cb3d9a91eedab2c2eab1758d0ee00f --- /dev/null +++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#ifndef AIDGE_BACKEND_CUDA_OPERATOR_SHIFTMAXIMPL_H_ +#define AIDGE_BACKEND_CUDA_OPERATOR_SHIFTMAXIMPL_H_ + +#include <array> +#include <memory> +#include <tuple> +#include <vector> + +#include <cudnn.h> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/ShiftMax.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { +class ShiftMaxImpl_cuda : public OperatorImpl { +private: + std::shared_ptr<Tensor> mInputFallback; +public: + ShiftMaxImpl_cuda(const ShiftMax_Op &op) : OperatorImpl(op, "cuda") {} + + static std::unique_ptr<ShiftMaxImpl_cuda> create(const ShiftMax_Op &op) { + return std::make_unique<ShiftMaxImpl_cuda>(op); + } + +public: + void forward(); + //~ShiftMaxImpl_cuda(); + +private: + template <class T> void forward_(const Tensor& input); + +}; + +namespace { +// add cuda backend to ShiftMax_Op implementation registry +static Registrar<ShiftMax_Op> registrarShiftMaxImpl_cuda("cuda", Aidge::ShiftMaxImpl_cuda::create); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_BACKEND_CUDA_OPERATOR_SHIFTMAXIMPL_H_ */ diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e6f5205c0287c039fedfa88ff05e934c33873b8a --- /dev/null +++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp @@ -0,0 +1,34 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#ifndef AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_ +#define AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_ + +#include <stdexcept> +#include <cfloat> +#include <cuda.h> +#include <cuda_runtime_api.h> +#include <cuda_fp16.h> + +#include "aidge/data/Data.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { + +extern void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input); + +template <class T> +__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) ; +} + +#endif /* AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_ */ \ No newline at end of file diff --git a/src/operator/ShiftGELUImpl.cpp b/src/operator/ShiftGELUImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..779fbb9175893f80dfa9110ee449e281fe3d5ca6 --- /dev/null +++ b/src/operator/ShiftGELUImpl.cpp @@ -0,0 +1,96 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> +#include <algorithm> // For std::max +#include <cmath> // For pow +#include <typeinfo> + +#include "aidge/backend/cuda/data/TensorImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp" +#include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/ShiftGELU.hpp" +#include "aidge/utils/Types.h" + +void Aidge::ShiftGELUImpl_cuda::forward() { + + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + assert(mOp.getRawInput(0) && "missing input #0"); + const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); + + //forward_<float>(input); + + //should template and changing type + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<float>(input); + break; + case DataType::Float32: + forward_<float>(input); + break; + case DataType::Float16: + forward_<float>(input); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); + } +} + +template<class T> +void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input) +{ + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr()); + + int N = 15; + int output_bits = 8; + + size_t size = input.size(); + std::vector<DimSize_t> dims_input = input.dims(); + + double min = std::numeric_limits<double>::max(); + double max = std::numeric_limits<double>::min(); + for(std::size_t i = 0; i < dims_input[0]; i++) { + for(std::size_t j = 0; j < dims_input[1]; j++) { + for(std::size_t k = 0; k < dims_input[2]; k++) { + for(std::size_t l = 0; l < dims_input[3]; l++) { + std::vector<std::size_t> coordIdx = {i, j, k, l}; + std::size_t newFlatIdx = input.getIdx(coordIdx); + if (newFlatIdx < min) { + min = newFlatIdx; + } + if (newFlatIdx > max) { + max = newFlatIdx; + } + } + } + } + } + + double m = std::max(std::abs(min), std::abs(max)); + + // Calculate the normalization factor + double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1; + + // Return the normalized maximum + double final_sf = m / normalization_factor; + T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé + ShiftGELULaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input); +} \ No newline at end of file diff --git a/src/operator/ShiftGELUImpl_CUDA_kernels.cu b/src/operator/ShiftGELUImpl_CUDA_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..c51af4e51c1270bbe0b5c7cfe658a9d84614c58a --- /dev/null +++ b/src/operator/ShiftGELUImpl_CUDA_kernels.cu @@ -0,0 +1,147 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ +#define MAX(X,Y) (((X) > (Y)) ? (X) : (Y)) +#define CLAMP(X) (((X) < (0)) ? (0) : (X)) + +#include <stdio.h> +#include <cuda_runtime.h> + +#include "aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp" + +__device__ inline int ExpShift(int I,int N, double SF) +{ + int Ip = I + (I >> 1) - (I >> 4); + int I0 = floorf(-1.0/SF); + Ip = MAX(Ip,N*I0); + int q = floorf(Ip / (I0)); + int r = Ip -(I0*q); + int Ib = r/2 - I0; + Ib = CLAMP(Ib * powf(2,N-q));//BitShift? + return (int)Ib; +} + +namespace Aidge{ + +template <class T> +__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) { + /* + * Kernels du Forward de GeLU + * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T) + * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU + * geLUTensor => pointeur vers un bloc mémoire vide alloué sur le GPU + * SumTensor => pointeur vers un bloc mémoire vide alloué sur le GPU + * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs + * SF => Scaling Factor + * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé) + * output_bits => précision en bit souhaité (8 pour int8 par exemple) + */ + int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1 + int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2 + int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3 + + double SF_sig = SF * 1.702;// SF multiplié par une constante utilisé dans l'algo + double Final_SF = SF / powf(2,(output_bits-1)); + + if (x < dims[0] && y < dims[1] && z < dims[2]) { + int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3]; + for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor) + int idx = maxIdx + i; + quantized_tensor[idx] = roundf(input[idx] / SF); + } + int maxVal = quantized_tensor[maxIdx]; + for (int i = 1; i < dims[3]; i++) { // Computing max value + int idx = maxIdx + i; + maxVal = MAX(maxVal, quantized_tensor[idx]); + } + int Max_Exp = ExpShift(-maxVal,N,SF_sig); + for (int i = 0; i < dims[3]; i++) { //Exponential (artihmetic) + int idx = maxIdx + i; + GELUtensor[idx] = ExpShift(quantized_tensor[idx] - maxVal,N,SF_sig); + if(GELUtensor[idx] > INT_MAX - Max_Exp) { + SumTensor[idx] = 1; + } + else + { + SumTensor[idx] = floorf(INT_MAX/(GELUtensor[idx] + Max_Exp)); + } + //SigMoidInt. + SumTensor[idx] = floorf((GELUtensor[idx] * SumTensor[idx]) >> (31 - output_bits + 1)); + quantized_tensor[idx] *= SumTensor[idx]; + input[idx] = quantized_tensor[idx] * Final_SF; + } + } +} + +//TODO Template +void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { + + double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftgelu, utilisé pour déquantifier le tenseur renvoyé + + int dims_input_cuda[4]; + if (dims_input.size() >= 4) { + // Fixed-size array to store the first 4 elements + + // Copy the first 4 elements from dims_input to dims_input2 + for (std::size_t i = 0; i < 4; ++i) { + dims_input_cuda[i] = static_cast<int>(dims_input[i]); + } + } + + + float* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor,size*sizeof(float)); //| + cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice); + //| + int* quantized_tensor; //| + cudaMalloc(&quantized_tensor,size*sizeof(int)); //| + //| => Allocation des blocs mémoire sur le GPU + int* GELUtensor; + cudaMalloc(&GELUtensor,size*sizeof(int)); + + int* SumTensor; + cudaMalloc(&SumTensor,size*sizeof(int)); //| + //| + int* dims; //| + cudaMalloc(&dims,4*sizeof(int)); + + cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice); + + dim3 threadsPerBlock(10, 10, 10); //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur + dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, //| le GPU pour en fonctions des dimensions du tenseur en entrée + (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, //| + (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); //| + + ShiftGELUWholeKernel<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);//Lancement du Kernel + cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU + + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("Erreur CUDA: %s\n", cudaGetErrorString(err)); + } //Checks des possibles erreurs sur le GPU + + + //float* ControlFinal = (float*)malloc(size*sizeof(float)); + //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); + //MyTensor<float> control(ControlFinal,x.dims); + + cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); + + cudaFree(quantized_tensor); //| + cudaFree(GELUtensor); //| + cudaFree(SumTensor); //| + cudaFree(dims); //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros) + cudaFree(input_cuda_tensor);//| +} + +} \ No newline at end of file diff --git a/src/operator/ShiftMaxImpl.cpp b/src/operator/ShiftMaxImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..470bac840969eb114d7b7e34e51fb9e733d41ae9 --- /dev/null +++ b/src/operator/ShiftMaxImpl.cpp @@ -0,0 +1,96 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> +#include <algorithm> // For std::max +#include <cmath> // For pow +#include <typeinfo> + +#include "aidge/backend/cuda/data/TensorImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp" +#include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/ShiftMax.hpp" +#include "aidge/utils/Types.h" + +void Aidge::ShiftMaxImpl_cuda::forward() { + + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + assert(mOp.getRawInput(0) && "missing input #0"); + const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); + + //forward_<float>(input); + + //should template and changing type + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<float>(input); + break; + case DataType::Float32: + forward_<float>(input); + break; + case DataType::Float16: + forward_<float>(input); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); + } +} + +template<class T> +void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input) +{ + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr()); + + int N = 15; + int output_bits = 8; + + size_t size = input.size(); + std::vector<DimSize_t> dims_input = input.dims(); + + double min = std::numeric_limits<double>::max(); + double max = std::numeric_limits<double>::min(); + for(std::size_t i = 0; i < dims_input[0]; i++) { + for(std::size_t j = 0; j < dims_input[1]; j++) { + for(std::size_t k = 0; k < dims_input[2]; k++) { + for(std::size_t l = 0; l < dims_input[3]; l++) { + std::vector<std::size_t> coordIdx = {i, j, k, l}; + std::size_t newFlatIdx = input.getIdx(coordIdx); + if (newFlatIdx < min) { + min = newFlatIdx; + } + if (newFlatIdx > max) { + max = newFlatIdx; + } + } + } + } + } + + double m = std::max(std::abs(min), std::abs(max)); + + // Calculate the normalization factor + double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1; + + // Return the normalized maximum + double final_sf = m / normalization_factor; + T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé + ShiftMaxLaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input); +} \ No newline at end of file diff --git a/src/operator/ShiftMaxImpl_CUDA_kernels.cu b/src/operator/ShiftMaxImpl_CUDA_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..c4c619cdde1a1ee123f19a3d75f3f14bc542bc52 --- /dev/null +++ b/src/operator/ShiftMaxImpl_CUDA_kernels.cu @@ -0,0 +1,152 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ +#define MAX(X,Y) (((X) > (Y)) ? (X) : (Y)) +#define CLAMP(X) (((X) < (0)) ? (0) : (X)) + +#include <stdio.h> +#include <cuda_runtime.h> + +#include "aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp" + +__device__ inline int ExpShift(int I,int N, double SF) +{ + int Ip = I + (I >> 1) - (I >> 4); + int I0 = floorf(-1.0/SF); + Ip = MAX(Ip,N*I0); + int q = floorf(Ip / (I0)); + int r = Ip -(I0*q); + int Ib = r/2 - I0; + Ib = CLAMP(Ib * powf(2,N-q));//BitShift? + return (int)Ib; +} + +namespace Aidge{ + +template <class T> +__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) +/* + * Kernels du Forward de Shiftmax + * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T) + * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU + * factor => pointeur vers un bloc mémoire vide alloué sur le GPU + * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs + * SF => Scaling Factor + * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé) + * output_bits => précision en bit souhaité (8 pour int8 par exemple) + * new_SF => Nouveau SF pour déquantifier le tenseur + */ +{ + int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1 + int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2 + int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3 + int sum = 0; + /* + * x,y et z représente les indices des dimensions 1,2 et 3 du tenseur, toutes les combinaisons possible de x,y et z + * sont appelés en parralèle ce qui permet le speedup GPU + * pour iterer dans la derniere dimensions on utilise les boucles "for" ci dessous + * */ + if (x < dims[0] && y < dims[1] && z < dims[2]) { + int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3]; + for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor) + int idx = maxIdx + i; + quantized_tensor[idx] = roundf(input[idx] / SF); + } + int maxVal = quantized_tensor[maxIdx]; + for (int i = 1; i < dims[3]; i++) { // max value par dimensions 4 + int idx = maxIdx + i; + maxVal = MAX(maxVal, quantized_tensor[idx]); + } + for (int i = 0; i < dims[3]; i++) { //Expo (artihmetic) + int idx = maxIdx + i; + quantized_tensor[idx] = ExpShift(quantized_tensor[idx]-maxVal,N,SF); + } + for (int i = 0; i < dims[3]; i++) { // Sum et clamp quand dépassement de valeur + int idx = maxIdx + i; + if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx])//CLAMP(2**31-1) + { + sum = INT_MAX; + break; + } + else { + sum += quantized_tensor[idx]; + } + } + factor[x * dims[1] * dims[2] + y * dims[2] + z] = floorf(INT_MAX/sum); + for(int i= 0; i < dims[3]; ++i) //bitshift pour quantifier sur 8 bits + { + int idx = maxIdx + i; + quantized_tensor[idx] = (quantized_tensor[idx] * factor[x * dims[1] * dims[2] + y * dims[2] + z]) >> (31-(2*output_bits-1)); + input[idx] =quantized_tensor[idx]*new_SF; + } + } +} + +//TODO Template +void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { + + double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé + + int dims_input_cuda[4]; + if (dims_input.size() >= 4) { + // Fixed-size array to store the first 4 elements + + // Copy the first 4 elements from dims_input to dims_input2 + for (std::size_t i = 0; i < 4; ++i) { + dims_input_cuda[i] = static_cast<int>(dims_input[i]); + } + } + + + float* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor,size*sizeof(float)); //| + cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice); + //| + int* quantized_tensor; //| + cudaMalloc(&quantized_tensor,size*sizeof(int)); //| + //| => Allocation des blocs mémoire sur le GPU + int* factor; //| + cudaMalloc(&factor,size*sizeof(int)); //| + //| + int* dims; //| + cudaMalloc(&dims,4*sizeof(int)); + + cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice); + + dim3 threadsPerBlock(10, 10, 10); //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur + dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, //| le GPU pour en fonctions des dimensions du tenseur en entrée + (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, //| + (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); //| + + ShiftMaxWholeKernel<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,quantized_tensor,factor,dims,SF,N,output_bits,new_SF);//Lancement du Kernel + cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU + + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("Erreur CUDA: %s\n", cudaGetErrorString(err)); + } //Checks des possibles erreurs sur le GPU + + + //float* ControlFinal = (float*)malloc(size*sizeof(float)); + //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); + //MyTensor<float> control(ControlFinal,x.dims); + + cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); + + cudaFree(quantized_tensor); //| + cudaFree(factor); //| + cudaFree(dims); //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros) + cudaFree(input_cuda_tensor);//| +} + +} \ No newline at end of file diff --git a/unit_tests/Test_ShiftGELUImpl.cpp b/unit_tests/Test_ShiftGELUImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d5382c29cef00587eaed3dcd789352c4e8263d31 --- /dev/null +++ b/unit_tests/Test_ShiftGELUImpl.cpp @@ -0,0 +1,131 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#include <array> + +#include <catch2/catch_test_macros.hpp> + +#include "Test_cuda.hpp" + +#include "aidge/data/Tensor.hpp" + +#include "aidge/backend/cpu.hpp" +#include "aidge/backend/cuda.hpp" + +using namespace Aidge; + +TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") { + SECTION("4D Tensor") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + {0.96, 0.48, 0.54, 0.49, 0.59, 0.93, 0.00, 0.00, 0.61, 0.61}, + {0.85, 0.06, 0.11, 0.87, 0.55, 0.12, 0.80, 0.48, 0.41, 0.16} + }, + { + {0.24, 0.46, 0.97, 0.19, 0.65, 0.12, 0.44, 1.00, 0.37, 0.09}, + {0.44, 0.64, 0.21, 0.58, 0.05, 0.24, 0.56, 0.07, 0.49, 0.79} + } + }, + { + { + {0.00, 0.13, 0.55, 0.42, 0.49, 0.28, 0.52, 0.55, 0.34, 0.85}, + {0.98, 0.32, 0.09, 0.05, 0.37, 0.47, 0.63, 0.13, 0.70, 0.02} + }, + { + {0.69, 0.13, 0.74, 0.61, 0.25, 0.87, 0.46, 0.40, 0.81, 0.06}, + {0.89, 0.32, 0.61, 0.24, 0.70, 0.23, 0.09, 0.03, 0.14, 0.80} + } + } + } + }); + + std::shared_ptr<Tensor> output_shiftGELU = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + { 0.991388f, 0.413078f, 0.413078f, 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.0f, 0.413078f, 0.413078f }, + { 0.413078f, 0.0f, 0.0f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.0f } + }, + { + { 0.0f, 0.413078f, 0.991388f, 0.0f, 0.413078f, 0.0f, 0.413078f, 0.991388f, 0.413078f, 0.0f }, + { 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.0f, 0.0f, 0.413078f, 0.0f, 0.413078f, 0.413078f } + } + }, + { + { + { 0.0f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.413078f }, + { 0.991388f, 0.413078f, 0.0f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.0f} + }, + { + { 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.413078f, 0.413078f, 0.413078f, 0.0f }, + { 0.413078f, 0.413078f, 0.413078f, 0.0f, 0.413078f, 0.0f, 0.0f, 0.0f, 0.0f, 0.413078f } + } + } + } + }); + + std::shared_ptr<Tensor> output_GELU = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> { + { + { + { + { 0.7982f, 0.3285f, 0.3809f, 0.3371f, 0.4262f, 0.7661f, 0.0000f, 0.0000f, 0.4447f, 0.4447f }, + { 0.6820f, 0.0314f, 0.0598f, 0.7028f, 0.3899f, 0.0657f, 0.6305f, 0.3285f, 0.2702f, 0.0902f } + }, + { + { 0.1428f, 0.3115f, 0.8090f, 0.1093f, 0.4824f, 0.0657f, 0.2948f, 0.8413f, 0.2384f, 0.0482f }, + { 0.2948f, 0.4729f, 0.1225f, 0.4170f, 0.0260f, 0.1428f, 0.3989f, 0.0370f, 0.3371f, 0.6203f } + } + }, + { + { + { 0.0000f, 0.0717f, 0.3899f, 0.2784f, 0.3371f, 0.1709f, 0.3632f, 0.3899f, 0.2152f, 0.6820f }, + { 0.8197f, 0.2002f, 0.0482f, 0.0260f, 0.2384f, 0.3200f, 0.4635f, 0.0717f, 0.5306f, 0.0102f } + }, + { + { 0.5209f, 0.0717f, 0.5701f, 0.4447f, 0.1497f, 0.7028f, 0.3115f, 0.2622f, 0.6407f, 0.0314f }, + { 0.7238f, 0.2002f, 0.4447f, 0.1428f, 0.5306f, 0.1359f, 0.0482f, 0.0154f, 0.0778f, 0.6305f } + } + } + } + }); //value given by torch nn GELU + + std::shared_ptr<Node> myShiftGELU = ShiftGELU(); + auto op = std::static_pointer_cast<OperatorTensor>(myShiftGELU -> getOperator()); + op->associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forward(); + + float* computedOutput = new float[output_shiftGELU->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftGELU->size(), cudaMemcpyDeviceToHost); + + for(int i = 0; i < output_shiftGELU->size(); i++){ + const float targetOutput = *(static_cast<float*>(output_shiftGELU->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + float sum = 0.0; + for(int i = 0; i < output_GELU->size(); i++){ + const float targetOutput = *(static_cast<float*>(output_GELU->getImpl()->rawPtr()) + i); + sum += fabs(computedOutput[i] - targetOutput); + } + sum = sum / output_GELU->size(); + std::cout << sum << "\n"; + REQUIRE(sum < 1.5e-1); + + delete[] computedOutput; + } + +} \ No newline at end of file diff --git a/unit_tests/Test_ShiftMaxImpl.cpp b/unit_tests/Test_ShiftMaxImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a9f8b2a665d9b77f8ba6ce7c4d251f9b6c0da166 --- /dev/null +++ b/unit_tests/Test_ShiftMaxImpl.cpp @@ -0,0 +1,128 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 25.06.2024 + * + ********************************************************************************/ + +#include <array> + +#include <catch2/catch_test_macros.hpp> + +#include "Test_cuda.hpp" + +#include "aidge/data/Tensor.hpp" + +#include "aidge/backend/cpu.hpp" +#include "aidge/backend/cuda.hpp" + +using namespace Aidge; + +TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") { + SECTION("4D Tensor") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + {0.96, 0.48, 0.54, 0.49, 0.59, 0.93, 0.00, 0.00, 0.61, 0.61}, + {0.85, 0.06, 0.11, 0.87, 0.55, 0.12, 0.80, 0.48, 0.41, 0.16} + }, + { + {0.24, 0.46, 0.97, 0.19, 0.65, 0.12, 0.44, 1.00, 0.37, 0.09}, + {0.44, 0.64, 0.21, 0.58, 0.05, 0.24, 0.56, 0.07, 0.49, 0.79} + } + }, + { + { + {0.00, 0.13, 0.55, 0.42, 0.49, 0.28, 0.52, 0.55, 0.34, 0.85}, + {0.98, 0.32, 0.09, 0.05, 0.37, 0.47, 0.63, 0.13, 0.70, 0.02} + }, + { + {0.69, 0.13, 0.74, 0.61, 0.25, 0.87, 0.46, 0.40, 0.81, 0.06}, + {0.89, 0.32, 0.61, 0.24, 0.70, 0.23, 0.09, 0.03, 0.14, 0.80} + } + } + } + }); + std::shared_ptr<Tensor> output_shiftmax = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + { 0.111084f, 0.111084f, 0.111084f, 0.111084f, 0.111084f, 0.111084f, 0.055542f, 0.055542f, 0.111084f, 0.111084f }, + { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f } + }, + { + { 0.0624695f, 0.124969f, 0.124969f, 0.0624695f, 0.124969f, 0.0624695f, 0.124969f, 0.124969f, 0.124969f, 0.0624695f }, + { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f } + } + }, + { + { + { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f }, + { 0.124969f, 0.124969f, 0.0624695f, 0.0624695f, 0.124969f, 0.124969f, 0.124969f, 0.0624695f, 0.124969f, 0.0624695f } + }, + { + { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f }, + { 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f, 0.0999756f } + } + } + } + }); + std::shared_ptr<Tensor> output_softmax = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> { + { + { + { + { 0.1484f, 0.0918f, 0.0975f, 0.0928f, 0.1025f, 0.1440f, 0.0568f, 0.0568f, 0.1046f, 0.1046f }, + { 0.1436f, 0.0652f, 0.0685f, 0.1465f, 0.1064f, 0.0692f, 0.1366f, 0.0992f, 0.0925f, 0.0721f } + }, + { + { 0.0768f, 0.0957f, 0.1593f, 0.0730f, 0.1157f, 0.0681f, 0.0938f, 0.1642f, 0.0874f, 0.0661f }, + { 0.1005f, 0.1227f, 0.0798f, 0.1156f, 0.0680f, 0.0823f, 0.1133f, 0.0694f, 0.1056f, 0.1426f } + } + }, + { + { + { 0.0645f, 0.0734f, 0.1118f, 0.0981f, 0.1052f, 0.0853f, 0.1085f, 0.1118f, 0.0906f, 0.1509f }, + { 0.1743f, 0.0901f, 0.0716f, 0.0688f, 0.0947f, 0.1047f, 0.1228f, 0.0745f, 0.1317f, 0.0667f } + }, + { + { 0.1164f, 0.0665f, 0.1224f, 0.1075f, 0.0750f, 0.1394f, 0.0925f, 0.0871f, 0.1313f, 0.0620f }, + { 0.1551f, 0.0877f, 0.1172f, 0.0810f, 0.1283f, 0.0802f, 0.0697f, 0.0656f, 0.0733f, 0.1418f } + } + } + } + }); //softmax value given by torch softmax + + std::shared_ptr<Node> myShiftMax = ShiftMax(); + auto op = std::static_pointer_cast<OperatorTensor>(myShiftMax -> getOperator()); + op->associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forward(); + + float* computedOutput = new float[output_shiftmax->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftmax->size(), cudaMemcpyDeviceToHost); + + for(int i = 0; i < output_shiftmax->size(); i++){ + const float targetOutput = *(static_cast<float*>(output_shiftmax->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + float sum = 0.0; + for(int i = 0; i < output_softmax->size(); i++){ + const float targetOutput = *(static_cast<float*>(output_softmax->getImpl()->rawPtr()) + i); + sum += fabs(computedOutput[i] - targetOutput); + } + sum = sum / output_softmax->size(); + REQUIRE(sum < 4e-2); + + delete[] computedOutput; + } + +} \ No newline at end of file