diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp index cfae53b64115aa7946580d00f45be56f17163d7f..d34ad64b1b7751e850be6c900f086810d7c002ae 100644 --- a/include/aidge/backend/cuda.hpp +++ b/include/aidge/backend/cuda.hpp @@ -15,5 +15,6 @@ #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/ConvImpl.hpp" #include "aidge/backend/cuda/operator/ProducerImpl.hpp" +#include "aidge/backend/cuda/operator/ReLUImpl.hpp" #endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/operator/ReLUImpl.hpp b/include/aidge/backend/cuda/operator/ReLUImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b6cbcc6041a4757a919203fed6d080e30051d08 --- /dev/null +++ b/include/aidge/backend/cuda/operator/ReLUImpl.hpp @@ -0,0 +1,60 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_BACKEND_CUDA_OPERATOR_RELUIMPL_H_ +#define AIDGE_BACKEND_CUDA_OPERATOR_RELUIMPL_H_ + +#include <array> +#include <memory> +#include <tuple> +#include <vector> + +#include <cudnn.h> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cuda/utils/CudaUtils.hpp" + +namespace Aidge { +class ReLUImpl_cuda : public OperatorImpl { +private: + // CuDNN specific variables + #if CUDNN_VERSION >= 5000 + cudnnActivationDescriptor_t mReLUDesc = nullptr; + #else + cudnnActivationMode_t mReLUDesc = nullptr; + #endif + +public: + ReLUImpl_cuda(const ReLU_Op &op) : OperatorImpl(op) {} + + static std::unique_ptr<ReLUImpl_cuda> create(const ReLU_Op &op) { + return std::make_unique<ReLUImpl_cuda>(op); + } + +public: + void forward(); + ~ReLUImpl_cuda(); + +private: + template <class T> void forward_(const Tensor& input); +}; + +namespace { +// add cuda backend to ReLU_Op implementation registry +static Registrar<ReLU_Op> registrarReLUImpl_cuda("cuda", Aidge::ReLUImpl_cuda::create); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_BACKEND_CUDA_OPERATOR_RELUIMPL_H_ */ diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..46360879514cbb6aa1ac1898afdf9b4c49a07153 --- /dev/null +++ b/src/operator/ReLUImpl.cpp @@ -0,0 +1,70 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> + +#include "aidge/utils/Types.h" +#include "aidge/operator/ReLU.hpp" + +#include "aidge/backend/cuda/data/TensorImpl.hpp" +#include "aidge/backend/cuda/operator/ReLUImpl.hpp" +#include "aidge/backend/cuda/utils/CudaContext.hpp" + +void Aidge::ReLUImpl_cuda::forward() { + assert(mOp.getRawInput(0) && "missing input #0"); + + std::shared_ptr<Tensor> inputFallback; + const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + + // Lazy-initialize CuDNN ReLU descriptor + if (mReLUDesc == nullptr) { + #if CUDNN_VERSION >= 5000 + CHECK_CUDNN_STATUS(cudnnCreateActivationDescriptor(&mReLUDesc)); + CHECK_CUDNN_STATUS(cudnnSetActivationDescriptor( + mReLUDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0)); + #else + mReLUDesc = CUDNN_ACTIVATION_RELU; + #endif + } + + if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { + forward_<double>(input); + } + else { + forward_<float>(input); + } +} + +template <class T> +void Aidge::ReLUImpl_cuda::forward_(const Tensor& input) { + const T alpha = 1.0f; + const T beta = 0.0f; + CHECK_CUDNN_STATUS( + cudnnActivationForward(CudaContext::cudnnHandle(), + mReLUDesc, + &alpha, + dynamic_cast<TensorImpl_cuda_*>(input.getImpl().get())->getCudnnTensorDesc(), + input.getImpl()->rawPtr(), + &beta, + dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr())); +} + +Aidge::ReLUImpl_cuda::~ReLUImpl_cuda() { + if (mReLUDesc != nullptr) { + cudnnDestroyActivationDescriptor(mReLUDesc); + } +} + diff --git a/unit_tests/Test_ReLUImpl.cpp b/unit_tests/Test_ReLUImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82da6fae6737ee39fc60d771c10dc69fa2dea5f6 --- /dev/null +++ b/unit_tests/Test_ReLUImpl.cpp @@ -0,0 +1,200 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <array> + +#include <catch2/catch_test_macros.hpp> + +#include "Test_cuda.hpp" + +#include "aidge/data/Tensor.hpp" + +#include "aidge/backend/cpu.hpp" +#include "aidge/backend/cuda.hpp" + +using namespace Aidge; + + +TEST_CASE("[gpu/operator] ReLU(forward)", "[ReLU][GPU]") { + SECTION("1D Tensor") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array1D<float,10> { + {0, 1, 2,-3, 4,-5,-6, 7, 8, 9} + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array1D<float,10> { + {0, 1, 2, 0, 4, 0, 0, 7, 8, 9} + }); + + std::shared_ptr<Node> myReLU = ReLU(); + auto op = std::static_pointer_cast<OperatorTensor>(myReLU -> getOperator()); + op->associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->computeOutputDims(); + myReLU->forward(); + + float* computedOutput = new float[myOutput->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost); + + for(int i = 0; i < myOutput->size(); i++){ + const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + delete[] computedOutput; + } + + SECTION("2D Tensor") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array2D<float,2,10> { + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + } + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array2D<float,2,10> { + { + { 0, 1, 2, 0, 4, 0, 0, 7, 8, 9}, + { 0, 4, 2, 0, 4, 0, 0, 7, 0,10} + } + }); + + std::shared_ptr<Node> myReLU = ReLU(); + auto op = std::static_pointer_cast<OperatorTensor>(myReLU -> getOperator()); + op->associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->computeOutputDims(); + myReLU->forward(); + + float* computedOutput = new float[myOutput->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost); + + for(int i = 0; i < myOutput->size(); i++){ + const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + delete[] computedOutput; + } + + SECTION("3D Tensor") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array3D<float,2,2,10> { + { + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + }, + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + } + } + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array3D<float,2,2,10> { + { + { + { 0, 1, 2, 0, 4, 0, 0, 7, 8, 9}, + { 0, 4, 2, 0, 4, 0, 0, 7, 0,10} + }, + { + { 0, 1, 2, 0, 4, 0, 0, 7, 8, 9}, + { 0, 4, 2, 0, 4, 0, 0, 7, 0,10} + } + } + }); + + std::shared_ptr<Node> myReLU = ReLU(); + auto op = std::static_pointer_cast<OperatorTensor>(myReLU -> getOperator()); + op->associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->computeOutputDims(); + myReLU->forward(); + + float* computedOutput = new float[myOutput->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost); + + for(int i = 0; i < myOutput->size(); i++){ + const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + delete[] computedOutput; + } + + SECTION("4D Tensor") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + }, + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + } + }, + { + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + }, + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + } + } + } + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,2,2,2,10> { + { + { + { + { 0, 1, 2, 0, 4, 0, 0, 7, 8, 9}, + { 0, 4, 2, 0, 4, 0, 0, 7, 0,10} + }, + { + { 0, 1, 2, 0, 4, 0, 0, 7, 8, 9}, + { 0, 4, 2, 0, 4, 0, 0, 7, 0,10} + } + }, + { + { + { 0, 1, 2, 0, 4, 0, 0, 7, 8, 9}, + { 0, 4, 2, 0, 4, 0, 0, 7, 0,10} + }, + { + { 0, 1, 2, 0, 4, 0, 0, 7, 8, 9}, + { 0, 4, 2, 0, 4, 0, 0, 7, 0,10} + } + } + } + }); + + std::shared_ptr<Node> myReLU = ReLU(); + auto op = std::static_pointer_cast<OperatorTensor>(myReLU -> getOperator()); + op->associateInput(0,input0); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); + op->computeOutputDims(); + op->forward(); + + float* computedOutput = new float[myOutput->size()](); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost); + + for(int i = 0; i < myOutput->size(); i++){ + const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i); + REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6); + } + + delete[] computedOutput; + } +}