diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp index e8b8772dc92a4a7b5cd8849fa08c62606149d8cc..d5e9d1654f0a4fe894ed0e965a25b32c9e5caa06 100644 --- a/include/aidge/backend/cuda.hpp +++ b/include/aidge/backend/cuda.hpp @@ -37,4 +37,9 @@ #include "aidge/backend/cuda/operator/SubImpl.hpp" #include "aidge/backend/cuda/operator/TanhImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp" +#include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp" +#include "aidge/backend/cuda/operator/ILayerNormImpl.hpp" + + #endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */ diff --git a/include/aidge/backend/cuda/operator/ILayerNormImpl.hpp b/include/aidge/backend/cuda/operator/ILayerNormImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..742401de7903f19ab4d8f51a153b0e864f21dd47 --- /dev/null +++ b/include/aidge/backend/cuda/operator/ILayerNormImpl.hpp @@ -0,0 +1,65 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 10.09.2024 + * + ********************************************************************************/ + +#ifndef AIDGE_BACKEND_CUDA_OPERATOR_ILAYERNORMIMPL_H_ +#define AIDGE_BACKEND_CUDA_OPERATOR_ILAYERNORMIMPL_H_ + +#include <array> +#include <memory> +#include <tuple> +#include <vector> + +#include <cudnn.h> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/ILayerNorm.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { +class ILayerNormImpl_cuda : public OperatorImpl { +public: + ILayerNormImpl_cuda(const ILayerNorm_Op &op) : OperatorImpl(op, "cuda") {} + + static std::unique_ptr<ILayerNormImpl_cuda> create(const ILayerNorm_Op &op) { + return std::make_unique<ILayerNormImpl_cuda>(op); + } + + virtual std::set<ImplSpec> getAvailableImplSpecs() const override { + return { + {DataType::Float64}, + {DataType::Float32}, + {DataType::Float16}, + }; + } + + void forward() override; + void backward() override; + +private: + std::shared_ptr<Tensor> mInput0Fallback; + std::shared_ptr<Tensor> mInput1Fallback; + std::shared_ptr<Tensor> mInput2Fallback; + std::shared_ptr<Tensor> mOutputGradFallback; + + template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2); + template <class T> void backward_(const Tensor& output_grad); +}; + +// Implementation entry point registration to Operator +REGISTRAR(ILayerNorm_Op, "cuda", Aidge::ILayerNormImpl_cuda::create); +} // namespace Aidge + +#endif /* AIDGE_BACKEND_CUDA_OPERATOR_ILAYERNORMIMPL_H_ */ diff --git a/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aa54029ea29bc46809f227038a1a23d91bc161ee --- /dev/null +++ b/include/aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp @@ -0,0 +1,92 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 10.09.2024 + * + ********************************************************************************/ + +#ifndef AIDGE_CUDA_OPERATOR_ILAYERNORMIMPL_FORWARD_KERNEL_H_ +#define AIDGE_CUDA_OPERATOR_ILAYERNORMIMPL_FORWARD_KERNEL_H_ + +#include <stdexcept> +#include <cfloat> +#include <cuda.h> +#include <cuda_runtime_api.h> +#include <cuda_fp16.h> + +#include "aidge/data/Data.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { + +/** + * @brief Compute the forward for ILayerNorm + * @param input: Input tensor + * @param SF: Scaling factor of input tensor + * @param dims: Dimensions of input tensor + * @param quantized_tensor: Quantized output tensor + * @param square_tensor: Tensor use for computation + * @param weight: weight of ILayerNorm layer + * @param bias: bias of ILayerNorm layer + * @param new_SF: Scaling factor of output that can be use to dequantify +*/ +template <class T> +__global__ void ILayerNormforward_(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF); + +/** + * @brief Wrapper function to execute ILayerNormforward_ + * @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor + * @param input: Input tensor + * @param output: Output tensor (not quantized) + * @param SF: Scaling factor of input tensor + * @param weight_raw: weight of ILayerNorm layer + * @param bias_raw: bias of ILayerNorm layer + * @param size: Number of elements in the input tensor + * @param dims: Dimensions of input tensor +*/ +template <class T> +void ILayerNormforward(const T* input, T* output, double SF, const T* weight_raw, const T* bias_raw, size_t size, std::vector<long unsigned int> dims_input); + +/** + * @brief Compute the backward for ILayerNorm + * @param output_grad: Gradient of output tensor + * @param input_tensor: Input tensor + * @param output_tensor: Output tensor obtained after forward + * @param mean: Arithmetic mean of input tensor + * @param var: Arithmetic variance of input tensor + * @param weight: weight of ILayerNorm layer + * @param bias: bias of ILayerNorm layer + * @param input_grad: Gradient of input tensor + * @param weight_grad: Gradient of ILayerNorm weight + * @param bias_grad: Gradient of ILayerNorm bias + * @param size: Number of elements in the input tensor +*/ +template <class T> +__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size); + +/** + * @brief Wrapper function to execute ILayerNormbackward_ + * @param input_tensor: Input tensor + * @param output_grad: Gradient of output tensor + * @param output_tensor: Output tensor obtained after forward + * @param mean: Arithmetic mean of input tensor + * @param var: Arithmetic variance of input tensor + * @param weight: weight of ILayerNorm layer + * @param bias: bias of ILayerNorm layer + * @param input_grad: Gradient of input tensor + * @param weight_grad: Gradient of ILayerNorm weight + * @param bias_grad: Gradient of ILayerNorm bias + * @param size: Number of elements in the input tensor +*/ +template <class T> +void ILayerNormbackward(const T* input_tensor, const T* output_grad, const T* output_tensor,const T* mean,const T* var, const T* weight, const T* bias, T* input_grad, T* weight_grad, T* bias_grad, size_t size); + +} + +#endif /* AIDGE_CUDA_OPERATOR_ILAYERNORMIMPL_FORWARD_KERNEL_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp index 6eee6c12ce5d4efaa4dbec3f99dc35951c8087eb..f83b41ae139482cdb0cd1060846c77ba78fcc0ee 100644 --- a/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp +++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl.hpp @@ -29,12 +29,11 @@ #include "aidge/backend/cuda/utils/CudaUtils.hpp" namespace Aidge { -// Operator implementation entry point for the backend class ShiftGELUImpl_cuda : public OperatorImpl { public: - ShiftGELUImpl_cuda(const ShiftGELU_Op& op) : OperatorImpl(op, "cuda") {} + ShiftGELUImpl_cuda(const ShiftGELU_Op &op) : OperatorImpl(op, "cuda") {} - static std::unique_ptr<ShiftGELUImpl_cuda> create(const ShiftGELU_Op& op) { + static std::unique_ptr<ShiftGELUImpl_cuda> create(const ShiftGELU_Op &op) { return std::make_unique<ShiftGELUImpl_cuda>(op); } @@ -46,12 +45,17 @@ public: }; } + void forward() override; + void backward() override; private: std::shared_ptr<Tensor> mInputFallback; + std::shared_ptr<Tensor> mOutputGradFallback; template <class T> void forward_(const Tensor& input); + template <class T> void backward_(const Tensor& output_grad); + }; // Implementation entry point registration to Operator diff --git a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp index ab982fa2cb44833f687368a3704bf855444a661e..14268521451a631ccb9194d44ed7543af8d494f5 100644 --- a/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/ShiftGELUImpl_CUDA_kernels.hpp @@ -25,10 +25,54 @@ namespace Aidge { -extern void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input); +/** + * @brief Compute the forward for ShiftGELU + * @param input: Input tensor + * @param quantized_tensor: Quantized output tensor + * @param GELUtensor: Pointer to an empty memory block allocated on the GPU (just use for computation) + * @param SumTensor: Pointer to an empty memory block allocated on the GPU (just use for computation) + * @param dims: Dimensions of input tensor + * @param SF: Scaling factor of input tensor + * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required) + * @param output_bits: Desired bit precision (8 for int8, for example) +*/ +template <class T> +__global__ void ShiftGELUforward_(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits); + +/** + * @brief Wrapper function to execute ShiftGELUforward_ + * @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor + * @param input: Input tensor + * @param output: Output tensor (not quantized) + * @param SF: Scaling factor of input tensor + * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required) + * @param output_bits: Desired bit precision (8 for int8, for example) + * @param size: Number of elements in the input tensor + * @param dims_input: Dimensions of input tensor +*/ +template <class T> +void ShiftGELUforward(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input); +/** + * @brief Compute the backward for ShiftGELU + * @param input_grad: Gradient of input tensor (that we want to obtain) + * @param output_tensor: Output tensor obtained after forward + * @param output_grad: Gradient of output tensor + * @param size: Number of elements in the input tensor +*/ template <class T> -__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits); +__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size); + +/** + * @brief Wrapper function to execute ShiftGELUbackward_ + * @param output_tensor: Output tensor obtained after forward + * @param output_grad: Gradient of output tensor + * @param input_grad: Gradient of input tensor (that we want to obtain) + * @param size: Number of elements in the input tensor +*/ +template <class T> +void ShiftGELUbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size); + } -#endif /* AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_KERNELS_H_ */ \ No newline at end of file +#endif /* AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_ */ diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp index bce533158e3a8fffdf798a07df5cc9735a836fa8..707b5616fde120f7e8ef38e6dc9f1552cfdb0d59 100644 --- a/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp +++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl.hpp @@ -29,12 +29,11 @@ #include "aidge/backend/cuda/utils/CudaUtils.hpp" namespace Aidge { -// Operator implementation entry point for the backend class ShiftMaxImpl_cuda : public OperatorImpl { public: - ShiftMaxImpl_cuda(const ShiftMax_Op& op) : OperatorImpl(op, "cuda") {} + ShiftMaxImpl_cuda(const ShiftMax_Op &op) : OperatorImpl(op, "cuda") {} - static std::unique_ptr<ShiftMaxImpl_cuda> create(const ShiftMax_Op& op) { + static std::unique_ptr<ShiftMaxImpl_cuda> create(const ShiftMax_Op &op) { return std::make_unique<ShiftMaxImpl_cuda>(op); } @@ -47,11 +46,15 @@ public: } void forward() override; + void backward() override; private: std::shared_ptr<Tensor> mInputFallback; + std::shared_ptr<Tensor> mOutputGradFallback; template <class T> void forward_(const Tensor& input); + template <class T> void backward_(const Tensor& output_grad); + }; // Implementation entry point registration to Operator diff --git a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp index 59c8c0ba19764cfd1cba2eaddaa75b0ce4de3e43..037a7cbb6362a8eca5a9e6f5a277b29a6a6bd907 100644 --- a/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp +++ b/include/aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp @@ -25,10 +25,55 @@ namespace Aidge { -extern void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input); +/** + * @brief Compute the forward for ShiftMax + * @param input: Input tensor + * @param quantized_tensor: Quantized output tensor + * @param factor: Pointer to an empty memory block allocated on the GPU (just use for computation) + * @param dims: Dimensions of input tensor + * @param SF: Scaling factor of input tensor + * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required) + * @param output_bits: Desired bit precision (8 for int8, for example) + * @param new_SF: Scaling factor of output that can be use to dequantify +*/ +template <class T> +__global__ void ShiftMaxforward_(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF); + +/** + * @brief Wrapper function to execute ShiftMaxforward_ + * @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor + * @param input: Input tensor + * @param output: Output tensor (not quantized) + * @param SF: Scaling factor of input tensor + * @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required) + * @param output_bits: Desired bit precision (8 for int8, for example) + * @param size: Number of elements in the input tensor + * @param dims_input: Dimensions of input tensor +*/ +template <class T> +void ShiftMaxforward(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input); +/** + * @brief Compute the backward for ShiftMax + * @param input_grad: Gradient of input tensor (that we want to obtain) + * @param output_tensor: Output tensor obtained after forward + * @param output_grad: Gradient of output tensor + * @param dims: Dimensions of input tensor +*/ template <class T> -__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) ; +__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims); + +/** + * @brief Wrapper function to execute ShiftMaxbackward_ + * @param output_tensor: Output tensor obtained after forward + * @param output_grad: Gradient of output tensor + * @param input_grad: Gradient of input tensor (that we want to obtain) + * @param size: Number of elements in the input tensor + * @param dims: Dimensions of input tensor +*/ +template <class T> +void ShiftMaxbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size, std::vector<long unsigned int> dims); + } -#endif /* AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_KERNELS_H_ */ \ No newline at end of file +#endif /* AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_ */ diff --git a/src/operator/ILayerNormImpl.cpp b/src/operator/ILayerNormImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47dd1d5d1a3f127c9e08788f605796020a7814a7 --- /dev/null +++ b/src/operator/ILayerNormImpl.cpp @@ -0,0 +1,204 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 10.09.2024 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> +#include <algorithm> // For std::max +#include <cmath> // For pow +#include <typeinfo> + +#include "aidge/backend/cuda/data/TensorImpl.hpp" +#include "aidge/backend/cuda/operator/ILayerNormImpl.hpp" +#include "aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp" +#include "aidge/backend/cuda/utils/CudaContext.hpp" +#include "aidge/backend/cuda/utils/CudaUtils.hpp" +#include "aidge/operator/ILayerNorm.hpp" +#include "aidge/utils/Types.h" + +void Aidge::ILayerNormImpl_cuda::forward() { + + + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + + assert(mOp.getRawInput(0) && "missing input #0"); + assert(mOp.getRawInput(1) && "missing input #1"); + assert(mOp.getRawInput(2) && "missing input #2"); + + const auto& input0 = op.getInput(0)->refCastFrom(mInput0Fallback, *op.getOutput(0)); + const auto& input1 = op.getInput(1)->refCastFrom(mInput1Fallback, *op.getOutput(0)); + const auto& input2 = op.getInput(2)->refCastFrom(mInput2Fallback, *op.getOutput(0)); + + switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { + case DataType::Float64: + forward_<double>(input0, input1, input2); + break; + case DataType::Float32: + forward_<float>(input0, input1, input2); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); + } +} + + +template<class T> +void Aidge::ILayerNormImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) +{ + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const T * input_raw = static_cast<const T*>(input0.getImpl()->rawPtr()); + const T * weight = static_cast<const T*>(input1.getImpl()->rawPtr()); + const T * bias = static_cast<const T*>(input2.getImpl()->rawPtr()); + T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); + + int N = 15; + int output_bits = 8; + size_t size = input0.size(); + std::vector<DimSize_t> dims_input = input0.dims(); + + // maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value) + + double min = std::numeric_limits<double>::max(); + double max = std::numeric_limits<double>::min(); + for(std::size_t i = 0; i < dims_input[0]; i++) { + for(std::size_t j = 0; j < dims_input[1]; j++) { + for(std::size_t k = 0; k < dims_input[2]; k++) { + for(std::size_t l = 0; l < dims_input[3]; l++) { + std::vector<std::size_t> coordIdx = {i, j, k, l}; + std::size_t newFlatIdx = input0.getIdx(coordIdx); + if (newFlatIdx < min) { + min = newFlatIdx; + } + if (newFlatIdx > max) { + max = newFlatIdx; + } + } + } + } + } + double m = std::max(std::abs(min), std::abs(max)); + double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1; + double scaling_factor = m / normalization_factor; + + // The new scaling factor that we can use to dequantify the returned tensor (not used here) + // double new_SF = 1/std::pow(2,2*output_bits-1); + + ILayerNormforward(input_raw, output, scaling_factor, weight, bias, size, dims_input); +} + +void Aidge::ILayerNormImpl_cuda::backward() { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + + assert(op.getOutput(0)->grad() && "missing output #0"); + + const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad()); + + if (op.getInput(0)->grad()->dataType() == DataType::Float64) { + backward_<double>(output_grad); + } + else { + backward_<float>(output_grad); + } +} + +template <class T> +void Aidge::ILayerNormImpl_cuda::backward_(const Tensor& output_grad) { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + size_t size = output_grad.size(); + std::vector<DimSize_t> dims_input = output_grad.dims(); + + const T * output = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()); + + T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); + T * weight_grad = static_cast<T*>(op.getInput(1)->grad()->getImpl()->rawPtr()); + T * bias_grad = static_cast<T*>(op.getInput(2)->grad()->getImpl()->rawPtr()); + + const T * input = static_cast<const T*>(op.getInput(0)->getImpl()->rawPtr()); + const T * weight = static_cast<const T*>(op.getInput(1)->getImpl()->rawPtr()); + const T * bias = static_cast<const T*>(op.getInput(2)->getImpl()->rawPtr()); + + // maybe find a most efficient way to compute mean and variance tensor + + std::vector<std::vector<std::vector<std::vector<T>>>> means(dims_input[0], + std::vector<std::vector<std::vector<T>>>(dims_input[1], + std::vector<std::vector<T>>(dims_input[2], + std::vector<T>(dims_input[3], 0.0f)))); + + for (std::size_t i = 0; i < dims_input[0]; i++) { + for (std::size_t j = 0; j < dims_input[1]; j++) { + for (std::size_t k = 0; k < dims_input[2]; k++) { + T sum = 0.0f; + + for (std::size_t l = 0; l < dims_input[3]; l++) { + std::vector<std::size_t> coordIdx = {i, j, k, l}; + sum += output_grad.getIdx(coordIdx); + } + for (std::size_t l = 0; l < dims_input[3]; l++) { + std::vector<std::size_t> coordIdx = {i, j, k, l}; + means[i][j][k][l] = sum / static_cast<T>(dims_input[3]); + } + } + } + } + std::vector<T> flat_means; + + for (const auto &vec3d : means) { + for (const auto &vec2d : vec3d) { + for (const auto &vec1d : vec2d) { + flat_means.insert(flat_means.end(), vec1d.begin(), vec1d.end()); + } + } + } + + std::vector<std::vector<std::vector<std::vector<T>>>> vars(dims_input[0], + std::vector<std::vector<std::vector<T>>>(dims_input[1], + std::vector<std::vector<T>>(dims_input[2], + std::vector<T>(dims_input[3], 0.0f)))); + + for (std::size_t i = 0; i < dims_input[0]; i++) { + for (std::size_t j = 0; j < dims_input[1]; j++) { + for (std::size_t k = 0; k < dims_input[2]; k++) { + T sum_sq_diff = 0.0f; + + for (std::size_t l = 0; l < dims_input[3]; l++) { + std::vector<std::size_t> coordIdx = {i, j, k, l}; + T value = static_cast<T>(output_grad.getIdx(coordIdx)); + T diff = value - means[i][j][k][l]; + sum_sq_diff += diff * diff; + } + T variance = sum_sq_diff / static_cast<T>(dims_input[3]); + for (std::size_t l = 0; l < dims_input[3]; l++) { + vars[i][j][k][l] = variance; + } + } + } + } + + std::vector<T> flat_vars; + + for (const auto &vec3d : vars) { + for (const auto &vec2d : vec3d) { + for (const auto &vec1d : vec2d) { + flat_vars.insert(flat_vars.end(), vec1d.begin(), vec1d.end()); + } + } + } + + const T* mean_ = flat_means.data(); + const T* var_ = flat_vars.data(); + const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr()); + + ILayerNormbackward(output, output_grad_raw, input, mean_, var_, weight, bias, input_grad, weight_grad, bias_grad, size); +} diff --git a/src/operator/ILayerNormImpl_CUDA_kernels.cu b/src/operator/ILayerNormImpl_CUDA_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..fafdc176fdad6a6130c9bc4374d75f8a773f2c16 --- /dev/null +++ b/src/operator/ILayerNormImpl_CUDA_kernels.cu @@ -0,0 +1,335 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 10.09.2024 + * + ********************************************************************************/ + +#define MAX(X,Y) (((X) > (Y)) ? (X) : (Y)) +#define CLAMP(X) (((X) < (0)) ? (0) : (X)) + +#include <stdio.h> +#include <cuda_runtime.h> + +#include "aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp" + +namespace Aidge{ + +template <class T> +__global__ void ILayerNormforward_(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + int z = blockIdx.z * blockDim.z + threadIdx.z; + + int k = 1 << 16; + long long int sum = 0; + if (x < dims[0] && y < dims[1] && z < dims[2]) { + int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3]; + int val; + int mean_val = 0; + for (int i = 0; i < dims[3]; i++) { + int idx = maxIdx + i; + val = roundf(input[idx] / SF); + quantized_tensor[idx] = val; + mean_val += val; + } + for (int i = 0; i < dims[3]; i++) { + int idx = maxIdx + i; + quantized_tensor[idx] -= (mean_val/dims[3]) ; + square_tensor[idx] = (quantized_tensor[idx] * quantized_tensor[idx]); // I-ViT code implementation + //square_tensor[idx] = (quantized_tensor[idx] * quantized_tensor[idx])/dims[3]; // I-ViT paper implementation + } + for (int i = 0; i < dims[3]; i++) { + int idx = maxIdx + i; + sum += square_tensor[idx]; + biase[i] = (biase[i]/weight[i])/new_SF; + weight[i] = weight[i] * new_SF; + } + for(int h = 0; h < 10 ; h++) + { + k = floorf((k + floorf(sum / k))/2); + } + int factor = (((1 << 31) - 1) / k); + for (int i = 0; i < dims[3]; i++) { + int idx = maxIdx + i; + square_tensor[idx] = (biase[idx]/weight[idx])/new_SF; + quantized_tensor[idx] = (quantized_tensor[idx] * factor / 2) + biase[maxIdx]; + input[idx] = quantized_tensor[idx] * new_SF; + } + + } +} + +template <> +void ILayerNormforward<float>(const float* input, float* output, double SF, const float* weight_raw, const float* bias_raw, size_t size, std::vector<long unsigned int> dims_input) +{ + int dims_input_cuda[4] = {1, 1, 1, 1}; + for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) { + dims_input_cuda[i] = static_cast<int>(dims_input[i]); + } + + double new_SF = std::sqrt(dims_input_cuda[3]) / (1 << 30); + + float* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor,size*sizeof(float)); + cudaMemcpy(input_cuda_tensor,input, size * sizeof(float),cudaMemcpyHostToDevice); + + int *quantized_tensor; + cudaMalloc(&quantized_tensor, size * sizeof(int)); + + int *dims; + cudaMalloc(&dims, 4 * sizeof(int)); + cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice); + + float *weight; + cudaMalloc(&weight,dims_input_cuda[3]*sizeof(float)); + cudaMemcpy(weight,weight_raw,dims_input_cuda[3]*sizeof(float),cudaMemcpyHostToDevice); + + float *bias; + cudaMalloc(&bias,dims_input_cuda[3]*sizeof(float)); + cudaMemcpy(bias,bias_raw,dims_input_cuda[3]*sizeof(float),cudaMemcpyHostToDevice); + + long long int* Squaretensor; + cudaMalloc(&Squaretensor,(size)*sizeof(long long int)); + + dim3 threadsPerBlock(10, 10, 10); + dim3 numBlocks((dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, + (dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, + (dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); + + ILayerNormforward_<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF); + cudaDeviceSynchronize(); + + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl; + } + cudaMemcpy(output,input_cuda_tensor, (size ) * sizeof(float), cudaMemcpyDeviceToHost); + + + cudaFree(input_cuda_tensor); + cudaFree(weight); + cudaFree(bias); + cudaFree(dims); + cudaFree(quantized_tensor); +} + +template <> +void ILayerNormforward<double>(const double* input, double* output, double SF, const double* weight_raw, const double* bias_raw, size_t size, std::vector<long unsigned int> dims_input) +{ + int dims_input_cuda[4] = {1, 1, 1, 1}; + for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) { + dims_input_cuda[i] = static_cast<int>(dims_input[i]); + } + + double new_SF = std::sqrt(dims_input_cuda[3]) / (1 << 30); + + double* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor,size*sizeof(double)); + cudaMemcpy(input_cuda_tensor,input, size * sizeof(double),cudaMemcpyHostToDevice); + + int *quantized_tensor; + cudaMalloc(&quantized_tensor, size * sizeof(int)); + + int *dims; + cudaMalloc(&dims, 4 * sizeof(int)); + cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice); + + double *weight; + cudaMalloc(&weight,dims_input_cuda[3]*sizeof(double)); + cudaMemcpy(weight,weight_raw,dims_input_cuda[3]*sizeof(double),cudaMemcpyHostToDevice); + + double *bias; + cudaMalloc(&bias,dims_input_cuda[3]*sizeof(double)); + cudaMemcpy(bias,bias_raw,dims_input_cuda[3]*sizeof(double),cudaMemcpyHostToDevice); + + long long int* Squaretensor; + cudaMalloc(&Squaretensor,(size)*sizeof(long long int)); + + dim3 threadsPerBlock(10, 10, 10); + dim3 numBlocks((dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, + (dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, + (dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); + + ILayerNormforward_<double><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF); + cudaDeviceSynchronize(); + + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl; + } + + cudaMemcpy(output,input_cuda_tensor, (size ) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(input_cuda_tensor); + cudaFree(weight); + cudaFree(bias); + cudaFree(dims); + cudaFree(quantized_tensor); +} + +template <class T> +__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < size) { + T d_norm = output_grad[i] * weight[i]; + T d_var = d_norm * (input_tensor[i] - mean[i]) * -0.5 * powf(var[i] + 1e-6, -1.5); + T d_mean = d_norm * -1 / sqrtf(var[i] + 1e-6) + d_var * -2 * mean[i] / size; + T d_input = d_norm / sqrtf(var[i] + 1e-6) + d_var * 2 * (input_tensor[i] - mean[i]) / size + d_mean / size; + + input_grad[i] = d_input; + weight_grad[i] = output_grad[i] * output_tensor[i]; + bias_grad[i] = output_grad[i]; + } +} + +template <> +void ILayerNormbackward<float>(const float* input_tensor, const float* output_grad, const float* output_tensor,const float* mean,const float* var, const float* weight, const float* bias, float* input_grad, float* weight_grad, float* bias_grad, size_t size) +{ + float* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor,size*sizeof(float)); + cudaMemcpy(input_cuda_tensor,input_tensor,size*sizeof(float),cudaMemcpyHostToDevice); + + float* output_grad_; + cudaMalloc(&output_grad_,size*sizeof(float)); + cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice); + + float* output_tensor_; + cudaMalloc(&output_tensor_,size*sizeof(float)); + cudaMemcpy(output_tensor_,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice); + + float* mean_; + cudaMalloc(&mean_,size*sizeof(float)); + cudaMemcpy(mean_,mean,size*sizeof(float),cudaMemcpyHostToDevice); + + float* var_; + cudaMalloc(&var_,size*sizeof(float)); + cudaMemcpy(var_,var,size*sizeof(float),cudaMemcpyHostToDevice); + + float* weight_; + cudaMalloc(&weight_,size*sizeof(float)); + cudaMemcpy(weight_,weight,size*sizeof(float),cudaMemcpyHostToDevice); + + float* bias_; + cudaMalloc(&bias_,size*sizeof(float)); + cudaMemcpy(bias_,bias,size*sizeof(float),cudaMemcpyHostToDevice); + + + float* input_grad_; + cudaMalloc(&input_grad_,size*sizeof(float)); + + float* weight_grad_; + cudaMalloc(&weight_grad_,size*sizeof(float)); + + float* bias_grad_; + cudaMalloc(&bias_grad_,size*sizeof(float)); + + + dim3 threadParBlock(256); + dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); + + ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size); + + cudaDeviceSynchronize(); + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + + cudaMemcpy(input_grad , input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(weight_grad , weight_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(bias_grad , bias_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost); + + cudaFree(input_cuda_tensor); + cudaFree(output_grad_); + cudaFree(mean_); + cudaFree(var_); + cudaFree(weight_); + cudaFree(bias_); + cudaFree(input_grad_); + cudaFree(weight_grad_); + cudaFree(bias_grad_); + +} + +template <> +void ILayerNormbackward<double>(const double* input_tensor, const double* output_grad, const double* output_tensor,const double* mean,const double* var, const double* weight, const double* bias, double* input_grad, double* weight_grad, double* bias_grad, size_t size) +{ + double* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor,size*sizeof(double)); + cudaMemcpy(input_cuda_tensor,input_tensor,size*sizeof(double),cudaMemcpyHostToDevice); + + double* output_grad_; + cudaMalloc(&output_grad_,size*sizeof(double)); + cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice); + + double* output_tensor_; + cudaMalloc(&output_tensor_,size*sizeof(double)); + cudaMemcpy(output_tensor_,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice); + + double* mean_; + cudaMalloc(&mean_,size*sizeof(double)); + cudaMemcpy(mean_,mean,size*sizeof(double),cudaMemcpyHostToDevice); + + double* var_; + cudaMalloc(&var_,size*sizeof(double)); + cudaMemcpy(var_,var,size*sizeof(double),cudaMemcpyHostToDevice); + + double* weight_; + cudaMalloc(&weight_,size*sizeof(double)); + cudaMemcpy(weight_,weight,size*sizeof(double),cudaMemcpyHostToDevice); + + double* bias_; + cudaMalloc(&bias_,size*sizeof(double)); + cudaMemcpy(bias_,bias,size*sizeof(double),cudaMemcpyHostToDevice); + + + double* input_grad_; + cudaMalloc(&input_grad_,size*sizeof(double)); + + double* weight_grad_; + cudaMalloc(&weight_grad_,size*sizeof(double)); + + double* bias_grad_; + cudaMalloc(&bias_grad_,size*sizeof(double)); + + + dim3 threadParBlock(256); + dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); + + ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size); + + cudaDeviceSynchronize(); + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + + + cudaMemcpy(input_grad , input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost); + cudaMemcpy(weight_grad , weight_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost); + cudaMemcpy(bias_grad , bias_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(input_cuda_tensor); + cudaFree(output_grad_); + cudaFree(mean_); + cudaFree(var_); + cudaFree(weight_); + cudaFree(bias_); + cudaFree(input_grad_); + cudaFree(weight_grad_); + cudaFree(bias_grad_); +} + +} \ No newline at end of file diff --git a/src/operator/ShiftGELUImpl.cpp b/src/operator/ShiftGELUImpl.cpp index 779fbb9175893f80dfa9110ee449e281fe3d5ca6..c2774804d04a422aefd0c66ed0d1fc1d949b1f06 100644 --- a/src/operator/ShiftGELUImpl.cpp +++ b/src/operator/ShiftGELUImpl.cpp @@ -34,19 +34,13 @@ void Aidge::ShiftGELUImpl_cuda::forward() { assert(mOp.getRawInput(0) && "missing input #0"); const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); - //forward_<float>(input); - - //should template and changing type switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { case DataType::Float64: - forward_<float>(input); + forward_<double>(input); break; case DataType::Float32: forward_<float>(input); break; - case DataType::Float16: - forward_<float>(input); - break; default: AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); } @@ -57,13 +51,15 @@ void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr()); + T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); int N = 15; int output_bits = 8; - size_t size = input.size(); std::vector<DimSize_t> dims_input = input.dims(); + // maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value) + double min = std::numeric_limits<double>::max(); double max = std::numeric_limits<double>::min(); for(std::size_t i = 0; i < dims_input[0]; i++) { @@ -84,13 +80,40 @@ void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input) } double m = std::max(std::abs(min), std::abs(max)); - - // Calculate the normalization factor double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1; + double scaling_factor = m / normalization_factor; + + // The new scaling factor that we can use to dequantify the returned tensor (not used here) + // double new_SF = 1/std::pow(2,2*output_bits-1); + + ShiftGELUforward(input_raw, output, scaling_factor,N, output_bits, size, dims_input); +} + +void Aidge::ShiftGELUImpl_cuda::backward() { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + + assert(op.getOutput(0)->grad() && "missing output #0"); + + const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad()); + + if (op.getInput(0)->grad()->dataType() == DataType::Float64) { + backward_<double>(output_grad); + } + else { + backward_<float>(output_grad); + } +} + +template <class T> +void Aidge::ShiftGELUImpl_cuda::backward_(const Tensor& output_grad) { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const T * input = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()); + + size_t size = output_grad.size(); + + T * output = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); + + const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr()); + ShiftGELUbackward(input, output_grad_raw, output, size); - // Return the normalized maximum - double final_sf = m / normalization_factor; - T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); - double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé - ShiftGELULaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input); } \ No newline at end of file diff --git a/src/operator/ShiftGELUImpl_CUDA_kernels.cu b/src/operator/ShiftGELUImpl_CUDA_kernels.cu index c51af4e51c1270bbe0b5c7cfe658a9d84614c58a..aabd89c04e960f9f19eca69247173168d3eaf71e 100644 --- a/src/operator/ShiftGELUImpl_CUDA_kernels.cu +++ b/src/operator/ShiftGELUImpl_CUDA_kernels.cu @@ -26,45 +26,35 @@ __device__ inline int ExpShift(int I,int N, double SF) int q = floorf(Ip / (I0)); int r = Ip -(I0*q); int Ib = r/2 - I0; - Ib = CLAMP(Ib * powf(2,N-q));//BitShift? + Ib = CLAMP(Ib * powf(2,N-q)); return (int)Ib; } namespace Aidge{ template <class T> -__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) { - /* - * Kernels du Forward de GeLU - * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T) - * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU - * geLUTensor => pointeur vers un bloc mémoire vide alloué sur le GPU - * SumTensor => pointeur vers un bloc mémoire vide alloué sur le GPU - * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs - * SF => Scaling Factor - * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé) - * output_bits => précision en bit souhaité (8 pour int8 par exemple) - */ - int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1 - int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2 - int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3 - - double SF_sig = SF * 1.702;// SF multiplié par une constante utilisé dans l'algo +__global__ void ShiftGELUforward_(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) { + + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + int z = blockIdx.z * blockDim.z + threadIdx.z; + + double SF_sig = SF * 1.702; double Final_SF = SF / powf(2,(output_bits-1)); if (x < dims[0] && y < dims[1] && z < dims[2]) { int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3]; - for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor) + for (int i = 0; i < dims[3]; i++) { int idx = maxIdx + i; quantized_tensor[idx] = roundf(input[idx] / SF); } int maxVal = quantized_tensor[maxIdx]; - for (int i = 1; i < dims[3]; i++) { // Computing max value + for (int i = 1; i < dims[3]; i++) { int idx = maxIdx + i; maxVal = MAX(maxVal, quantized_tensor[idx]); } int Max_Exp = ExpShift(-maxVal,N,SF_sig); - for (int i = 0; i < dims[3]; i++) { //Exponential (artihmetic) + for (int i = 0; i < dims[3]; i++) { int idx = maxIdx + i; GELUtensor[idx] = ExpShift(quantized_tensor[idx] - maxVal,N,SF_sig); if(GELUtensor[idx] > INT_MAX - Max_Exp) { @@ -74,7 +64,6 @@ __global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUten { SumTensor[idx] = floorf(INT_MAX/(GELUtensor[idx] + Max_Exp)); } - //SigMoidInt. SumTensor[idx] = floorf((GELUtensor[idx] * SumTensor[idx]) >> (31 - output_bits + 1)); quantized_tensor[idx] *= SumTensor[idx]; input[idx] = quantized_tensor[idx] * Final_SF; @@ -82,66 +71,186 @@ __global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUten } } -//TODO Template -void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { +template <> +void ShiftGELUforward<float>(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { - double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftgelu, utilisé pour déquantifier le tenseur renvoyé + double new_SF = 1/std::pow(2,2*output_bits-1); int dims_input_cuda[4]; if (dims_input.size() >= 4) { - // Fixed-size array to store the first 4 elements - - // Copy the first 4 elements from dims_input to dims_input2 for (std::size_t i = 0; i < 4; ++i) { dims_input_cuda[i] = static_cast<int>(dims_input[i]); } } - float* input_cuda_tensor; - cudaMalloc(&input_cuda_tensor,size*sizeof(float)); //| + cudaMalloc(&input_cuda_tensor,size*sizeof(float)); cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice); - //| - int* quantized_tensor; //| - cudaMalloc(&quantized_tensor,size*sizeof(int)); //| - //| => Allocation des blocs mémoire sur le GPU + + int* quantized_tensor; + cudaMalloc(&quantized_tensor,size*sizeof(int)); + int* GELUtensor; cudaMalloc(&GELUtensor,size*sizeof(int)); int* SumTensor; - cudaMalloc(&SumTensor,size*sizeof(int)); //| - //| - int* dims; //| + cudaMalloc(&SumTensor,size*sizeof(int)); + + int* dims; cudaMalloc(&dims,4*sizeof(int)); cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice); - dim3 threadsPerBlock(10, 10, 10); //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur - dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, //| le GPU pour en fonctions des dimensions du tenseur en entrée - (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, //| - (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); //| + dim3 threadsPerBlock(10, 10, 10); + dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, + (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, + (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); - ShiftGELUWholeKernel<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);//Lancement du Kernel - cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU + ShiftGELUforward_<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8); + cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); if(err != cudaSuccess) { - printf("Erreur CUDA: %s\n", cudaGetErrorString(err)); - } //Checks des possibles erreurs sur le GPU + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + + cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); + cudaFree(quantized_tensor); + cudaFree(GELUtensor); + cudaFree(SumTensor); + cudaFree(dims); + cudaFree(input_cuda_tensor); +} - //float* ControlFinal = (float*)malloc(size*sizeof(float)); - //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); - //MyTensor<float> control(ControlFinal,x.dims); +template <> +void ShiftGELUforward<double>(const double* input, double* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { - cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); + double new_SF = 1/std::pow(2,2*output_bits-1); + + int dims_input_cuda[4]; + if (dims_input.size() >= 4) { + for (std::size_t i = 0; i < 4; ++i) { + dims_input_cuda[i] = static_cast<int>(dims_input[i]); + } + } + + double* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor,size*sizeof(double)); + cudaMemcpy(input_cuda_tensor,input,size*sizeof(double),cudaMemcpyHostToDevice); + + int* quantized_tensor; + cudaMalloc(&quantized_tensor,size*sizeof(int)); + + int* GELUtensor; + cudaMalloc(&GELUtensor,size*sizeof(int)); + + int* SumTensor; + cudaMalloc(&SumTensor,size*sizeof(int)); + + int* dims; + cudaMalloc(&dims,4*sizeof(int)); + + cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice); + + dim3 threadsPerBlock(10, 10, 10); + dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, + (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, + (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); - cudaFree(quantized_tensor); //| - cudaFree(GELUtensor); //| - cudaFree(SumTensor); //| - cudaFree(dims); //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros) - cudaFree(input_cuda_tensor);//| + ShiftGELUforward_<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8); + cudaDeviceSynchronize(); + + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + + cudaMemcpy(output,input_cuda_tensor,size*sizeof(double),cudaMemcpyDeviceToHost); + + cudaFree(quantized_tensor); + cudaFree(GELUtensor); + cudaFree(SumTensor); + cudaFree(dims); + cudaFree(input_cuda_tensor); +} + +template <class T> +__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + float x = output_tensor[index]; + float grad = output_grad[index]; + + float cdf = 0.5 * (1.0 + tanh(sqrt(2.0 / M_PI) * (x + 0.044715 * pow(x, 3)))); + float pdf = exp(-0.5 * x * x) / sqrt(2.0 * M_PI); + float dx = pdf + x * cdf; + float backprop_grad = grad * dx; + input_grad[index] = backprop_grad; + } +} + +template <> +void ShiftGELUbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size) +{ + float* output_cuda_tensor; + cudaMalloc(&output_cuda_tensor,size*sizeof(float)); + cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice); + + float* output_grad_; + cudaMalloc(&output_grad_,size*sizeof(float)); + cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice); + + float *input_grad_; + cudaMalloc(&input_grad_, size * sizeof(float)); + + dim3 threadParBlock(256); + dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); + + ShiftGELUbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size); + cudaDeviceSynchronize(); + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + cudaMemcpy(input_grad,input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost); + cudaFree(output_cuda_tensor); + cudaFree(input_grad_); + cudaFree(output_grad_); +} + +template <> +void ShiftGELUbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size) +{ + double* output_cuda_tensor; + cudaMalloc(&output_cuda_tensor,size*sizeof(double)); + cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice); + + double* output_grad_; + cudaMalloc(&output_grad_,size*sizeof(double)); + cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice); + + double *input_grad_; + cudaMalloc(&input_grad_, size * sizeof(double)); + + dim3 threadParBlock(256); + dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); + + ShiftGELUbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size); + cudaDeviceSynchronize(); + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + cudaMemcpy(input_grad,input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost); + cudaFree(output_cuda_tensor); + cudaFree(input_grad_); + cudaFree(output_grad_); } } \ No newline at end of file diff --git a/src/operator/ShiftMaxImpl.cpp b/src/operator/ShiftMaxImpl.cpp index 470bac840969eb114d7b7e34e51fb9e733d41ae9..1134cc5d6b99e53eb492c82e32d811bc0bcba0e0 100644 --- a/src/operator/ShiftMaxImpl.cpp +++ b/src/operator/ShiftMaxImpl.cpp @@ -34,19 +34,13 @@ void Aidge::ShiftMaxImpl_cuda::forward() { assert(mOp.getRawInput(0) && "missing input #0"); const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0)); - //forward_<float>(input); - - //should template and changing type switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) { case DataType::Float64: - forward_<float>(input); + forward_<double>(input); break; case DataType::Float32: forward_<float>(input); break; - case DataType::Float16: - forward_<float>(input); - break; default: AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda"); } @@ -57,13 +51,15 @@ void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input) { const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr()); + T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); int N = 15; int output_bits = 8; - size_t size = input.size(); std::vector<DimSize_t> dims_input = input.dims(); + // maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value) + double min = std::numeric_limits<double>::max(); double max = std::numeric_limits<double>::min(); for(std::size_t i = 0; i < dims_input[0]; i++) { @@ -84,13 +80,42 @@ void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input) } double m = std::max(std::abs(min), std::abs(max)); - - // Calculate the normalization factor double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1; + double scaling_factor = m / normalization_factor; + + // The new scaling factor that we can use to dequantify the returned tensor (not used here) + // double new_SF = 1/std::pow(2,2*output_bits-1); + + ShiftMaxforward(input_raw, output, scaling_factor,N, output_bits, size, dims_input); +} + + +void Aidge::ShiftMaxImpl_cuda::backward() { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + + assert(op.getOutput(0)->grad() && "missing output #0"); + + const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad()); + + if (op.getInput(0)->grad()->dataType() == DataType::Float64) { + backward_<double>(output_grad); + } + else { + backward_<float>(output_grad); + } +} + +template <class T> +void Aidge::ShiftMaxImpl_cuda::backward_(const Tensor& output_grad) { + const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp); + const T * output_tensor = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()); + + size_t size = output_grad.size(); + std::vector<DimSize_t> dims_output = output_grad.dims(); + + T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr()); + + const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr()); + ShiftMaxbackward(output_tensor, output_grad_raw, input_grad, size, dims_output); - // Return the normalized maximum - double final_sf = m / normalization_factor; - T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); - double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé - ShiftMaxLaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input); } \ No newline at end of file diff --git a/src/operator/ShiftMaxImpl_CUDA_kernels.cu b/src/operator/ShiftMaxImpl_CUDA_kernels.cu index c4c619cdde1a1ee123f19a3d75f3f14bc542bc52..ba3cfcb51e02fb0befbf9f7c1fc054e73a2a7157 100644 --- a/src/operator/ShiftMaxImpl_CUDA_kernels.cu +++ b/src/operator/ShiftMaxImpl_CUDA_kernels.cu @@ -26,53 +26,38 @@ __device__ inline int ExpShift(int I,int N, double SF) int q = floorf(Ip / (I0)); int r = Ip -(I0*q); int Ib = r/2 - I0; - Ib = CLAMP(Ib * powf(2,N-q));//BitShift? + Ib = CLAMP(Ib * powf(2,N-q)); return (int)Ib; } namespace Aidge{ template <class T> -__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) -/* - * Kernels du Forward de Shiftmax - * Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T) - * quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU - * factor => pointeur vers un bloc mémoire vide alloué sur le GPU - * dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs - * SF => Scaling Factor - * N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé) - * output_bits => précision en bit souhaité (8 pour int8 par exemple) - * new_SF => Nouveau SF pour déquantifier le tenseur - */ +__global__ void ShiftMaxforward_(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) { - int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1 - int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2 - int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3 + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + int z = blockIdx.z * blockDim.z + threadIdx.z; int sum = 0; - /* - * x,y et z représente les indices des dimensions 1,2 et 3 du tenseur, toutes les combinaisons possible de x,y et z - * sont appelés en parralèle ce qui permet le speedup GPU - * pour iterer dans la derniere dimensions on utilise les boucles "for" ci dessous - * */ + if (x < dims[0] && y < dims[1] && z < dims[2]) { int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3]; - for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor) + for (int i = 0; i < dims[3]; i++) { int idx = maxIdx + i; quantized_tensor[idx] = roundf(input[idx] / SF); } int maxVal = quantized_tensor[maxIdx]; - for (int i = 1; i < dims[3]; i++) { // max value par dimensions 4 + for (int i = 1; i < dims[3]; i++) { int idx = maxIdx + i; maxVal = MAX(maxVal, quantized_tensor[idx]); } - for (int i = 0; i < dims[3]; i++) { //Expo (artihmetic) + for (int i = 0; i < dims[3]; i++) { int idx = maxIdx + i; quantized_tensor[idx] = ExpShift(quantized_tensor[idx]-maxVal,N,SF); } - for (int i = 0; i < dims[3]; i++) { // Sum et clamp quand dépassement de valeur + for (int i = 0; i < dims[3]; i++) { int idx = maxIdx + i; - if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx])//CLAMP(2**31-1) + if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx]) { sum = INT_MAX; break; @@ -82,7 +67,7 @@ __global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, } } factor[x * dims[1] * dims[2] + y * dims[2] + z] = floorf(INT_MAX/sum); - for(int i= 0; i < dims[3]; ++i) //bitshift pour quantifier sur 8 bits + for(int i= 0; i < dims[3]; ++i) { int idx = maxIdx + i; quantized_tensor[idx] = (quantized_tensor[idx] * factor[x * dims[1] * dims[2] + y * dims[2] + z]) >> (31-(2*output_bits-1)); @@ -91,62 +76,211 @@ __global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, } } -//TODO Template -void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { +template <> +void ShiftMaxforward<float>(const float* input, float* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { + + double new_SF = 1 / std::pow(2, 2 * output_bits - 1); // New scaling factor + + int dims_input_cuda[4] = {1, 1, 1, 1}; + for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) { + dims_input_cuda[i] = static_cast<int>(dims_input[i]); + } + + // Allocate memory on the GPU + float* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor, size * sizeof(float)); + cudaMemcpy(input_cuda_tensor, input, size * sizeof(float), cudaMemcpyHostToDevice); + + int* quantized_tensor; + cudaMalloc(&quantized_tensor, size * sizeof(int)); + + int* factor; + cudaMalloc(&factor, size * sizeof(int)); + + int* dims; + cudaMalloc(&dims, 4 * sizeof(int)); + cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice); + + // Calculate grid and block dimensions + dim3 threadsPerBlock(10, 10, 10); + dim3 numBlocks( + (dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, + (dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, + (dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z + ); + + // Launch the kernel (assuming a templated ShiftMaxWholeKernel function exists) + ShiftMaxforward_<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF); + cudaDeviceSynchronize(); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl; + } + + // Copy the result back to host + cudaMemcpy(output, input_cuda_tensor, size * sizeof(float), cudaMemcpyDeviceToHost); + + // Free allocated memory on GPU + cudaFree(quantized_tensor); + cudaFree(factor); + cudaFree(dims); + cudaFree(input_cuda_tensor); +} + +template <> +void ShiftMaxforward<double>(const double* input, double* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) { + + double new_SF = 1 / std::pow(2, 2 * output_bits - 1); + + int dims_input_cuda[4] = {1, 1, 1, 1}; + for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) { + dims_input_cuda[i] = static_cast<int>(dims_input[i]); + } + + // Allocate memory on the GPU + double* input_cuda_tensor; + cudaMalloc(&input_cuda_tensor, size * sizeof(double)); + cudaMemcpy(input_cuda_tensor, input, size * sizeof(double), cudaMemcpyHostToDevice); + + int* quantized_tensor; + cudaMalloc(&quantized_tensor, size * sizeof(int)); + + int* factor; + cudaMalloc(&factor, size * sizeof(int)); + + int* dims; + cudaMalloc(&dims, 4 * sizeof(int)); + cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice); + + // Calculate grid and block dimensions + dim3 threadsPerBlock(10, 10, 10); + dim3 numBlocks( + (dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, + (dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, + (dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z + ); + + // Launch the kernel (assuming a templated ShiftMaxWholeKernel function exists) + ShiftMaxforward_<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF); + cudaDeviceSynchronize(); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl; + } - double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé + // Copy the result back to host + cudaMemcpy(output, input_cuda_tensor, size * sizeof(double), cudaMemcpyDeviceToHost); + + // Free allocated memory on GPU + cudaFree(quantized_tensor); + cudaFree(factor); + cudaFree(dims); + cudaFree(input_cuda_tensor); +} - int dims_input_cuda[4]; - if (dims_input.size() >= 4) { - // Fixed-size array to store the first 4 elements - // Copy the first 4 elements from dims_input to dims_input2 - for (std::size_t i = 0; i < 4; ++i) { - dims_input_cuda[i] = static_cast<int>(dims_input[i]); +template <class T> +__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < dims[0] * dims[1] * dims[2] * dims[3]) { + int w = (index / dims[3]) % dims[2]; + int h = (index / dims[3] / dims[2]) % dims[1]; + int n = index / dims[3] / dims[2] / dims[1]; + + float sum = 0.0f; + for (int i = 0; i < dims[3]; ++i) { + sum += output_tensor[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i] * output_grad[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i]; } - } + input_grad[index] = output_tensor[index] * (output_grad[index] - sum); + } +} + +template <> +void ShiftMaxbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size, std::vector<long unsigned int> dims) +{ + int dims_input_cuda[4] = {1, 1, 1, 1}; + for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) { + dims_input_cuda[i] = static_cast<int>(dims[i]); + } + float* output_cuda_tensor; + cudaMalloc(&output_cuda_tensor,size*sizeof(float)); + cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice); - float* input_cuda_tensor; - cudaMalloc(&input_cuda_tensor,size*sizeof(float)); //| - cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice); - //| - int* quantized_tensor; //| - cudaMalloc(&quantized_tensor,size*sizeof(int)); //| - //| => Allocation des blocs mémoire sur le GPU - int* factor; //| - cudaMalloc(&factor,size*sizeof(int)); //| - //| - int* dims; //| - cudaMalloc(&dims,4*sizeof(int)); - - cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice); - - dim3 threadsPerBlock(10, 10, 10); //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur - dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, //| le GPU pour en fonctions des dimensions du tenseur en entrée - (dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, //| - (dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); //| - - ShiftMaxWholeKernel<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,quantized_tensor,factor,dims,SF,N,output_bits,new_SF);//Lancement du Kernel - cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU + float* output_grad_; + cudaMalloc(&output_grad_,size*sizeof(float)); + cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice); + + float *input_grad_; + cudaMalloc(&input_grad_, size * sizeof(float)); + + int *dims_; + cudaMalloc(&dims_, 4 * sizeof(int)); + cudaMemcpy(dims_, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice); + + dim3 threadParBlock(256); + dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); + ShiftMaxbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_); + cudaDeviceSynchronize(); cudaError_t err = cudaGetLastError(); if(err != cudaSuccess) { - printf("Erreur CUDA: %s\n", cudaGetErrorString(err)); - } //Checks des possibles erreurs sur le GPU + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + + cudaMemcpy(input_grad, input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost); + cudaFree(output_cuda_tensor); + cudaFree(input_grad_); + cudaFree(dims_); + cudaFree(output_grad_); +} + +template <> +void ShiftMaxbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size, std::vector<long unsigned int> dims) +{ + int dims_input_cuda[4] = {1, 1, 1, 1}; + for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) { + dims_input_cuda[i] = static_cast<int>(dims[i]); + } + + double* output_cuda_tensor; + cudaMalloc(&output_cuda_tensor,size*sizeof(double)); + cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice); + + double* output_grad_; + cudaMalloc(&output_grad_,size*sizeof(double)); + cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice); + double *input_grad_; + cudaMalloc(&input_grad_, size * sizeof(double)); - //float* ControlFinal = (float*)malloc(size*sizeof(float)); - //cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); - //MyTensor<float> control(ControlFinal,x.dims); + int *dims_; + cudaMalloc(&dims_, 4 * sizeof(int)); + cudaMemcpy(dims_, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice); - cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost); + dim3 threadParBlock(256); + dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x); - cudaFree(quantized_tensor); //| - cudaFree(factor); //| - cudaFree(dims); //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros) - cudaFree(input_cuda_tensor);//| + ShiftMaxbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_); + cudaDeviceSynchronize(); + cudaError_t err = cudaGetLastError(); + if(err != cudaSuccess) + { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } + + cudaMemcpy(input_grad,input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost); + cudaFree(output_cuda_tensor); + cudaFree(input_grad_); + cudaFree(dims_); + cudaFree(output_grad_); } + + } \ No newline at end of file diff --git a/unit_tests/Test_ILayerNormImpl.cpp b/unit_tests/Test_ILayerNormImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0487b7c4716596e0d2e7bcbdaf812358be4de3bf --- /dev/null +++ b/unit_tests/Test_ILayerNormImpl.cpp @@ -0,0 +1,201 @@ +/******************************************************************************** + * Copyright (c) 2024 Thales + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * Author: Lucas RAKOTOARIVONY, Thales Research & Technology France + * Date: 10.09.2024 + * + ********************************************************************************/ + +#include <array> + +#include <catch2/catch_test_macros.hpp> + +#include "Test_cuda.hpp" + +#include "aidge/data/Tensor.hpp" + +#include "aidge/backend/cpu.hpp" +#include "aidge/backend/cuda.hpp" + +using namespace Aidge; + +TEST_CASE("[gpu/operator] ILayerNorm(forward)", "[ILayerNorm][GPU]") { + SECTION("4D Tensor") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + {0.96, 0.48, 0.54, 0.49, 0.59, 0.93, 0.00, 0.00, 0.61, 0.61}, + {0.85, 0.06, 0.11, 0.87, 0.55, 0.12, 0.80, 0.48, 0.41, 0.16} + }, + { + {0.24, 0.46, 0.97, 0.19, 0.65, 0.12, 0.44, 1.00, 0.37, 0.09}, + {0.44, 0.64, 0.21, 0.58, 0.05, 0.24, 0.56, 0.07, 0.49, 0.79} + } + }, + { + { + {0.00, 0.13, 0.55, 0.42, 0.49, 0.28, 0.52, 0.55, 0.34, 0.85}, + {0.98, 0.32, 0.09, 0.05, 0.37, 0.47, 0.63, 0.13, 0.70, 0.02} + }, + { + {0.69, 0.13, 0.74, 0.61, 0.25, 0.87, 0.46, 0.40, 0.81, 0.06}, + {0.89, 0.32, 0.61, 0.24, 0.70, 0.23, 0.09, 0.03, 0.14, 0.80} + } + } + } + }); + + std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float, 10>{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}); + std::shared_ptr<Tensor> myWeight = std::make_shared<Tensor>(Array1D<float, 10>{{0.1617684f, 0.3833238f ,-0.6842308f ,-0.4342245f ,-0.4717381f ,-0.1776187f, -0.2728751f, -0.4638580f, 0.2936697f, -0.9011016f}}); + + myWeight->setBackend("cuda"); + myBias->setBackend("cuda"); + + std::shared_ptr<Node> myILayerNorm = ILayerNorm(); + auto op = std::static_pointer_cast<OperatorTensor>(myILayerNorm -> getOperator()); + + op -> associateInput(1, myWeight); + op -> associateInput(2, myBias); + + input0->setBackend("cuda"); + + op -> associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->forward(); + + // expected output + std::shared_ptr<Tensor> output_ilayernorm = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + {9.8821178e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02}, + {4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00} + }, + { + {0.0000000e+00, 4.9410585e-02, 9.8821178e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 9.8821178e-02, 4.9410585e-02, 0.0000000e+00}, + {4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02} + } + }, + { + { + {0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02}, + {9.8821178e-02, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00} + }, + { + {4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00}, + {4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02} + } + } + } + }); + + + float* computedOutput = new float[output_ilayernorm->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_ilayernorm->size(), cudaMemcpyDeviceToHost); + + //test if forward result are as expected + for(int i = 0; i < output_ilayernorm->size(); i++){ + const float targetOutput = *(static_cast<float*>(output_ilayernorm->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + } + +} + +TEST_CASE("[gpu/operator] ILayerNorm(backward)", "[ILayerNorm][GPU]") + +{ + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW + { + { + { + {1.46650600, 1.24083233, -0.33106008, -0.15137172, 0.06625678, -1.8326609, 0.53444749, -0.05167147}, + }, + }, + } + }); + + std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW + { + { + { + {0.96, 0.54, 0.22, -0.15, 0.17, 0.26, -0.85, 0.5}, + }, + }, + } + }); + + std::shared_ptr<Tensor> myWeight = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW + { + { + { + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + }, + }, + } + }); + + + myWeight->setBackend("cuda"); + myBias->setBackend("cuda"); + + std::shared_ptr<Node> myILayerNorm = ILayerNorm(); + auto op = std::static_pointer_cast<OperatorTensor>(myILayerNorm -> getOperator()); + + op -> associateInput(1, myWeight); + op -> associateInput(2, myBias); + + input0->setBackend("cuda"); + + op -> associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + myILayerNorm->forward(); + + std::shared_ptr<Tensor> myOutputGrad = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 1.34347093, 0.90813798, 0.39607167, 1.20428133, 0.16845724, 0.48487359, 0.40748054, -0.21790814}, + }, + }, + } + }); + + + myOutputGrad->setBackend("cuda"); + std::shared_ptr<Tensor> predictedOutput = op->getOutput(0); + std::shared_ptr<Tensor> input = op->getInput(0); + predictedOutput->setGrad(myOutputGrad); + REQUIRE_NOTHROW(myILayerNorm->backward()); + + std::shared_ptr<Tensor> expectedInputGradILayerNorm = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 0.467678, 0.310749, 0.1129, 0.351786, 0.0507252, 0.101587, 0.130249, -0.0646476}, + }, + }, + } + }); + + + float *computedInputGradCuda = new float[myOutputGrad->size()](); + cudaMemcpy(computedInputGradCuda, op->getInput(0)->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost); + + //test if backward result are as expected + for(int i = 0; i < expectedInputGradILayerNorm->size(); i++){ + const float targetOutput = *(static_cast<float*>(expectedInputGradILayerNorm->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedInputGradCuda[i] - targetOutput) < 2e-6); + } + + delete[] computedInputGradCuda; +} diff --git a/unit_tests/Test_ShiftGELUImpl.cpp b/unit_tests/Test_ShiftGELUImpl.cpp index d5382c29cef00587eaed3dcd789352c4e8263d31..86e747e735eccb397caa8062f52c2561e8ef759d 100644 --- a/unit_tests/Test_ShiftGELUImpl.cpp +++ b/unit_tests/Test_ShiftGELUImpl.cpp @@ -51,6 +51,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") { } }); + //expected output of shiftgelu forward operator std::shared_ptr<Tensor> output_shiftGELU = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { { { @@ -76,6 +77,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") { } }); + //expected output of GELU forward operator (computed with PyTorch) std::shared_ptr<Tensor> output_GELU = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> { { { @@ -99,7 +101,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") { } } } - }); //value given by torch nn GELU + }); std::shared_ptr<Node> myShiftGELU = ShiftGELU(); auto op = std::static_pointer_cast<OperatorTensor>(myShiftGELU -> getOperator()); @@ -111,21 +113,108 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") { float* computedOutput = new float[output_shiftGELU->size()](); cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftGELU->size(), cudaMemcpyDeviceToHost); + //test if forward result are as expected for(int i = 0; i < output_shiftGELU->size(); i++){ const float targetOutput = *(static_cast<float*>(output_shiftGELU->getImpl()->rawPtr()) + i); REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); } + //measure difference between GELU and shiftgelu float sum = 0.0; for(int i = 0; i < output_GELU->size(); i++){ const float targetOutput = *(static_cast<float*>(output_GELU->getImpl()->rawPtr()) + i); sum += fabs(computedOutput[i] - targetOutput); } sum = sum / output_GELU->size(); - std::cout << sum << "\n"; REQUIRE(sum < 1.5e-1); delete[] computedOutput; } -} \ No newline at end of file +} + +TEST_CASE("[gpu/operator] ShiftGELU(backward)", "[ShiftGELU][GPU]") + +{ + + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW + { + { + { + {1.46650600, 1.24083233, -0.33106008, -0.15137172, 0.06625678, -1.8326609, 0.53444749, -0.05167147}, + }, + }, + } + }); + + input0->setBackend("cuda"); + + std::shared_ptr<Node> myShiftGELU = ShiftGELU(); + auto op = std::static_pointer_cast<OperatorTensor>(myShiftGELU->getOperator()); + op->associateInput(0, input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + myShiftGELU->forward(); + + std::shared_ptr<Tensor> myOutputGrad = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 1.34347093, 0.90813798, 0.39607167, 1.20428133, 0.16845724, 0.48487359, 0.40748054, -0.21790814}, + }, + }, + } + }); + + + myOutputGrad->setBackend("cuda"); + std::shared_ptr<Tensor> predictedOutput = op->getOutput(0); + std::shared_ptr<Tensor> input = op->getInput(0); + predictedOutput->setGrad(myOutputGrad); + REQUIRE_NOTHROW(myShiftGELU->backward()); + + //expected output of shiftgelu backward operator + std::shared_ptr<Tensor> expectedInputGradShiftGELU = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 1.88094, 1.09182, 0.134203, 0.439603, 0.0696628, 0.173469, 0.254718, -0.084009}, + }, + }, + } + }); + + //expected output of gelu backward operator (computed with PyTorch) + std::shared_ptr<Tensor> expectedInputGradGELU = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 1.5159, 1.0188, 0.0971, 0.4578, 0.0931, -0.0499, 0.3620, -0.1000}, + }, + }, + } + }); + + + float *computedGradCuda = new float[myOutputGrad->size()](); + + cudaMemcpy(computedGradCuda, input->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost); + + //test if backward result are as expected + for(int i = 0; i < expectedInputGradShiftGELU->size(); i++){ + const float targetOutput = *(static_cast<float*>(expectedInputGradShiftGELU->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedGradCuda[i] - targetOutput) < 2e-6); + } + + //measure difference between gelu and shifgelu + float sum = 0.0; + for(int i = 0; i < expectedInputGradGELU->size(); i++){ + const float targetOutput = *(static_cast<float*>(expectedInputGradGELU->getImpl()->rawPtr()) + i); + sum += fabs(computedGradCuda[i] - targetOutput); + } + sum = sum / expectedInputGradGELU->size(); + REQUIRE(sum < 2e-1); + + + delete[] computedGradCuda; +} diff --git a/unit_tests/Test_ShiftMaxImpl.cpp b/unit_tests/Test_ShiftMaxImpl.cpp index a9f8b2a665d9b77f8ba6ce7c4d251f9b6c0da166..2a94a23c3a04edd72cb535ebfb6e2c538e4aeee8 100644 --- a/unit_tests/Test_ShiftMaxImpl.cpp +++ b/unit_tests/Test_ShiftMaxImpl.cpp @@ -50,6 +50,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") { } } }); + //expected output of shiftmax forward operator std::shared_ptr<Tensor> output_shiftmax = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { { { @@ -74,6 +75,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") { } } }); + //expected output of softmax forward operator (computed with PyTorch) std::shared_ptr<Tensor> output_softmax = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> { { { @@ -97,7 +99,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") { } } } - }); //softmax value given by torch softmax + }); std::shared_ptr<Node> myShiftMax = ShiftMax(); auto op = std::static_pointer_cast<OperatorTensor>(myShiftMax -> getOperator()); @@ -109,11 +111,13 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") { float* computedOutput = new float[output_shiftmax->size()](); cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftmax->size(), cudaMemcpyDeviceToHost); + //test if forward result are as expected for(int i = 0; i < output_shiftmax->size(); i++){ const float targetOutput = *(static_cast<float*>(output_shiftmax->getImpl()->rawPtr()) + i); REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); } + //measure difference between softmax and shiftmax float sum = 0.0; for(int i = 0; i < output_softmax->size(); i++){ const float targetOutput = *(static_cast<float*>(output_softmax->getImpl()->rawPtr()) + i); @@ -125,4 +129,89 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") { delete[] computedOutput; } -} \ No newline at end of file +} + +TEST_CASE("[gpu/operator] ShiftMax(backward)", "[ShiftMax][GPU]") + +{ + + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW + { + { + { + {1.46650600, 1.24083233, -0.33106008, -0.15137172, 0.06625678, -1.8326609, 0.53444749, -0.05167147}, + }, + }, + } + }); + + input0->setBackend("cuda"); + + std::shared_ptr<Node> myShiftMax = ShiftMax(); + auto op = std::static_pointer_cast<OperatorTensor>(myShiftMax->getOperator()); + op->associateInput(0, input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + myShiftMax->forward(); + + std::shared_ptr<Tensor> myOutputGrad = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 1.34347093, 0.90813798, 0.39607167, 1.20428133, 0.16845724, 0.48487359, 0.40748054, -0.21790814}, + }, + }, + } + }); + + + myOutputGrad->setBackend("cuda"); + std::shared_ptr<Tensor> predictedOutput = op->getOutput(0); + std::shared_ptr<Tensor> input = op->getInput(0); + predictedOutput->setGrad(myOutputGrad); + REQUIRE_NOTHROW(myShiftMax->backward()); + + //expected output of shiftmax backward operator + std::shared_ptr<Tensor> expectedInputGradShiftMax = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 0.159378, 0.0249331, -0.0250217, 0.0262418, -0.0514701, -0.00459638, -0.0551896, -0.0739511}, + }, + }, + } + }); + + //expected output of softmax backward operator (computed with PyTorch) + std::shared_ptr<Tensor> expectedInputGradSoftmax = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { + { + { + { + { 0.1672, 0.0198, -0.0236, 0.0241, -0.0535, -0.0042, -0.0547, -0.0752}, + }, + }, + } + }); + + + float *computedGradCuda = new float[myOutputGrad->size()](); + + cudaMemcpy(computedGradCuda, input->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost); + + //test if backward result are as expected + for(int i = 0; i < expectedInputGradShiftMax->size(); i++){ + const float targetOutput = *(static_cast<float*>(expectedInputGradShiftMax->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedGradCuda[i] - targetOutput) < 1e-6); + } + + //measure difference between softmax and shiftmax + float sum = 0.0; + for(int i = 0; i < expectedInputGradSoftmax->size(); i++){ + const float targetOutput = *(static_cast<float*>(expectedInputGradSoftmax->getImpl()->rawPtr()) + i); + sum += fabs(computedGradCuda[i] - targetOutput); + } + sum = sum / expectedInputGradSoftmax->size(); + REQUIRE(sum < 4e-3); + + delete[] computedGradCuda; +}