Skip to content
Snippets Groups Projects
Commit c959af7a authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

Merge branch 'dev' of gitlab.eclipse.org:eclipse/aidge/aidge_backend_cuda into rework_unit_tests

parents 25693404 f08d08ed
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@
#include "aidge/backend/cuda/operator/ReLUImpl.hpp"
#include "aidge/backend/cuda/operator/ReshapeImpl.hpp"
#include "aidge/backend/cuda/operator/SigmoidImpl.hpp"
#include "aidge/backend/cuda/operator/SubImpl.hpp"
#include "aidge/backend/cuda/operator/TanhImpl.hpp"
#endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */
\ No newline at end of file
#endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */
......@@ -89,6 +89,10 @@ public:
std::size_t scalarSize() const noexcept override { return sizeof(T); }
void zeros() override final {
CHECK_CUDA_STATUS(cudaMemset(rawPtr(), T(0), mNbElts * sizeof(T)));
}
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "TensorImpl_cuda<{}>::copy(): copy length ({}) is above capacity ({})", typeid(T).name(), length, mNbElts);
const T* srcT = static_cast<const T *>(src);
......
......@@ -36,6 +36,7 @@ private:
cudnnActivationMode_t mReLUDesc = nullptr;
#endif
std::shared_ptr<Tensor> mInputFallback;
std::shared_ptr<Tensor> mOutputGradFallback;
public:
ReLUImpl_cuda(const ReLU_Op &op) : OperatorImpl(op, "cuda") {}
......@@ -46,10 +47,12 @@ public:
public:
void forward();
void backward();
~ReLUImpl_cuda();
private:
template <class T> void forward_(const Tensor& input);
template <class T> void backward_(const Tensor& output_grad);
};
namespace {
......
......@@ -36,6 +36,7 @@ private:
cudnnActivationMode_t mSigmoidDesc = nullptr;
#endif
std::shared_ptr<Tensor> mInputFallback;
std::shared_ptr<Tensor> mOutputGradFallback;
public:
SigmoidImpl_cuda(const Sigmoid_Op &op) : OperatorImpl(op, "cuda") {}
......@@ -46,10 +47,12 @@ public:
public:
void forward();
void backward();
~SigmoidImpl_cuda();
private:
template <class T> void forward_(const Tensor& input);
template <class T> void backward_(const Tensor& output_grad);
};
namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_BACKEND_CUDA_OPERATOR_SUBIMPL_H_
#define AIDGE_BACKEND_CUDA_OPERATOR_SUBIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include <cudnn.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
class SubImpl_cuda : public OperatorImpl {
private:
public:
SubImpl_cuda(const Sub_Op &op) : OperatorImpl(op, "cuda") {}
static std::unique_ptr<SubImpl_cuda> create(const Sub_Op &op) {
return std::make_unique<SubImpl_cuda>(op);
}
public:
void forward();
// ~SubImpl_cuda();
private:
template <class T> void forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
};
namespace {
// add cuda backend to Sub_Op implementation registry
static Registrar<Sub_Op> registrarSubImpl_cuda("cuda", Aidge::SubImpl_cuda::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_BACKEND_CUDA_OPERATOR_SUBIMPL_H_ */
......@@ -36,6 +36,7 @@ private:
cudnnActivationMode_t mTanhDesc = nullptr;
#endif
std::shared_ptr<Tensor> mInputFallback;
std::shared_ptr<Tensor> mOutputGradFallback;
public:
TanhImpl_cuda(const Tanh_Op &op) : OperatorImpl(op, "cuda") {}
......@@ -46,10 +47,12 @@ public:
public:
void forward();
void backward();
~TanhImpl_cuda();
private:
template <class T> void forward_(const Tensor& input);
template <class T> void backward_(const Tensor& output_grad);
};
namespace {
......
......@@ -76,14 +76,14 @@ void Aidge::AddImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std:
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
// Create a Tensor descriptor with the broadcasted dims and strides
cudnnTensorDescriptor_t tesnsorDesc;
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tesnsorDesc));
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tesnsorDesc, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
cudnnTensorDescriptor_t tensorDesc;
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc));
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
// Add first input
CHECK_CUDNN_STATUS(
cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha,
tesnsorDesc,
tensorDesc,
inputs[0].getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
......@@ -92,16 +92,16 @@ void Aidge::AddImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std:
// Add other inputs if there are any
for (size_t i = 1; i < op.nbInputs(); ++i)
{
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tesnsorDesc, CudaContext::data_type<T>::value, inputsDims[i].size(), inputsDims[i].data(), inputsStrides[i].data()));
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[i].size(), inputsDims[i].data(), inputsStrides[i].data()));
CHECK_CUDNN_STATUS(
cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha,
tesnsorDesc,
tensorDesc,
inputs[i].getImpl()->rawPtr(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
);
}
CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tesnsorDesc));
CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc));
}
......@@ -64,6 +64,55 @@ void Aidge::ReLUImpl_cuda::forward_(const Tensor& input) {
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
}
void Aidge::ReLUImpl_cuda::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(op.getOutput(0)->grad() && "missing output #0");
const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
// Lazy-initialize CuDNN ReLU descriptor
if (mReLUDesc == nullptr) {
#if CUDNN_VERSION >= 5000
CHECK_CUDNN_STATUS(cudnnCreateActivationDescriptor(&mReLUDesc));
CHECK_CUDNN_STATUS(cudnnSetActivationDescriptor(
mReLUDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
#else
mReLUDesc = CUDNN_ACTIVATION_RELU;
#endif
}
// Do the actual backward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
backward_<double>(output_grad);
}
else {
backward_<float>(output_grad);
}
}
template <class T>
void Aidge::ReLUImpl_cuda::backward_(const Tensor& output_grad) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
CHECK_CUDNN_STATUS(
cudnnActivationBackward(CudaContext::cudnnHandle(),
mReLUDesc,
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(output_grad.getImpl())->getCudnnTensorDesc(output_grad),
output_grad.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
std::static_pointer_cast<Tensor>(op.getRawInput(0))->getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)->grad()),
op.getInput(0)->grad()->getImpl()->rawPtr()));
}
Aidge::ReLUImpl_cuda::~ReLUImpl_cuda() {
if (mReLUDesc != nullptr) {
#if CUDNN_VERSION >= 5000
......
......@@ -64,6 +64,55 @@ void Aidge::SigmoidImpl_cuda::forward_(const Tensor& input) {
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
}
void Aidge::SigmoidImpl_cuda::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(op.getOutput(0)->grad() && "missing output #0");
const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
// Lazy-initialize CuDNN Sigmoid descriptor
if (mSigmoidDesc == nullptr) {
#if CUDNN_VERSION >= 5000
CHECK_CUDNN_STATUS(cudnnCreateActivationDescriptor(&mSigmoidDesc));
CHECK_CUDNN_STATUS(cudnnSetActivationDescriptor(
mSigmoidDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0.0));
#else
mSigmoidDesc = CUDNN_ACTIVATION_SIGMOID;
#endif
}
// Do the actual backward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
backward_<double>(output_grad);
}
else {
backward_<float>(output_grad);
}
}
template <class T>
void Aidge::SigmoidImpl_cuda::backward_(const Tensor& output_grad) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
CHECK_CUDNN_STATUS(
cudnnActivationBackward(CudaContext::cudnnHandle(),
mSigmoidDesc,
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(output_grad.getImpl())->getCudnnTensorDesc(output_grad),
output_grad.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
std::static_pointer_cast<Tensor>(op.getRawInput(0))->getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)->grad()),
op.getInput(0)->grad()->getImpl()->rawPtr()));
}
Aidge::SigmoidImpl_cuda::~SigmoidImpl_cuda() {
if (mSigmoidDesc != nullptr) {
#if CUDNN_VERSION >= 5000
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/SubImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/utils/Types.h"
void Aidge::SubImpl_cuda::forward() {
const Sub_Op& op = static_cast<const Sub_Op&>(mOp);
// Check inputs
AIDGE_ASSERT(op.getInput(0), "missing input in Sub operator");
AIDGE_ASSERT(op.getInput(0)->hasImpl(), "cannot run Sub forward because the 0-th input has no implementation.");
DataType datatypeFirstInput = op.getInput(0)->dataType();
for (IOIndex_t i = 1; i < op.nbInputs(); ++i) {
AIDGE_ASSERT(op.getInput(i), "missing input in Sub operator");
AIDGE_ASSERT(op.getInput(i)->hasImpl(), "cannot run Sub forward because the {}-th input has no implementation.", i);
AIDGE_ASSERT(op.getInput(i)->dataType() == datatypeFirstInput, "Cannot add inputs with two differents data type.");
}
std::vector<std::shared_ptr<Tensor>> inputFallbacks(op.nbInputs());
std::vector<Tensor> inputs(op.nbInputs());
std::vector<std::vector<int>> dims(op.nbInputs()); // For broadcasted dims
std::vector<std::vector<int>> strides(op.nbInputs()); // For the cooresponding strides
for (IOIndex_t i = 0; i < op.nbInputs(); ++i) {
inputs[i] = op.getInput(i)->refCastFrom(inputFallbacks[i], *op.getOutput(0));
// Get tensor dims and broadcast them
std::copy(inputs[i].dims().begin(), inputs[i].dims().end(), std::back_inserter(dims[i]));
dims[i].insert(dims[i].cbegin(), op.getOutput(0)->nbDims() - dims[i].size(), int(1));
// Compute the corresponding strides
std::vector<int> tensorStrides(dims[i].size());
int product = 1;
for (size_t j = dims[i].size(); j > 0; --j) {
tensorStrides[j - 1] = product;
product *= dims[i][j - 1];
}
strides[i] = tensorStrides;
}
switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
case DataType::Float64:
forward_<double>(inputs, dims, strides);
break;
case DataType::Float32:
forward_<float>(inputs, dims, strides);
break;
case DataType::Float16:
forward_<half>(inputs, dims, strides);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
}
}
template <class T>
void Aidge::SubImpl_cuda::forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
const typename Cuda::cudnn_scaling_type<T>::type gamma = -1.0f;
// Create a Tensor descriptor with the broadcasted dims and strides
cudnnTensorDescriptor_t tensorDesc;
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&tensorDesc));
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[0].size(), inputsDims[0].data(), inputsStrides[0].data()));
// Add first input to the output
CHECK_CUDNN_STATUS(
cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha,
tensorDesc,
inputs[0].getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
);
// Substract other inputs if there are any
for (size_t i = 1; i < op.nbInputs(); ++i)
{
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(tensorDesc, CudaContext::data_type<T>::value, inputsDims[i].size(), inputsDims[i].data(), inputsStrides[i].data()));
CHECK_CUDNN_STATUS(
cudnnAddTensor(CudaContext::cudnnHandle(),
&gamma,
tensorDesc,
inputs[i].getImpl()->rawPtr(),
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr())
);
}
CHECK_CUDNN_STATUS(cudnnDestroyTensorDescriptor(tensorDesc));
}
......@@ -64,6 +64,55 @@ void Aidge::TanhImpl_cuda::forward_(const Tensor& input) {
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
}
void Aidge::TanhImpl_cuda::backward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
assert(op.getOutput(0)->grad() && "missing output #0");
const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());
// Lazy-initialize CuDNN Tanh descriptor
if (mTanhDesc == nullptr) {
#if CUDNN_VERSION >= 5000
CHECK_CUDNN_STATUS(cudnnCreateActivationDescriptor(&mTanhDesc));
CHECK_CUDNN_STATUS(cudnnSetActivationDescriptor(
mTanhDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0.0));
#else
mTanhDesc = CUDNN_ACTIVATION_SIGMOID;
#endif
}
// Do the actual backward computation
// Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision.
if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
backward_<double>(output_grad);
}
else {
backward_<float>(output_grad);
}
}
template <class T>
void Aidge::TanhImpl_cuda::backward_(const Tensor& output_grad) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
CHECK_CUDNN_STATUS(
cudnnActivationBackward(CudaContext::cudnnHandle(),
mTanhDesc,
&alpha,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(output_grad.getImpl())->getCudnnTensorDesc(output_grad),
output_grad.getImpl()->rawPtr(),
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
std::static_pointer_cast<Tensor>(op.getRawInput(0))->getImpl()->rawPtr(),
&beta,
std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)->grad()),
op.getInput(0)->grad()->getImpl()->rawPtr()));
}
Aidge::TanhImpl_cuda::~TanhImpl_cuda() {
if (mTanhDesc != nullptr) {
#if CUDNN_VERSION >= 5000
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment