Skip to content
Snippets Groups Projects
Commit cf02c586 authored by Maxence Naud's avatar Maxence Naud
Browse files

Remove iostream include from cudaUtils and ShiftMax kernel

parent 79f05729
No related branches found
No related tags found
No related merge requests found
#ifndef AIDGE_BACKEND_CUDA_CUDA_UTILS_H
#define AIDGE_BACKEND_CUDA_CUDA_UTILS_H
#ifndef AIDGE_BACKEND_CUDA_CUDA_UTILS_H_
#define AIDGE_BACKEND_CUDA_CUDA_UTILS_H_
#include <string>
#include <memory>
#include <sstream>
#include <iostream>
#include <stdexcept>
#include <fmt/core.h>
#include <fmt/format.h>
#include <cublas_v2.h>
#include <cuda.h>
......@@ -18,31 +18,29 @@
do { \
const cudnnStatus_t e = (status); \
if (e != CUDNN_STATUS_SUCCESS) { \
std::stringstream error; \
error << "CUDNN failure: " << cudnnGetErrorString(e) << " (" \
<< static_cast<int>(e) << ") in " << __FILE__ << ':' << __LINE__; \
int status_dev; \
if (cudaGetDevice(&status_dev) == cudaSuccess) \
error << " on device #" << status_dev; \
std::cerr << error.str() << std::endl; \
std::string error = fmt::format("CUDNN failure: {} ({}) in {}:{}", \
cudnnGetErrorString(e), static_cast<int>(e), __FILE__, __LINE__); \
int status_dev; \
if (cudaGetDevice(&status_dev) == cudaSuccess) \
error = fmt::format("{} on device #{}", error, status_dev); \
fmt::print(stderr, "{}\n", error); \
cudaDeviceReset(); \
throw std::runtime_error(error.str()); \
throw std::runtime_error(error); \
} \
} while(0)
#define CHECK_CUDA_STATUS(status) \
do { \
const cudaError_t e = (status); \
if ((e) != cudaSuccess) { \
std::stringstream error; \
error << "Cuda failure: " << cudaGetErrorString(e) << " (" \
<< static_cast<int>(e) << ") in " << __FILE__ << ':' << __LINE__; \
int status_dev; \
if (cudaGetDevice(&status_dev) == cudaSuccess) \
error << " on device #" << status_dev; \
std::cerr << error.str() << std::endl; \
if ((e) != cudaSuccess) { \
std::string error = fmt::format("Cuda failure: {} ({}) in {}:{}", \
cudaGetErrorString(e), static_cast<int>(e), __FILE__, __LINE__); \
int status_dev; \
if (cudaGetDevice(&status_dev) == cudaSuccess) \
error = fmt::format("{} on device #{}", error, status_dev); \
fmt::print(stderr, "{}\n", error); \
cudaDeviceReset(); \
throw std::runtime_error(error.str()); \
throw std::runtime_error(error); \
} \
} while(0)
......@@ -50,16 +48,14 @@
do { \
const cublasStatus_t e = (status); \
if (e != CUBLAS_STATUS_SUCCESS) { \
std::stringstream error; \
error << "Cublas failure: " \
<< Aidge::Cuda::cublasGetErrorString(e) << " (" \
<< static_cast<int>(e) << ") in " << __FILE__ << ':' << __LINE__; \
int status_dev; \
if (cudaGetDevice(&status_dev) == cudaSuccess) \
error << " on device #" << status_dev; \
std::cerr << error.str() << std::endl; \
std::string error = fmt::format("Cublas failure: {} ({}) in {}:{}", \
Aidge::Cuda::cublasGetErrorString(e), static_cast<int>(e), __FILE__, __LINE__); \
int status_dev; \
if (cudaGetDevice(&status_dev) == cudaSuccess) \
error = fmt::format("{} on device #{}", error, status_dev); \
fmt::print(stderr, "{}\n", error); \
cudaDeviceReset(); \
throw std::runtime_error(error.str()); \
throw std::runtime_error(error); \
} \
} while(0)
......@@ -96,4 +92,4 @@ namespace Cuda {
}
}
#endif // AIDGE_BACKEND_CUDA_CUDA_UTILS_H
#endif // AIDGE_BACKEND_CUDA_CUDA_UTILS_H_
......@@ -13,11 +13,17 @@
#define MAX(X,Y) (((X) > (Y)) ? (X) : (Y))
#define CLAMP(X) (((X) < (0)) ? (0) : (X))
#include <stdio.h>
#include <cuda_runtime.h>
#include "aidge/backend/cuda/operator/ShiftMaxImpl_CUDA_kernels.hpp"
#include <algorithm> // std::min
#include <math.h>
#include <cstddef> // std::size_t
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <fmt/core.h>
__device__ inline int ExpShift(int I,int N, double SF)
{
int Ip = I + (I >> 1) - (I >> 4);
......@@ -79,7 +85,7 @@ __global__ void ShiftMaxforward_(T* input,int* quantized_tensor,int* factor, int
template <>
void ShiftMaxforward<float>(const float* input, float* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
double new_SF = 1 / std::pow(2, 2 * output_bits - 1); // New scaling factor
double new_SF = 1 / powf(2, 2 * output_bits - 1); // New scaling factor
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
......@@ -116,7 +122,7 @@ void ShiftMaxforward<float>(const float* input, float* output, double SF, int N,
// Check for CUDA errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
fmt::print(stderr, "CUDA Error: {}\n", cudaGetErrorString(err));
}
// Copy the result back to host
......@@ -132,7 +138,7 @@ void ShiftMaxforward<float>(const float* input, float* output, double SF, int N,
template <>
void ShiftMaxforward<double>(const double* input, double* output, double SF, int N, int output_bits, size_t size, std::vector<long unsigned int> dims_input) {
double new_SF = 1 / std::pow(2, 2 * output_bits - 1);
double new_SF = 1 / powf(2, 2 * output_bits - 1);
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims_input.size(), size_t(4)); ++i) {
......@@ -169,7 +175,7 @@ void ShiftMaxforward<double>(const double* input, double* output, double SF, int
// Check for CUDA errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
fmt::print(stderr, "CUDA Error: {}\n", cudaGetErrorString(err));
}
// Copy the result back to host
......@@ -201,7 +207,7 @@ __global__ void ShiftMaxbackward_(T* input_grad, const T* output_tensor, const T
template <>
void ShiftMaxbackward<float>(const float* output_tensor, const float* output_grad, float* input_grad, size_t size, std::vector<long unsigned int> dims)
{
{
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims[i]);
......@@ -230,7 +236,7 @@ void ShiftMaxbackward<float>(const float* output_tensor, const float* output_gra
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
fmt::print(stderr, "CUDA Error: {}\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad, input_grad_, (size) * sizeof(float), cudaMemcpyDeviceToHost);
......@@ -242,7 +248,7 @@ void ShiftMaxbackward<float>(const float* output_tensor, const float* output_gra
template <>
void ShiftMaxbackward<double>(const double* output_tensor, const double* output_grad, double* input_grad, size_t size, std::vector<long unsigned int> dims)
{
{
int dims_input_cuda[4] = {1, 1, 1, 1};
for (std::size_t i = 0; i < std::min(dims.size(), size_t(4)); ++i) {
dims_input_cuda[i] = static_cast<int>(dims[i]);
......@@ -271,7 +277,7 @@ void ShiftMaxbackward<double>(const double* output_tensor, const double* output_
cudaError_t err = cudaGetLastError();
if(err != cudaSuccess)
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
fmt::print(stderr, "CUDA Error: {}\n", cudaGetErrorString(err));
}
cudaMemcpy(input_grad,input_grad_, (size) * sizeof(double), cudaMemcpyDeviceToHost);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment