Skip to content
Snippets Groups Projects
Commit ec78b3aa authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'main' into 'dev'

Integration of I-ViT operators

See merge request !34
parents 0b246d6a 4eb31361
No related branches found
No related tags found
No related merge requests found
Showing
with 1637 additions and 170 deletions
......@@ -37,4 +37,9 @@
#include "aidge/backend/cuda/operator/SubImpl.hpp"
#include "aidge/backend/cuda/operator/TanhImpl.hpp"
#include "aidge/backend/cuda/operator/ShiftMaxImpl.hpp"
#include "aidge/backend/cuda/operator/ShiftGELUImpl.hpp"
#include "aidge/backend/cuda/operator/ILayerNormImpl.hpp"
#endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */
/********************************************************************************
* Copyright (c) 2024 Thales
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
* Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
* Date: 10.09.2024
*
********************************************************************************/
#ifndef AIDGE_BACKEND_CUDA_OPERATOR_ILAYERNORMIMPL_H_
#define AIDGE_BACKEND_CUDA_OPERATOR_ILAYERNORMIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include <cudnn.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/ILayerNorm.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
class ILayerNormImpl_cuda : public OperatorImpl {
public:
ILayerNormImpl_cuda(const ILayerNorm_Op &op) : OperatorImpl(op, "cuda") {}
static std::unique_ptr<ILayerNormImpl_cuda> create(const ILayerNorm_Op &op) {
return std::make_unique<ILayerNormImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
{DataType::Float16},
};
}
void forward() override;
void backward() override;
private:
std::shared_ptr<Tensor> mInput0Fallback;
std::shared_ptr<Tensor> mInput1Fallback;
std::shared_ptr<Tensor> mInput2Fallback;
std::shared_ptr<Tensor> mOutputGradFallback;
template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2);
template <class T> void backward_(const Tensor& output_grad);
};
// Implementation entry point registration to Operator
REGISTRAR(ILayerNorm_Op, "cuda", Aidge::ILayerNormImpl_cuda::create);
} // namespace Aidge
#endif /* AIDGE_BACKEND_CUDA_OPERATOR_ILAYERNORMIMPL_H_ */
/********************************************************************************
* Copyright (c) 2024 Thales
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
* Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
* Date: 10.09.2024
*
********************************************************************************/
#ifndef AIDGE_CUDA_OPERATOR_ILAYERNORMIMPL_FORWARD_KERNEL_H_
#define AIDGE_CUDA_OPERATOR_ILAYERNORMIMPL_FORWARD_KERNEL_H_
#include <stdexcept>
#include <cfloat>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include "aidge/data/Data.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
/**
* @brief Compute the forward for ILayerNorm
* @param input: Input tensor
* @param SF: Scaling factor of input tensor
* @param dims: Dimensions of input tensor
* @param quantized_tensor: Quantized output tensor
* @param square_tensor: Tensor use for computation
* @param weight: weight of ILayerNorm layer
* @param bias: bias of ILayerNorm layer
* @param new_SF: Scaling factor of output that can be use to dequantify
*/
template <class T>
__global__ void ILayerNormforward_(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF);
/**
* @brief Wrapper function to execute ILayerNormforward_
* @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor
* @param input: Input tensor
* @param output: Output tensor (not quantized)
* @param SF: Scaling factor of input tensor
* @param weight_raw: weight of ILayerNorm layer
* @param bias_raw: bias of ILayerNorm layer
* @param size: Number of elements in the input tensor
* @param dims: Dimensions of input tensor
*/
template <class T>
void ILayerNormforward(const T* input, T* output, double SF, const T* weight_raw, const T* bias_raw, size_t size, std::vector<long unsigned int> dims_input);
/**
* @brief Compute the backward for ILayerNorm
* @param output_grad: Gradient of output tensor
* @param input_tensor: Input tensor
* @param output_tensor: Output tensor obtained after forward
* @param mean: Arithmetic mean of input tensor
* @param var: Arithmetic variance of input tensor
* @param weight: weight of ILayerNorm layer
* @param bias: bias of ILayerNorm layer
* @param input_grad: Gradient of input tensor
* @param weight_grad: Gradient of ILayerNorm weight
* @param bias_grad: Gradient of ILayerNorm bias
* @param size: Number of elements in the input tensor
*/
template <class T>
__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size);
/**
* @brief Wrapper function to execute ILayerNormbackward_
* @param input_tensor: Input tensor
* @param output_grad: Gradient of output tensor
* @param output_tensor: Output tensor obtained after forward
* @param mean: Arithmetic mean of input tensor
* @param var: Arithmetic variance of input tensor
* @param weight: weight of ILayerNorm layer
* @param bias: bias of ILayerNorm layer
* @param input_grad: Gradient of input tensor
* @param weight_grad: Gradient of ILayerNorm weight
* @param bias_grad: Gradient of ILayerNorm bias
* @param size: Number of elements in the input tensor
*/
template <class T>
void ILayerNormbackward(const T* input_tensor, const T* output_grad, const T* output_tensor,const T* mean,const T* var, const T* weight, const T* bias, T* input_grad, T* weight_grad, T* bias_grad, size_t size);
}
#endif /* AIDGE_CUDA_OPERATOR_ILAYERNORMIMPL_FORWARD_KERNEL_H_ */
\ No newline at end of file
......@@ -29,12 +29,11 @@
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
// Operator implementation entry point for the backend
class ShiftGELUImpl_cuda : public OperatorImpl {
public:
ShiftGELUImpl_cuda(const ShiftGELU_Op& op) : OperatorImpl(op, "cuda") {}
ShiftGELUImpl_cuda(const ShiftGELU_Op &op) : OperatorImpl(op, "cuda") {}
static std::unique_ptr<ShiftGELUImpl_cuda> create(const ShiftGELU_Op& op) {
static std::unique_ptr<ShiftGELUImpl_cuda> create(const ShiftGELU_Op &op) {
return std::make_unique<ShiftGELUImpl_cuda>(op);
}
......@@ -46,12 +45,17 @@ public:
};
}
void forward() override;
void backward() override;
private:
std::shared_ptr<Tensor> mInputFallback;
std::shared_ptr<Tensor> mOutputGradFallback;
template <class T> void forward_(const Tensor& input);
template <class T> void backward_(const Tensor& output_grad);
};
// Implementation entry point registration to Operator
......
......@@ -25,10 +25,54 @@
namespace Aidge {
extern void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
/**
* @brief Compute the forward for ShiftGELU
* @param input: Input tensor
* @param quantized_tensor: Quantized output tensor
* @param GELUtensor: Pointer to an empty memory block allocated on the GPU (just use for computation)
* @param SumTensor: Pointer to an empty memory block allocated on the GPU (just use for computation)
* @param dims: Dimensions of input tensor
* @param SF: Scaling factor of input tensor
* @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
* @param output_bits: Desired bit precision (8 for int8, for example)
*/
template <class T>
__global__ void ShiftGELUforward_(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits);
/**
* @brief Wrapper function to execute ShiftGELUforward_
* @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor
* @param input: Input tensor
* @param output: Output tensor (not quantized)
* @param SF: Scaling factor of input tensor
* @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
* @param output_bits: Desired bit precision (8 for int8, for example)
* @param size: Number of elements in the input tensor
* @param dims_input: Dimensions of input tensor
*/
template <class T>
void ShiftGELUforward(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
/**
* @brief Compute the backward for ShiftGELU
* @param input_grad: Gradient of input tensor (that we want to obtain)
* @param output_tensor: Output tensor obtained after forward
* @param output_grad: Gradient of output tensor
* @param size: Number of elements in the input tensor
*/
template <class T>
__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits);
__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size);
/**
* @brief Wrapper function to execute ShiftGELUbackward_
* @param output_tensor: Output tensor obtained after forward
* @param output_grad: Gradient of output tensor
* @param input_grad: Gradient of input tensor (that we want to obtain)
* @param size: Number of elements in the input tensor
*/
template <class T>
void ShiftGELUbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size);
}
#endif /* AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_KERNELS_H_ */
\ No newline at end of file
#endif /* AIDGE_CUDA_OPERATOR_SHIFTGELUIMPL_FORWARD_KERNEL_H_ */
......@@ -29,12 +29,11 @@
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
// Operator implementation entry point for the backend
class ShiftMaxImpl_cuda : public OperatorImpl {
public:
ShiftMaxImpl_cuda(const ShiftMax_Op& op) : OperatorImpl(op, "cuda") {}
ShiftMaxImpl_cuda(const ShiftMax_Op &op) : OperatorImpl(op, "cuda") {}
static std::unique_ptr<ShiftMaxImpl_cuda> create(const ShiftMax_Op& op) {
static std::unique_ptr<ShiftMaxImpl_cuda> create(const ShiftMax_Op &op) {
return std::make_unique<ShiftMaxImpl_cuda>(op);
}
......@@ -47,11 +46,15 @@ public:
}
void forward() override;
void backward() override;
private:
std::shared_ptr<Tensor> mInputFallback;
std::shared_ptr<Tensor> mOutputGradFallback;
template <class T> void forward_(const Tensor& input);
template <class T> void backward_(const Tensor& output_grad);
};
// Implementation entry point registration to Operator
......
......@@ -25,10 +25,55 @@
namespace Aidge {
extern void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
/**
* @brief Compute the forward for ShiftMax
* @param input: Input tensor
* @param quantized_tensor: Quantized output tensor
* @param factor: Pointer to an empty memory block allocated on the GPU (just use for computation)
* @param dims: Dimensions of input tensor
* @param SF: Scaling factor of input tensor
* @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
* @param output_bits: Desired bit precision (8 for int8, for example)
* @param new_SF: Scaling factor of output that can be use to dequantify
*/
template <class T>
__global__ void ShiftMaxforward_(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF);
/**
* @brief Wrapper function to execute ShiftMaxforward_
* @note Output correspond to the non-quantized tensor, to obtain the quantized tensor we need to copy quantized_tensor and not input_cuda_tensor
* @param input: Input tensor
* @param output: Output tensor (not quantized)
* @param SF: Scaling factor of input tensor
* @param N: Arithmetic precision, currently set at 15 like I-ViT (the greater the N, the more precise the operation, but the greater the number of bits required)
* @param output_bits: Desired bit precision (8 for int8, for example)
* @param size: Number of elements in the input tensor
* @param dims_input: Dimensions of input tensor
*/
template <class T>
void ShiftMaxforward(const T* input, T* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input);
/**
* @brief Compute the backward for ShiftMax
* @param input_grad: Gradient of input tensor (that we want to obtain)
* @param output_tensor: Output tensor obtained after forward
* @param output_grad: Gradient of output tensor
* @param dims: Dimensions of input tensor
*/
template <class T>
__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF) ;
__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims);
/**
* @brief Wrapper function to execute ShiftMaxbackward_
* @param output_tensor: Output tensor obtained after forward
* @param output_grad: Gradient of output tensor
* @param input_grad: Gradient of input tensor (that we want to obtain)
* @param size: Number of elements in the input tensor
* @param dims: Dimensions of input tensor
*/
template <class T>
void ShiftMaxbackward(const T* output_tensor, const T* output_grad, T* input_grad, size_t size, std::vector<long unsigned int> dims);
}
#endif /* AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_KERNELS_H_ */
\ No newline at end of file
#endif /* AIDGE_CUDA_OPERATOR_SHIFTMAXIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2024 Thales
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
* Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
* Date: 10.09.2024
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include <algorithm> // For std::max
#include <cmath> // For pow
#include <typeinfo>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/ILayerNormImpl.hpp"
#include "aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/ILayerNorm.hpp"
#include "aidge/utils/Types.h"
void Aidge::ILayerNormImpl_cuda::forward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
const auto& input0 = op.getInput(0)->refCastFrom(mInput0Fallback, *op.getOutput(0));
const auto& input1 = op.getInput(1)->refCastFrom(mInput1Fallback, *op.getOutput(0));
const auto& input2 = op.getInput(2)->refCastFrom(mInput2Fallback, *op.getOutput(0));
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(input0, input1, input2);
break;
case DataType::Float32:
forward_<float>(input0, input1, input2);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
}
template<class T>
void Aidge::ILayerNormImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2)
{
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T * input_raw = static_cast<const T*>(input0.getImpl()->rawPtr());
const T * weight = static_cast<const T*>(input1.getImpl()->rawPtr());
const T * bias = static_cast<const T*>(input2.getImpl()->rawPtr());
T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
int N = 15;
int output_bits = 8;
size_t size = input0.size();
std::vector<DimSize_t> dims_input = input0.dims();
// maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value)
double min = std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < dims_input[0]; i++) {
for(std::size_t j = 0; j < dims_input[1]; j++) {
for(std::size_t k = 0; k < dims_input[2]; k++) {
for(std::size_t l = 0; l < dims_input[3]; l++) {
std::vector<std::size_t> coordIdx = {i, j, k, l};
std::size_t newFlatIdx = input0.getIdx(coordIdx);
if (newFlatIdx < min) {
min = newFlatIdx;
}
if (newFlatIdx > max) {
max = newFlatIdx;
}
}
}
}
}
double m = std::max(std::abs(min), std::abs(max));
double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
double scaling_factor = m / normalization_factor;
// The new scaling factor that we can use to dequantify the returned tensor (not used here)
// double new_SF = 1/std::pow(2,2*output_bits-1);
ILayerNormforward(input_raw, output, scaling_factor, weight, bias, size, dims_input);
}
void Aidge::ILayerNormImpl_cuda::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(op.getOutput(0)->grad() && "missing output #0");
const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
backward_<double>(output_grad);
}
else {
backward_<float>(output_grad);
}
}
template <class T>
void Aidge::ILayerNormImpl_cuda::backward_(const Tensor& output_grad) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
size_t size = output_grad.size();
std::vector<DimSize_t> dims_input = output_grad.dims();
const T * output = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr());
T * weight_grad = static_cast<T*>(op.getInput(1)->grad()->getImpl()->rawPtr());
T * bias_grad = static_cast<T*>(op.getInput(2)->grad()->getImpl()->rawPtr());
const T * input = static_cast<const T*>(op.getInput(0)->getImpl()->rawPtr());
const T * weight = static_cast<const T*>(op.getInput(1)->getImpl()->rawPtr());
const T * bias = static_cast<const T*>(op.getInput(2)->getImpl()->rawPtr());
// maybe find a most efficient way to compute mean and variance tensor
std::vector<std::vector<std::vector<std::vector<T>>>> means(dims_input[0],
std::vector<std::vector<std::vector<T>>>(dims_input[1],
std::vector<std::vector<T>>(dims_input[2],
std::vector<T>(dims_input[3], 0.0f))));
for (std::size_t i = 0; i < dims_input[0]; i++) {
for (std::size_t j = 0; j < dims_input[1]; j++) {
for (std::size_t k = 0; k < dims_input[2]; k++) {
T sum = 0.0f;
for (std::size_t l = 0; l < dims_input[3]; l++) {
std::vector<std::size_t> coordIdx = {i, j, k, l};
sum += output_grad.getIdx(coordIdx);
}
for (std::size_t l = 0; l < dims_input[3]; l++) {
std::vector<std::size_t> coordIdx = {i, j, k, l};
means[i][j][k][l] = sum / static_cast<T>(dims_input[3]);
}
}
}
}
std::vector<T> flat_means;
for (const auto &vec3d : means) {
for (const auto &vec2d : vec3d) {
for (const auto &vec1d : vec2d) {
flat_means.insert(flat_means.end(), vec1d.begin(), vec1d.end());
}
}
}
std::vector<std::vector<std::vector<std::vector<T>>>> vars(dims_input[0],
std::vector<std::vector<std::vector<T>>>(dims_input[1],
std::vector<std::vector<T>>(dims_input[2],
std::vector<T>(dims_input[3], 0.0f))));
for (std::size_t i = 0; i < dims_input[0]; i++) {
for (std::size_t j = 0; j < dims_input[1]; j++) {
for (std::size_t k = 0; k < dims_input[2]; k++) {
T sum_sq_diff = 0.0f;
for (std::size_t l = 0; l < dims_input[3]; l++) {
std::vector<std::size_t> coordIdx = {i, j, k, l};
T value = static_cast<T>(output_grad.getIdx(coordIdx));
T diff = value - means[i][j][k][l];
sum_sq_diff += diff * diff;
}
T variance = sum_sq_diff / static_cast<T>(dims_input[3]);
for (std::size_t l = 0; l < dims_input[3]; l++) {
vars[i][j][k][l] = variance;
}
}
}
}
std::vector<T> flat_vars;
for (const auto &vec3d : vars) {
for (const auto &vec2d : vec3d) {
for (const auto &vec1d : vec2d) {
flat_vars.insert(flat_vars.end(), vec1d.begin(), vec1d.end());
}
}
}
const T* mean_ = flat_means.data();
const T* var_ = flat_vars.data();
const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr());
ILayerNormbackward(output, output_grad_raw, input, mean_, var_, weight, bias, input_grad, weight_grad, bias_grad, size);
}
/********************************************************************************
* Copyright (c) 2024 Thales
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
* Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
* Date: 10.09.2024
*
********************************************************************************/
#define MAX(X,Y) (((X) > (Y)) ? (X) : (Y))
#define CLAMP(X) (((X) < (0)) ? (0) : (X))
#include <stdio.h>
#include <cuda_runtime.h>
#include "aidge/backend/cuda/operator/ILayerNormImpl_CUDA_kernels.hpp"
namespace Aidge{
template <class T>
__global__ void ILayerNormforward_(T* input, double SF, int* dims, int* quantized_tensor,long long int* square_tensor, T* weight, T* biase, double new_SF) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int z = blockIdx.z * blockDim.z + threadIdx.z;
int k = 1 << 16;
long long int sum = 0;
if (x < dims[0] && y < dims[1] && z < dims[2]) {
int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3];
int val;
int mean_val = 0;
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
val = roundf(input[idx] / SF);
quantized_tensor[idx] = val;
mean_val += val;
}
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
quantized_tensor[idx] -= (mean_val/dims[3]) ;
square_tensor[idx] = (quantized_tensor[idx] * quantized_tensor[idx]); // I-ViT code implementation
//square_tensor[idx] = (quantized_tensor[idx] * quantized_tensor[idx])/dims[3]; // I-ViT paper implementation
}
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
sum += square_tensor[idx];
biase[i] = (biase[i]/weight[i])/new_SF;
weight[i] = weight[i] * new_SF;
}
for(int h = 0; h < 10 ; h++)
{
k = floorf((k + floorf(sum / k))/2);
}
int factor = (((1 << 31) - 1) / k);
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
square_tensor[idx] = (biase[idx]/weight[idx])/new_SF;
quantized_tensor[idx] = (quantized_tensor[idx] * factor / 2) + biase[maxIdx];
input[idx] = quantized_tensor[idx] * new_SF;
}
}
}
template <>
void ILayerNormforward<float>(const float* input, float* output, double SF, const float* weight_raw, const float* bias_raw, size_t size, std::vector<long unsigned int> dims_input)
{
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims_input[i]);
}
double new_SF = std::sqrt(dims_input_cuda[3]) / (1 << 30);
float* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor,size*sizeof(float));
cudaMemcpy(input_cuda_tensor,input, size * sizeof(float),cudaMemcpyHostToDevice);
int *quantized_tensor;
cudaMalloc(&quantized_tensor, size * sizeof(int));
int *dims;
cudaMalloc(&dims, 4 * sizeof(int));
cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
float *weight;
cudaMalloc(&weight,dims_input_cuda[3]*sizeof(float));
cudaMemcpy(weight,weight_raw,dims_input_cuda[3]*sizeof(float),cudaMemcpyHostToDevice);
float *bias;
cudaMalloc(&bias,dims_input_cuda[3]*sizeof(float));
cudaMemcpy(bias,bias_raw,dims_input_cuda[3]*sizeof(float),cudaMemcpyHostToDevice);
long long int* Squaretensor;
cudaMalloc(&Squaretensor,(size)*sizeof(long long int));
dim3 threadsPerBlock(10, 10, 10);
dim3 numBlocks((dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
(dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
(dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
ILayerNormforward_<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
}
cudaMemcpy(output,input_cuda_tensor, (size ) * sizeof(float), cudaMemcpyDeviceToHost);
cudaFree(input_cuda_tensor);
cudaFree(weight);
cudaFree(bias);
cudaFree(dims);
cudaFree(quantized_tensor);
}
template <>
void ILayerNormforward<double>(const double* input, double* output, double SF, const double* weight_raw, const double* bias_raw, size_t size, std::vector<long unsigned int> dims_input)
{
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims_input[i]);
}
double new_SF = std::sqrt(dims_input_cuda[3]) / (1 << 30);
double* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor,size*sizeof(double));
cudaMemcpy(input_cuda_tensor,input, size * sizeof(double),cudaMemcpyHostToDevice);
int *quantized_tensor;
cudaMalloc(&quantized_tensor, size * sizeof(int));
int *dims;
cudaMalloc(&dims, 4 * sizeof(int));
cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
double *weight;
cudaMalloc(&weight,dims_input_cuda[3]*sizeof(double));
cudaMemcpy(weight,weight_raw,dims_input_cuda[3]*sizeof(double),cudaMemcpyHostToDevice);
double *bias;
cudaMalloc(&bias,dims_input_cuda[3]*sizeof(double));
cudaMemcpy(bias,bias_raw,dims_input_cuda[3]*sizeof(double),cudaMemcpyHostToDevice);
long long int* Squaretensor;
cudaMalloc(&Squaretensor,(size)*sizeof(long long int));
dim3 threadsPerBlock(10, 10, 10);
dim3 numBlocks((dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
(dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
(dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
ILayerNormforward_<double><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,SF,dims,quantized_tensor,Squaretensor,weight,bias,new_SF);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
}
cudaMemcpy(output,input_cuda_tensor, (size ) * sizeof(double), cudaMemcpyDeviceToHost);
cudaFree(input_cuda_tensor);
cudaFree(weight);
cudaFree(bias);
cudaFree(dims);
cudaFree(quantized_tensor);
}
template <class T>
__global__ void ILayerNormbackward_(T* output_grad, T* input_tensor, T* output_tensor, T* mean, T* var, T* weight, T* bias, T* input_grad, T* weight_grad, T* bias_grad, int size)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < size) {
T d_norm = output_grad[i] * weight[i];
T d_var = d_norm * (input_tensor[i] - mean[i]) * -0.5 * powf(var[i] + 1e-6, -1.5);
T d_mean = d_norm * -1 / sqrtf(var[i] + 1e-6) + d_var * -2 * mean[i] / size;
T d_input = d_norm / sqrtf(var[i] + 1e-6) + d_var * 2 * (input_tensor[i] - mean[i]) / size + d_mean / size;
input_grad[i] = d_input;
weight_grad[i] = output_grad[i] * output_tensor[i];
bias_grad[i] = output_grad[i];
}
}
template <>
void ILayerNormbackward<float>(const float* input_tensor, const float* output_grad, const float* output_tensor,const float* mean,const float* var, const float* weight, const float* bias, float* input_grad, float* weight_grad, float* bias_grad, size_t size)
{
float* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor,size*sizeof(float));
cudaMemcpy(input_cuda_tensor,input_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
float* output_grad_;
cudaMalloc(&output_grad_,size*sizeof(float));
cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice);
float* output_tensor_;
cudaMalloc(&output_tensor_,size*sizeof(float));
cudaMemcpy(output_tensor_,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
float* mean_;
cudaMalloc(&mean_,size*sizeof(float));
cudaMemcpy(mean_,mean,size*sizeof(float),cudaMemcpyHostToDevice);
float* var_;
cudaMalloc(&var_,size*sizeof(float));
cudaMemcpy(var_,var,size*sizeof(float),cudaMemcpyHostToDevice);
float* weight_;
cudaMalloc(&weight_,size*sizeof(float));
cudaMemcpy(weight_,weight,size*sizeof(float),cudaMemcpyHostToDevice);
float* bias_;
cudaMalloc(&bias_,size*sizeof(float));
cudaMemcpy(bias_,bias,size*sizeof(float),cudaMemcpyHostToDevice);
float* input_grad_;
cudaMalloc(&input_grad_,size*sizeof(float));
float* weight_grad_;
cudaMalloc(&weight_grad_,size*sizeof(float));
float* bias_grad_;
cudaMalloc(&bias_grad_,size*sizeof(float));
dim3 threadParBlock(256);
dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad , input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(weight_grad , weight_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(bias_grad , bias_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
cudaFree(input_cuda_tensor);
cudaFree(output_grad_);
cudaFree(mean_);
cudaFree(var_);
cudaFree(weight_);
cudaFree(bias_);
cudaFree(input_grad_);
cudaFree(weight_grad_);
cudaFree(bias_grad_);
}
template <>
void ILayerNormbackward<double>(const double* input_tensor, const double* output_grad, const double* output_tensor,const double* mean,const double* var, const double* weight, const double* bias, double* input_grad, double* weight_grad, double* bias_grad, size_t size)
{
double* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor,size*sizeof(double));
cudaMemcpy(input_cuda_tensor,input_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
double* output_grad_;
cudaMalloc(&output_grad_,size*sizeof(double));
cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice);
double* output_tensor_;
cudaMalloc(&output_tensor_,size*sizeof(double));
cudaMemcpy(output_tensor_,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
double* mean_;
cudaMalloc(&mean_,size*sizeof(double));
cudaMemcpy(mean_,mean,size*sizeof(double),cudaMemcpyHostToDevice);
double* var_;
cudaMalloc(&var_,size*sizeof(double));
cudaMemcpy(var_,var,size*sizeof(double),cudaMemcpyHostToDevice);
double* weight_;
cudaMalloc(&weight_,size*sizeof(double));
cudaMemcpy(weight_,weight,size*sizeof(double),cudaMemcpyHostToDevice);
double* bias_;
cudaMalloc(&bias_,size*sizeof(double));
cudaMemcpy(bias_,bias,size*sizeof(double),cudaMemcpyHostToDevice);
double* input_grad_;
cudaMalloc(&input_grad_,size*sizeof(double));
double* weight_grad_;
cudaMalloc(&weight_grad_,size*sizeof(double));
double* bias_grad_;
cudaMalloc(&bias_grad_,size*sizeof(double));
dim3 threadParBlock(256);
dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
ILayerNormbackward_<<<Blocks,threadParBlock>>>(output_grad_,input_cuda_tensor,output_tensor_,mean_,var_,weight_,bias_,input_grad_, weight_grad_, bias_grad_, size);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad , input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
cudaMemcpy(weight_grad , weight_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
cudaMemcpy(bias_grad , bias_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
cudaFree(input_cuda_tensor);
cudaFree(output_grad_);
cudaFree(mean_);
cudaFree(var_);
cudaFree(weight_);
cudaFree(bias_);
cudaFree(input_grad_);
cudaFree(weight_grad_);
cudaFree(bias_grad_);
}
}
\ No newline at end of file
......@@ -34,19 +34,13 @@ void Aidge::ShiftGELUImpl_cuda::forward() {
assert(mOp.getRawInput(0) && "missing input #0");
const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0));
//forward_<float>(input);
//should template and changing type
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<float>(input);
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<float>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
......@@ -57,13 +51,15 @@ void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input)
{
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr());
T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
int N = 15;
int output_bits = 8;
size_t size = input.size();
std::vector<DimSize_t> dims_input = input.dims();
// maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value)
double min = std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < dims_input[0]; i++) {
......@@ -84,13 +80,40 @@ void Aidge::ShiftGELUImpl_cuda::forward_(const Tensor& input)
}
double m = std::max(std::abs(min), std::abs(max));
// Calculate the normalization factor
double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
double scaling_factor = m / normalization_factor;
// The new scaling factor that we can use to dequantify the returned tensor (not used here)
// double new_SF = 1/std::pow(2,2*output_bits-1);
ShiftGELUforward(input_raw, output, scaling_factor,N, output_bits, size, dims_input);
}
void Aidge::ShiftGELUImpl_cuda::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(op.getOutput(0)->grad() && "missing output #0");
const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
backward_<double>(output_grad);
}
else {
backward_<float>(output_grad);
}
}
template <class T>
void Aidge::ShiftGELUImpl_cuda::backward_(const Tensor& output_grad) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T * input = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
size_t size = output_grad.size();
T * output = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr());
const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr());
ShiftGELUbackward(input, output_grad_raw, output, size);
// Return the normalized maximum
double final_sf = m / normalization_factor;
T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
ShiftGELULaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input);
}
\ No newline at end of file
......@@ -26,45 +26,35 @@ __device__ inline int ExpShift(int I,int N, double SF)
int q = floorf(Ip / (I0));
int r = Ip -(I0*q);
int Ib = r/2 - I0;
Ib = CLAMP(Ib * powf(2,N-q));//BitShift?
Ib = CLAMP(Ib * powf(2,N-q));
return (int)Ib;
}
namespace Aidge{
template <class T>
__global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) {
/*
* Kernels du Forward de GeLU
* Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T)
* quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU
* geLUTensor => pointeur vers un bloc mémoire vide alloué sur le GPU
* SumTensor => pointeur vers un bloc mémoire vide alloué sur le GPU
* dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs
* SF => Scaling Factor
* N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé)
* output_bits => précision en bit souhaité (8 pour int8 par exemple)
*/
int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1
int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2
int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3
double SF_sig = SF * 1.702;// SF multiplié par une constante utilisé dans l'algo
__global__ void ShiftGELUforward_(T* input,int* quantized_tensor,int* GELUtensor,int* SumTensor, int* dims, double SF, int N, int output_bits) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int z = blockIdx.z * blockDim.z + threadIdx.z;
double SF_sig = SF * 1.702;
double Final_SF = SF / powf(2,(output_bits-1));
if (x < dims[0] && y < dims[1] && z < dims[2]) {
int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3];
for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor)
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
quantized_tensor[idx] = roundf(input[idx] / SF);
}
int maxVal = quantized_tensor[maxIdx];
for (int i = 1; i < dims[3]; i++) { // Computing max value
for (int i = 1; i < dims[3]; i++) {
int idx = maxIdx + i;
maxVal = MAX(maxVal, quantized_tensor[idx]);
}
int Max_Exp = ExpShift(-maxVal,N,SF_sig);
for (int i = 0; i < dims[3]; i++) { //Exponential (artihmetic)
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
GELUtensor[idx] = ExpShift(quantized_tensor[idx] - maxVal,N,SF_sig);
if(GELUtensor[idx] > INT_MAX - Max_Exp) {
......@@ -74,7 +64,6 @@ __global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUten
{
SumTensor[idx] = floorf(INT_MAX/(GELUtensor[idx] + Max_Exp));
}
//SigMoidInt.
SumTensor[idx] = floorf((GELUtensor[idx] * SumTensor[idx]) >> (31 - output_bits + 1));
quantized_tensor[idx] *= SumTensor[idx];
input[idx] = quantized_tensor[idx] * Final_SF;
......@@ -82,66 +71,186 @@ __global__ void ShiftGELUWholeKernel(T* input,int* quantized_tensor,int* GELUten
}
}
//TODO Template
void ShiftGELULaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
template <>
void ShiftGELUforward<float>(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftgelu, utilisé pour déquantifier le tenseur renvoyé
double new_SF = 1/std::pow(2,2*output_bits-1);
int dims_input_cuda[4];
if (dims_input.size() >= 4) {
// Fixed-size array to store the first 4 elements
// Copy the first 4 elements from dims_input to dims_input2
for (std::size_t i = 0; i < 4; ++i) {
dims_input_cuda[i] = static_cast<int>(dims_input[i]);
}
}
float* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor,size*sizeof(float)); //|
cudaMalloc(&input_cuda_tensor,size*sizeof(float));
cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice);
//|
int* quantized_tensor; //|
cudaMalloc(&quantized_tensor,size*sizeof(int)); //|
//| => Allocation des blocs mémoire sur le GPU
int* quantized_tensor;
cudaMalloc(&quantized_tensor,size*sizeof(int));
int* GELUtensor;
cudaMalloc(&GELUtensor,size*sizeof(int));
int* SumTensor;
cudaMalloc(&SumTensor,size*sizeof(int)); //|
//|
int* dims; //|
cudaMalloc(&SumTensor,size*sizeof(int));
int* dims;
cudaMalloc(&dims,4*sizeof(int));
cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice);
dim3 threadsPerBlock(10, 10, 10); //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur
dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, //| le GPU pour en fonctions des dimensions du tenseur en entrée
(dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, //|
(dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); //|
dim3 threadsPerBlock(10, 10, 10);
dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
(dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
(dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
ShiftGELUWholeKernel<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);//Lancement du Kernel
cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU
ShiftGELUforward_<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
} //Checks des possibles erreurs sur le GPU
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
cudaFree(quantized_tensor);
cudaFree(GELUtensor);
cudaFree(SumTensor);
cudaFree(dims);
cudaFree(input_cuda_tensor);
}
//float* ControlFinal = (float*)malloc(size*sizeof(float));
//cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
//MyTensor<float> control(ControlFinal,x.dims);
template <>
void ShiftGELUforward<double>(const double* input, double* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
double new_SF = 1/std::pow(2,2*output_bits-1);
int dims_input_cuda[4];
if (dims_input.size() >= 4) {
for (std::size_t i = 0; i < 4; ++i) {
dims_input_cuda[i] = static_cast<int>(dims_input[i]);
}
}
double* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor,size*sizeof(double));
cudaMemcpy(input_cuda_tensor,input,size*sizeof(double),cudaMemcpyHostToDevice);
int* quantized_tensor;
cudaMalloc(&quantized_tensor,size*sizeof(int));
int* GELUtensor;
cudaMalloc(&GELUtensor,size*sizeof(int));
int* SumTensor;
cudaMalloc(&SumTensor,size*sizeof(int));
int* dims;
cudaMalloc(&dims,4*sizeof(int));
cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice);
dim3 threadsPerBlock(10, 10, 10);
dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
(dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
(dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z);
cudaFree(quantized_tensor); //|
cudaFree(GELUtensor); //|
cudaFree(SumTensor); //|
cudaFree(dims); //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros)
cudaFree(input_cuda_tensor);//|
ShiftGELUforward_<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor,GELUtensor,SumTensor, dims, SF,N,8);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(output,input_cuda_tensor,size*sizeof(double),cudaMemcpyDeviceToHost);
cudaFree(quantized_tensor);
cudaFree(GELUtensor);
cudaFree(SumTensor);
cudaFree(dims);
cudaFree(input_cuda_tensor);
}
template <class T>
__global__ void ShiftGELUbackward_(T* input_grad, const T* output_tensor, const T* output_grad, int size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
float x = output_tensor[index];
float grad = output_grad[index];
float cdf = 0.5 * (1.0 + tanh(sqrt(2.0 / M_PI) * (x + 0.044715 * pow(x, 3))));
float pdf = exp(-0.5 * x * x) / sqrt(2.0 * M_PI);
float dx = pdf + x * cdf;
float backprop_grad = grad * dx;
input_grad[index] = backprop_grad;
}
}
template <>
void ShiftGELUbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size)
{
float* output_cuda_tensor;
cudaMalloc(&output_cuda_tensor,size*sizeof(float));
cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
float* output_grad_;
cudaMalloc(&output_grad_,size*sizeof(float));
cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice);
float *input_grad_;
cudaMalloc(&input_grad_, size * sizeof(float));
dim3 threadParBlock(256);
dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
ShiftGELUbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad,input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
cudaFree(output_cuda_tensor);
cudaFree(input_grad_);
cudaFree(output_grad_);
}
template <>
void ShiftGELUbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size)
{
double* output_cuda_tensor;
cudaMalloc(&output_cuda_tensor,size*sizeof(double));
cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
double* output_grad_;
cudaMalloc(&output_grad_,size*sizeof(double));
cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice);
double *input_grad_;
cudaMalloc(&input_grad_, size * sizeof(double));
dim3 threadParBlock(256);
dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
ShiftGELUbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,size);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad,input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
cudaFree(output_cuda_tensor);
cudaFree(input_grad_);
cudaFree(output_grad_);
}
}
\ No newline at end of file
......@@ -34,19 +34,13 @@ void Aidge::ShiftMaxImpl_cuda::forward() {
assert(mOp.getRawInput(0) && "missing input #0");
const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0));
//forward_<float>(input);
//should template and changing type
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<float>(input);
forward_<double>(input);
break;
case DataType::Float32:
forward_<float>(input);
break;
case DataType::Float16:
forward_<float>(input);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
......@@ -57,13 +51,15 @@ void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input)
{
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T * input_raw = static_cast<const T*>(input.getImpl()->rawPtr());
T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
int N = 15;
int output_bits = 8;
size_t size = input.size();
std::vector<DimSize_t> dims_input = input.dims();
// maybe find a most efficient way to compute scaling factor (a max and min function could help to retrieve scaling factor value)
double min = std::numeric_limits<double>::max();
double max = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < dims_input[0]; i++) {
......@@ -84,13 +80,42 @@ void Aidge::ShiftMaxImpl_cuda::forward_(const Tensor& input)
}
double m = std::max(std::abs(min), std::abs(max));
// Calculate the normalization factor
double normalization_factor = static_cast<double>(1 << (output_bits - 1)) - 1;
double scaling_factor = m / normalization_factor;
// The new scaling factor that we can use to dequantify the returned tensor (not used here)
// double new_SF = 1/std::pow(2,2*output_bits-1);
ShiftMaxforward(input_raw, output, scaling_factor,N, output_bits, size, dims_input);
}
void Aidge::ShiftMaxImpl_cuda::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(op.getOutput(0)->grad() && "missing output #0");
const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
backward_<double>(output_grad);
}
else {
backward_<float>(output_grad);
}
}
template <class T>
void Aidge::ShiftMaxImpl_cuda::backward_(const Tensor& output_grad) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T * output_tensor = static_cast<const T*>(std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr());
size_t size = output_grad.size();
std::vector<DimSize_t> dims_output = output_grad.dims();
T * input_grad = static_cast<T*>(op.getInput(0)->grad()->getImpl()->rawPtr());
const T * output_grad_raw = static_cast<const T*>(output_grad.getImpl()->rawPtr());
ShiftMaxbackward(output_tensor, output_grad_raw, input_grad, size, dims_output);
// Return the normalized maximum
double final_sf = m / normalization_factor;
T * output = static_cast<T*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
ShiftMaxLaunchKernel(input_raw, output, final_sf,N, output_bits, size, dims_input);
}
\ No newline at end of file
......@@ -26,53 +26,38 @@ __device__ inline int ExpShift(int I,int N, double SF)
int q = floorf(Ip / (I0));
int r = Ip -(I0*q);
int Ib = r/2 - I0;
Ib = CLAMP(Ib * powf(2,N-q));//BitShift?
Ib = CLAMP(Ib * powf(2,N-q));
return (int)Ib;
}
namespace Aidge{
template <class T>
__global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF)
/*
* Kernels du Forward de Shiftmax
* Input => Tenseur représentant l'entrée (non quantifiée (flottant)) (pointeur vers le bloc de mémoire de type T)
* quantized_tensor => pointeur vers un bloc mémoire vide alloué sur le GPU
* factor => pointeur vers un bloc mémoire vide alloué sur le GPU
* dims => int[4] sous forme de pointeur qui représente les 4 dimensions du tenseurs
* SF => Scaling Factor
* N => precision du Softmax arithmétique (plus N est grand plus l'opération est précise mais plus elle nécessite un nombre de bit elevé)
* output_bits => précision en bit souhaité (8 pour int8 par exemple)
* new_SF => Nouveau SF pour déquantifier le tenseur
*/
__global__ void ShiftMaxforward_(T* input,int* quantized_tensor,int* factor, int* dims, double SF, int N, int output_bits,double new_SF)
{
int x = blockIdx.x * blockDim.x + threadIdx.x; // Dim1
int y = blockIdx.y * blockDim.y + threadIdx.y; // Dim2
int z = blockIdx.z * blockDim.z + threadIdx.z; // Dim3
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int z = blockIdx.z * blockDim.z + threadIdx.z;
int sum = 0;
/*
* x,y et z représente les indices des dimensions 1,2 et 3 du tenseur, toutes les combinaisons possible de x,y et z
* sont appelés en parralèle ce qui permet le speedup GPU
* pour iterer dans la derniere dimensions on utilise les boucles "for" ci dessous
* */
if (x < dims[0] && y < dims[1] && z < dims[2]) {
int maxIdx = x * dims[1] * dims[2] * dims[3] + y * dims[2] * dims[3] + z * dims[3];
for (int i = 0; i < dims[3]; i++) { //Quantization (1thread per last dim of tensor)
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
quantized_tensor[idx] = roundf(input[idx] / SF);
}
int maxVal = quantized_tensor[maxIdx];
for (int i = 1; i < dims[3]; i++) { // max value par dimensions 4
for (int i = 1; i < dims[3]; i++) {
int idx = maxIdx + i;
maxVal = MAX(maxVal, quantized_tensor[idx]);
}
for (int i = 0; i < dims[3]; i++) { //Expo (artihmetic)
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
quantized_tensor[idx] = ExpShift(quantized_tensor[idx]-maxVal,N,SF);
}
for (int i = 0; i < dims[3]; i++) { // Sum et clamp quand dépassement de valeur
for (int i = 0; i < dims[3]; i++) {
int idx = maxIdx + i;
if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx])//CLAMP(2**31-1)
if(quantized_tensor[idx] > 0 && sum > INT_MAX - quantized_tensor[idx])
{
sum = INT_MAX;
break;
......@@ -82,7 +67,7 @@ __global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor,
}
}
factor[x * dims[1] * dims[2] + y * dims[2] + z] = floorf(INT_MAX/sum);
for(int i= 0; i < dims[3]; ++i) //bitshift pour quantifier sur 8 bits
for(int i= 0; i < dims[3]; ++i)
{
int idx = maxIdx + i;
quantized_tensor[idx] = (quantized_tensor[idx] * factor[x * dims[1] * dims[2] + y * dims[2] + z]) >> (31-(2*output_bits-1));
......@@ -91,62 +76,211 @@ __global__ void ShiftMaxWholeKernel(T* input,int* quantized_tensor,int* factor,
}
}
//TODO Template
void ShiftMaxLaunchKernel(const float* input, float* output, double SF,int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
template <>
void ShiftMaxforward<float>(const float* input, float* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
double new_SF = 1 / std::pow(2, 2 * output_bits - 1); // New scaling factor
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims_input[i]);
}
// Allocate memory on the GPU
float* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor, size * sizeof(float));
cudaMemcpy(input_cuda_tensor, input, size * sizeof(float), cudaMemcpyHostToDevice);
int* quantized_tensor;
cudaMalloc(&quantized_tensor, size * sizeof(int));
int* factor;
cudaMalloc(&factor, size * sizeof(int));
int* dims;
cudaMalloc(&dims, 4 * sizeof(int));
cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
// Calculate grid and block dimensions
dim3 threadsPerBlock(10, 10, 10);
dim3 numBlocks(
(dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
(dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
(dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z
);
// Launch the kernel (assuming a templated ShiftMaxWholeKernel function exists)
ShiftMaxforward_<float><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF);
cudaDeviceSynchronize();
// Check for CUDA errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
}
// Copy the result back to host
cudaMemcpy(output, input_cuda_tensor, size * sizeof(float), cudaMemcpyDeviceToHost);
// Free allocated memory on GPU
cudaFree(quantized_tensor);
cudaFree(factor);
cudaFree(dims);
cudaFree(input_cuda_tensor);
}
template <>
void ShiftMaxforward<double>(const double* input, double* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
double new_SF = 1 / std::pow(2, 2 * output_bits - 1);
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims_input[i]);
}
// Allocate memory on the GPU
double* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor, size * sizeof(double));
cudaMemcpy(input_cuda_tensor, input, size * sizeof(double), cudaMemcpyHostToDevice);
int* quantized_tensor;
cudaMalloc(&quantized_tensor, size * sizeof(int));
int* factor;
cudaMalloc(&factor, size * sizeof(int));
int* dims;
cudaMalloc(&dims, 4 * sizeof(int));
cudaMemcpy(dims, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
// Calculate grid and block dimensions
dim3 threadsPerBlock(10, 10, 10);
dim3 numBlocks(
(dims_input_cuda[0] + threadsPerBlock.x - 1) / threadsPerBlock.x,
(dims_input_cuda[1] + threadsPerBlock.y - 1) / threadsPerBlock.y,
(dims_input_cuda[2] + threadsPerBlock.z - 1) / threadsPerBlock.z
);
// Launch the kernel (assuming a templated ShiftMaxWholeKernel function exists)
ShiftMaxforward_<double><<<numBlocks, threadsPerBlock>>>(input_cuda_tensor, quantized_tensor, factor, dims, SF, N, output_bits, new_SF);
cudaDeviceSynchronize();
// Check for CUDA errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
}
double new_SF = 1/std::pow(2,2*output_bits-1); // Le nouveau scaling factor renvoyé par la fonction shiftmax, utilisé pour déquantifier le tenseur renvoyé
// Copy the result back to host
cudaMemcpy(output, input_cuda_tensor, size * sizeof(double), cudaMemcpyDeviceToHost);
// Free allocated memory on GPU
cudaFree(quantized_tensor);
cudaFree(factor);
cudaFree(dims);
cudaFree(input_cuda_tensor);
}
int dims_input_cuda[4];
if (dims_input.size() >= 4) {
// Fixed-size array to store the first 4 elements
// Copy the first 4 elements from dims_input to dims_input2
for (std::size_t i = 0; i < 4; ++i) {
dims_input_cuda[i] = static_cast<int>(dims_input[i]);
template <class T>
__global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T* output_grad, const int* dims) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < dims[0] * dims[1] * dims[2] * dims[3]) {
int w = (index / dims[3]) % dims[2];
int h = (index / dims[3] / dims[2]) % dims[1];
int n = index / dims[3] / dims[2] / dims[1];
float sum = 0.0f;
for (int i = 0; i < dims[3]; ++i) {
sum += output_tensor[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i] * output_grad[n * dims[1] * dims[2] * dims[3] + h * dims[2] * dims[3] + w * dims[3] + i];
}
}
input_grad[index] = output_tensor[index] * (output_grad[index] - sum);
}
}
template <>
void ShiftMaxbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size, std::vector<long unsigned int> dims)
{
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims[i]);
}
float* output_cuda_tensor;
cudaMalloc(&output_cuda_tensor,size*sizeof(float));
cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(float),cudaMemcpyHostToDevice);
float* input_cuda_tensor;
cudaMalloc(&input_cuda_tensor,size*sizeof(float)); //|
cudaMemcpy(input_cuda_tensor,input,size*sizeof(float),cudaMemcpyHostToDevice);
//|
int* quantized_tensor; //|
cudaMalloc(&quantized_tensor,size*sizeof(int)); //|
//| => Allocation des blocs mémoire sur le GPU
int* factor; //|
cudaMalloc(&factor,size*sizeof(int)); //|
//|
int* dims; //|
cudaMalloc(&dims,4*sizeof(int));
cudaMemcpy(dims,dims_input_cuda,4*sizeof(int),cudaMemcpyHostToDevice);
dim3 threadsPerBlock(10, 10, 10); //| Calculs du nombre de thread par blocs et du nombre de bloc a lancé en parrallèle sur
dim3 numBlocks((dims_input[0] + threadsPerBlock.x - 1) / threadsPerBlock.x, //| le GPU pour en fonctions des dimensions du tenseur en entrée
(dims_input[1] + threadsPerBlock.y - 1) / threadsPerBlock.y, //|
(dims_input[2] + threadsPerBlock.z - 1) / threadsPerBlock.z); //|
ShiftMaxWholeKernel<float><<<numBlocks,threadsPerBlock>>>(input_cuda_tensor,quantized_tensor,factor,dims,SF,N,output_bits,new_SF);//Lancement du Kernel
cudaDeviceSynchronize(); //Attente de la fin d'execution du kernel. Ligne très importante puisque sans celle ci le CPU continue l'execution du programme sans attendre le retour du GPU
float* output_grad_;
cudaMalloc(&output_grad_,size*sizeof(float));
cudaMemcpy(output_grad_,output_grad,size*sizeof(float),cudaMemcpyHostToDevice);
float *input_grad_;
cudaMalloc(&input_grad_, size * sizeof(float));
int *dims_;
cudaMalloc(&dims_, 4 * sizeof(int));
cudaMemcpy(dims_, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
dim3 threadParBlock(256);
dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
ShiftMaxbackward_<float><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("Erreur CUDA: %s\n", cudaGetErrorString(err));
} //Checks des possibles erreurs sur le GPU
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad, input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
cudaFree(output_cuda_tensor);
cudaFree(input_grad_);
cudaFree(dims_);
cudaFree(output_grad_);
}
template <>
void ShiftMaxbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size, std::vector<long unsigned int> dims)
{
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims[i]);
}
double* output_cuda_tensor;
cudaMalloc(&output_cuda_tensor,size*sizeof(double));
cudaMemcpy(output_cuda_tensor,output_tensor,size*sizeof(double),cudaMemcpyHostToDevice);
double* output_grad_;
cudaMalloc(&output_grad_,size*sizeof(double));
cudaMemcpy(output_grad_,output_grad,size*sizeof(double),cudaMemcpyHostToDevice);
double *input_grad_;
cudaMalloc(&input_grad_, size * sizeof(double));
//float* ControlFinal = (float*)malloc(size*sizeof(float));
//cudaMemcpy(ControlFinal,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
//MyTensor<float> control(ControlFinal,x.dims);
int *dims_;
cudaMalloc(&dims_, 4 * sizeof(int));
cudaMemcpy(dims_, dims_input_cuda, 4 * sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(output,input_cuda_tensor,size*sizeof(float),cudaMemcpyDeviceToHost);
dim3 threadParBlock(256);
dim3 Blocks((size + threadParBlock.x -1) / threadParBlock.x);
cudaFree(quantized_tensor); //|
cudaFree(factor); //|
cudaFree(dims); //| => Free sur GPU et CPU (tout ce qui a été malloc et cudaMalloc en gros)
cudaFree(input_cuda_tensor);//|
ShiftMaxbackward_<double><<<Blocks,threadParBlock>>>(input_grad_,output_cuda_tensor,output_grad_,dims_);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad,input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
cudaFree(output_cuda_tensor);
cudaFree(input_grad_);
cudaFree(dims_);
cudaFree(output_grad_);
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2024 Thales
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
* Author: Lucas RAKOTOARIVONY, Thales Research & Technology France
* Date: 10.09.2024
*
********************************************************************************/
#include <array>
#include <catch2/catch_test_macros.hpp>
#include "Test_cuda.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/backend/cuda.hpp"
using namespace Aidge;
TEST_CASE("[gpu/operator] ILayerNorm(forward)", "[ILayerNorm][GPU]") {
SECTION("4D Tensor") {
std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
{
{
{
{0.96, 0.48, 0.54, 0.49, 0.59, 0.93, 0.00, 0.00, 0.61, 0.61},
{0.85, 0.06, 0.11, 0.87, 0.55, 0.12, 0.80, 0.48, 0.41, 0.16}
},
{
{0.24, 0.46, 0.97, 0.19, 0.65, 0.12, 0.44, 1.00, 0.37, 0.09},
{0.44, 0.64, 0.21, 0.58, 0.05, 0.24, 0.56, 0.07, 0.49, 0.79}
}
},
{
{
{0.00, 0.13, 0.55, 0.42, 0.49, 0.28, 0.52, 0.55, 0.34, 0.85},
{0.98, 0.32, 0.09, 0.05, 0.37, 0.47, 0.63, 0.13, 0.70, 0.02}
},
{
{0.69, 0.13, 0.74, 0.61, 0.25, 0.87, 0.46, 0.40, 0.81, 0.06},
{0.89, 0.32, 0.61, 0.24, 0.70, 0.23, 0.09, 0.03, 0.14, 0.80}
}
}
}
});
std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float, 10>{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}});
std::shared_ptr<Tensor> myWeight = std::make_shared<Tensor>(Array1D<float, 10>{{0.1617684f, 0.3833238f ,-0.6842308f ,-0.4342245f ,-0.4717381f ,-0.1776187f, -0.2728751f, -0.4638580f, 0.2936697f, -0.9011016f}});
myWeight->setBackend("cuda");
myBias->setBackend("cuda");
std::shared_ptr<Node> myILayerNorm = ILayerNorm();
auto op = std::static_pointer_cast<OperatorTensor>(myILayerNorm -> getOperator());
op -> associateInput(1, myWeight);
op -> associateInput(2, myBias);
input0->setBackend("cuda");
op -> associateInput(0,input0);
op->setDataType(DataType::Float32);
op->setBackend("cuda");
op->forward();
// expected output
std::shared_ptr<Tensor> output_ilayernorm = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
{
{
{
{9.8821178e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02},
{4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00}
},
{
{0.0000000e+00, 4.9410585e-02, 9.8821178e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 9.8821178e-02, 4.9410585e-02, 0.0000000e+00},
{4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02}
}
},
{
{
{0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02},
{9.8821178e-02, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00}
},
{
{4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00},
{4.9410585e-02, 4.9410585e-02, 4.9410585e-02, 0.0000000e+00, 4.9410585e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.9410585e-02}
}
}
}
});
float* computedOutput = new float[output_ilayernorm->size()]();
cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_ilayernorm->size(), cudaMemcpyDeviceToHost);
//test if forward result are as expected
for(int i = 0; i < output_ilayernorm->size(); i++){
const float targetOutput = *(static_cast<float*>(output_ilayernorm->getImpl()->rawPtr()) + i);
REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
}
}
}
TEST_CASE("[gpu/operator] ILayerNorm(backward)", "[ILayerNorm][GPU]")
{
std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW
{
{
{
{1.46650600, 1.24083233, -0.33106008, -0.15137172, 0.06625678, -1.8326609, 0.53444749, -0.05167147},
},
},
}
});
std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW
{
{
{
{0.96, 0.54, 0.22, -0.15, 0.17, 0.26, -0.85, 0.5},
},
},
}
});
std::shared_ptr<Tensor> myWeight = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW
{
{
{
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
},
},
}
});
myWeight->setBackend("cuda");
myBias->setBackend("cuda");
std::shared_ptr<Node> myILayerNorm = ILayerNorm();
auto op = std::static_pointer_cast<OperatorTensor>(myILayerNorm -> getOperator());
op -> associateInput(1, myWeight);
op -> associateInput(2, myBias);
input0->setBackend("cuda");
op -> associateInput(0,input0);
op->setDataType(DataType::Float32);
op->setBackend("cuda");
myILayerNorm->forward();
std::shared_ptr<Tensor> myOutputGrad = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 1.34347093, 0.90813798, 0.39607167, 1.20428133, 0.16845724, 0.48487359, 0.40748054, -0.21790814},
},
},
}
});
myOutputGrad->setBackend("cuda");
std::shared_ptr<Tensor> predictedOutput = op->getOutput(0);
std::shared_ptr<Tensor> input = op->getInput(0);
predictedOutput->setGrad(myOutputGrad);
REQUIRE_NOTHROW(myILayerNorm->backward());
std::shared_ptr<Tensor> expectedInputGradILayerNorm = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 0.467678, 0.310749, 0.1129, 0.351786, 0.0507252, 0.101587, 0.130249, -0.0646476},
},
},
}
});
float *computedInputGradCuda = new float[myOutputGrad->size()]();
cudaMemcpy(computedInputGradCuda, op->getInput(0)->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost);
//test if backward result are as expected
for(int i = 0; i < expectedInputGradILayerNorm->size(); i++){
const float targetOutput = *(static_cast<float*>(expectedInputGradILayerNorm->getImpl()->rawPtr()) + i);
REQUIRE(fabs(computedInputGradCuda[i] - targetOutput) < 2e-6);
}
delete[] computedInputGradCuda;
}
......@@ -51,6 +51,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
}
});
//expected output of shiftgelu forward operator
std::shared_ptr<Tensor> output_shiftGELU = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
{
{
......@@ -76,6 +77,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
}
});
//expected output of GELU forward operator (computed with PyTorch)
std::shared_ptr<Tensor> output_GELU = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> {
{
{
......@@ -99,7 +101,7 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
}
}
}
}); //value given by torch nn GELU
});
std::shared_ptr<Node> myShiftGELU = ShiftGELU();
auto op = std::static_pointer_cast<OperatorTensor>(myShiftGELU -> getOperator());
......@@ -111,21 +113,108 @@ TEST_CASE("[gpu/operator] ShiftGELU(forward)", "[ShiftGELU][GPU]") {
float* computedOutput = new float[output_shiftGELU->size()]();
cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftGELU->size(), cudaMemcpyDeviceToHost);
//test if forward result are as expected
for(int i = 0; i < output_shiftGELU->size(); i++){
const float targetOutput = *(static_cast<float*>(output_shiftGELU->getImpl()->rawPtr()) + i);
REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
}
//measure difference between GELU and shiftgelu
float sum = 0.0;
for(int i = 0; i < output_GELU->size(); i++){
const float targetOutput = *(static_cast<float*>(output_GELU->getImpl()->rawPtr()) + i);
sum += fabs(computedOutput[i] - targetOutput);
}
sum = sum / output_GELU->size();
std::cout << sum << "\n";
REQUIRE(sum < 1.5e-1);
delete[] computedOutput;
}
}
\ No newline at end of file
}
TEST_CASE("[gpu/operator] ShiftGELU(backward)", "[ShiftGELU][GPU]")
{
std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW
{
{
{
{1.46650600, 1.24083233, -0.33106008, -0.15137172, 0.06625678, -1.8326609, 0.53444749, -0.05167147},
},
},
}
});
input0->setBackend("cuda");
std::shared_ptr<Node> myShiftGELU = ShiftGELU();
auto op = std::static_pointer_cast<OperatorTensor>(myShiftGELU->getOperator());
op->associateInput(0, input0);
op->setDataType(DataType::Float32);
op->setBackend("cuda");
myShiftGELU->forward();
std::shared_ptr<Tensor> myOutputGrad = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 1.34347093, 0.90813798, 0.39607167, 1.20428133, 0.16845724, 0.48487359, 0.40748054, -0.21790814},
},
},
}
});
myOutputGrad->setBackend("cuda");
std::shared_ptr<Tensor> predictedOutput = op->getOutput(0);
std::shared_ptr<Tensor> input = op->getInput(0);
predictedOutput->setGrad(myOutputGrad);
REQUIRE_NOTHROW(myShiftGELU->backward());
//expected output of shiftgelu backward operator
std::shared_ptr<Tensor> expectedInputGradShiftGELU = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 1.88094, 1.09182, 0.134203, 0.439603, 0.0696628, 0.173469, 0.254718, -0.084009},
},
},
}
});
//expected output of gelu backward operator (computed with PyTorch)
std::shared_ptr<Tensor> expectedInputGradGELU = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 1.5159, 1.0188, 0.0971, 0.4578, 0.0931, -0.0499, 0.3620, -0.1000},
},
},
}
});
float *computedGradCuda = new float[myOutputGrad->size()]();
cudaMemcpy(computedGradCuda, input->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost);
//test if backward result are as expected
for(int i = 0; i < expectedInputGradShiftGELU->size(); i++){
const float targetOutput = *(static_cast<float*>(expectedInputGradShiftGELU->getImpl()->rawPtr()) + i);
REQUIRE(fabs(computedGradCuda[i] - targetOutput) < 2e-6);
}
//measure difference between gelu and shifgelu
float sum = 0.0;
for(int i = 0; i < expectedInputGradGELU->size(); i++){
const float targetOutput = *(static_cast<float*>(expectedInputGradGELU->getImpl()->rawPtr()) + i);
sum += fabs(computedGradCuda[i] - targetOutput);
}
sum = sum / expectedInputGradGELU->size();
REQUIRE(sum < 2e-1);
delete[] computedGradCuda;
}
......@@ -50,6 +50,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
}
}
});
//expected output of shiftmax forward operator
std::shared_ptr<Tensor> output_shiftmax = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
{
{
......@@ -74,6 +75,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
}
}
});
//expected output of softmax forward operator (computed with PyTorch)
std::shared_ptr<Tensor> output_softmax = std::make_shared<Tensor>(Array4D<float, 2, 2, 2, 10> {
{
{
......@@ -97,7 +99,7 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
}
}
}
}); //softmax value given by torch softmax
});
std::shared_ptr<Node> myShiftMax = ShiftMax();
auto op = std::static_pointer_cast<OperatorTensor>(myShiftMax -> getOperator());
......@@ -109,11 +111,13 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
float* computedOutput = new float[output_shiftmax->size()]();
cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * output_shiftmax->size(), cudaMemcpyDeviceToHost);
//test if forward result are as expected
for(int i = 0; i < output_shiftmax->size(); i++){
const float targetOutput = *(static_cast<float*>(output_shiftmax->getImpl()->rawPtr()) + i);
REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
}
//measure difference between softmax and shiftmax
float sum = 0.0;
for(int i = 0; i < output_softmax->size(); i++){
const float targetOutput = *(static_cast<float*>(output_softmax->getImpl()->rawPtr()) + i);
......@@ -125,4 +129,89 @@ TEST_CASE("[gpu/operator] ShiftMax(forward)", "[ShiftMax][GPU]") {
delete[] computedOutput;
}
}
\ No newline at end of file
}
TEST_CASE("[gpu/operator] ShiftMax(backward)", "[ShiftMax][GPU]")
{
std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,1,1,1,8> { //NCHW
{
{
{
{1.46650600, 1.24083233, -0.33106008, -0.15137172, 0.06625678, -1.8326609, 0.53444749, -0.05167147},
},
},
}
});
input0->setBackend("cuda");
std::shared_ptr<Node> myShiftMax = ShiftMax();
auto op = std::static_pointer_cast<OperatorTensor>(myShiftMax->getOperator());
op->associateInput(0, input0);
op->setDataType(DataType::Float32);
op->setBackend("cuda");
myShiftMax->forward();
std::shared_ptr<Tensor> myOutputGrad = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 1.34347093, 0.90813798, 0.39607167, 1.20428133, 0.16845724, 0.48487359, 0.40748054, -0.21790814},
},
},
}
});
myOutputGrad->setBackend("cuda");
std::shared_ptr<Tensor> predictedOutput = op->getOutput(0);
std::shared_ptr<Tensor> input = op->getInput(0);
predictedOutput->setGrad(myOutputGrad);
REQUIRE_NOTHROW(myShiftMax->backward());
//expected output of shiftmax backward operator
std::shared_ptr<Tensor> expectedInputGradShiftMax = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 0.159378, 0.0249331, -0.0250217, 0.0262418, -0.0514701, -0.00459638, -0.0551896, -0.0739511},
},
},
}
});
//expected output of softmax backward operator (computed with PyTorch)
std::shared_ptr<Tensor> expectedInputGradSoftmax = std::make_shared<Tensor>(Array4D<float,1,1,1,8> {
{
{
{
{ 0.1672, 0.0198, -0.0236, 0.0241, -0.0535, -0.0042, -0.0547, -0.0752},
},
},
}
});
float *computedGradCuda = new float[myOutputGrad->size()]();
cudaMemcpy(computedGradCuda, input->grad()->getImpl()->rawPtr(), sizeof(float) * myOutputGrad->size(), cudaMemcpyDeviceToHost);
//test if backward result are as expected
for(int i = 0; i < expectedInputGradShiftMax->size(); i++){
const float targetOutput = *(static_cast<float*>(expectedInputGradShiftMax->getImpl()->rawPtr()) + i);
REQUIRE(fabs(computedGradCuda[i] - targetOutput) < 1e-6);
}
//measure difference between softmax and shiftmax
float sum = 0.0;
for(int i = 0; i < expectedInputGradSoftmax->size(); i++){
const float targetOutput = *(static_cast<float*>(expectedInputGradSoftmax->getImpl()->rawPtr()) + i);
sum += fabs(computedGradCuda[i] - targetOutput);
}
sum = sum / expectedInputGradSoftmax->size();
REQUIRE(sum < 4e-3);
delete[] computedGradCuda;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment