Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ReLUImpl.cpp 5.06 KiB
/********************************************************************************
 * Copyright (c) 2024 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <vector>

#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/ReLUImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/utils/Types.h"

void Aidge::ReLUImpl_cuda::forward() {
    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);

    assert(mOp.getRawInput(0) && "missing input #0");

    const auto& input = op.getInput(0)->refCastFrom(mInputFallback, *op.getOutput(0));

    // Lazy-initialize CuDNN ReLU descriptor
    if (mReLUDesc == nullptr) {
		#if CUDNN_VERSION >= 5000
			CHECK_CUDNN_STATUS(cudnnCreateActivationDescriptor(&mReLUDesc));
			CHECK_CUDNN_STATUS(cudnnSetActivationDescriptor(
				mReLUDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
		#else
			mReLUDesc = CUDNN_ACTIVATION_RELU;
		#endif
    }

    // Do the actual forward computation
    // Template is only for scaling parameters, which are always in float
    // excepted when the convolution is performed in double precision.
    if (op.getOutput(0)->dataType() == DataType::Float64) {
        forward_<double>(input);
    }
    else {
        forward_<float>(input);
    }
}

template <class T>
void Aidge::ReLUImpl_cuda::forward_(const Tensor& input) {
    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
    CHECK_CUDNN_STATUS(
        cudnnActivationForward(CudaContext::cudnnHandle(),
                               mReLUDesc,
                               &alpha,
							   std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
                               input.getImpl()->rawPtr(),
                               &beta,
                               std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                               std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
}

void Aidge::ReLUImpl_cuda::backward() {
    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);

    assert(op.getOutput(0)->grad() && "missing output #0");

    const auto& output_grad = op.getOutput(0)->grad()->refCastFrom(mOutputGradFallback, *op.getOutput(0)->grad());

    // Lazy-initialize CuDNN ReLU descriptor
    if (mReLUDesc == nullptr) {
		#if CUDNN_VERSION >= 5000
			CHECK_CUDNN_STATUS(cudnnCreateActivationDescriptor(&mReLUDesc));
			CHECK_CUDNN_STATUS(cudnnSetActivationDescriptor(
				mReLUDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
		#else
			mReLUDesc = CUDNN_ACTIVATION_RELU;
		#endif
    }

    // Do the actual backward computation
    // Template is only for scaling parameters, which are always in float
    // excepted when the convolution is performed in double precision.
    if (op.getInput(0)->grad()->dataType() == DataType::Float64) {
        backward_<double>(output_grad);
    }
    else {
        backward_<float>(output_grad);
    }
}

template <class T>
void Aidge::ReLUImpl_cuda::backward_(const Tensor& output_grad) {
    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
    const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f;
    const typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f;
    CHECK_CUDNN_STATUS(
        cudnnActivationBackward(CudaContext::cudnnHandle(),
                               mReLUDesc,
                               &alpha,
                               std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
                               std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr(),
                               std::dynamic_pointer_cast<TensorImpl_cuda_>(output_grad.getImpl())->getCudnnTensorDesc(output_grad),
                               output_grad.getImpl()->rawPtr(),
                               std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->getImpl())->getCudnnTensorDesc(*op.getInput(0)),
                               std::static_pointer_cast<Tensor>(op.getRawInput(0))->getImpl()->rawPtr(),
                               &beta,
                               std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getInput(0)->grad()->getImpl())->getCudnnTensorDesc(*op.getInput(0)->grad()),
                               op.getInput(0)->grad()->getImpl()->rawPtr()));
}

Aidge::ReLUImpl_cuda::~ReLUImpl_cuda() {
    if (mReLUDesc != nullptr) {
		#if CUDNN_VERSION >= 5000
            cudnnDestroyActivationDescriptor(mReLUDesc);
		#endif
    }
}