diff --git a/include/aidge/backend/cuda.hpp b/include/aidge/backend/cuda.hpp index a6bae174471e665f229d08a489d6b9f7911a6e9f..cfae53b64115aa7946580d00f45be56f17163d7f 100644 --- a/include/aidge/backend/cuda.hpp +++ b/include/aidge/backend/cuda.hpp @@ -14,5 +14,6 @@ #include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/operator/ConvImpl.hpp" +#include "aidge/backend/cuda/operator/ProducerImpl.hpp" #endif /* AIDGE_BACKEND_CUDA_IMPORTS_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 73b211113b3b21b9c8294e51e16cc001afad25e1..c61e926c88a9baf1fcdf64794c2a975a1b891356 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -5,11 +5,23 @@ #include "aidge/data/Tensor.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/future_std/span.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" namespace Aidge { + +template <typename SRC_T, typename DST_T> +void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/); +template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr> +void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size); +template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr> +void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size); +template <> +void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size); + /** * @brief Abstract class for the TensorImpl_cuda class template. * @details Its purpose is to provide access to base methods that are specific @@ -29,16 +41,31 @@ public: template <class T> class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { - private: +private: + static T* cudaAlloc(NbElts_t length) { + T* data; + CHECK_CUDA_STATUS(cudaMalloc(reinterpret_cast<void**>(&data), length * sizeof(T))); + return data; + } + + static void cudaDelete(T* data) { + // Should not be called if data is nullptr, according to the standard + cudaFree(data); + } + +private: const Tensor &mTensor; // Impl needs to access Tensor information, but is not // supposed to change it! - T* mData = nullptr; + /// Pointer to the data and its capacity + future_std::span<T> mData; + /// If this instance own the data, std::unique_ptr manages it + std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner; mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr; - public: +public: static constexpr const char *Backend = "cuda"; - TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {} + TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor), mDataOwner(nullptr, cudaDelete) {} bool operator==(const TensorImpl &otherImpl) const override final; @@ -47,23 +74,111 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { } // native interface - const T* data() const { return mData; } + const future_std::span<T>& data() const { return mData; } + std::size_t size() const override { return mData.size(); } std::size_t scalarSize() const override { return sizeof(T); } - void copy(const void *src, NbElts_t length) override { - CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyHostToDevice)); + void setDevice(DeviceIdx_t device) override { + mDevice = device; } - void *rawPtr() override { - lazyInit(reinterpret_cast<void**>(&mData)); - return mData; + void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override { + void* dst = static_cast<void*>(static_cast<T*>(rawPtr()) + offset); + CHECK_CUDA_STATUS(cudaMemcpy(dst, src, length * sizeof(T), cudaMemcpyDeviceToDevice)); } - void* getRaw(std::size_t idx) { - return static_cast<void*>(static_cast<T*>(rawPtr()) + idx); + void copyCast(const void *src, NbElts_t length, const DataType srcDt) override { + if (length == 0) { + return; + } + + AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); + if (srcDt == DataType::Float64) { + thrust_copy(static_cast<const double*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::Float32) { + thrust_copy(static_cast<const float*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::Float16) { + thrust_copy(static_cast<const half_float::half*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::Int64) { + thrust_copy(static_cast<const int64_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::UInt64) { + thrust_copy(static_cast<const uint64_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::Int32) { + thrust_copy(static_cast<const int32_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::UInt32) { + thrust_copy(static_cast<const uint32_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::Int16) { + thrust_copy(static_cast<const int16_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::UInt16) { + thrust_copy(static_cast<const uint16_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::Int8) { + thrust_copy(static_cast<const int8_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else if (srcDt == DataType::UInt8) { + thrust_copy(static_cast<const uint8_t*>(src), + static_cast<T*>(rawPtr()), + length); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); + } + } + + void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override { + AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); + CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice)); + } + + void copyFromHost(const void *src, NbElts_t length) override { + AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); + CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyHostToDevice)); } + void copyToHost(void *dst, NbElts_t length) const override { + AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); + CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost)); + } + + void *rawPtr(NbElts_t offset = 0) override { + lazyInit(); + return (mData.data() + offset); + }; + + const void *rawPtr(NbElts_t offset = 0) const override { + AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); + return (mData.data() + offset); + }; + const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override { if (mCudnnTensor == nullptr) { CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor)); @@ -97,22 +212,25 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { return mCudnnTensor; } - virtual ~TensorImpl_cuda() { - if (mData != nullptr) - cudaFree(mData); + void setRawPtr(void *ptr, NbElts_t length) override final { + AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity"); + mData = future_std::span<T>(static_cast<T *>(ptr), length); + mDataOwner.reset(); + }; + virtual ~TensorImpl_cuda() { if (mCudnnTensor != nullptr) cudnnDestroyTensorDescriptor(mCudnnTensor); } - void setRawPtr(void* /*ptr*/) override final { - printf("Not implemented yet."); - }; - - private: - void lazyInit(void** data) { - if (*data == nullptr) - CHECK_CUDA_STATUS(cudaMalloc(data, mTensor.size() * sizeof(T))); +private: + void lazyInit() { + if (mData.size() < mTensor.size()) { + // Need more data, a re-allocation will occur + AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data"); + mDataOwner.reset(cudaAlloc(mTensor.size())); + mData = future_std::span<T>(mDataOwner.get(), mTensor.size()); + } } }; @@ -121,6 +239,8 @@ static Registrar<Tensor> registrarTensorImpl_cuda_Float64( {"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create); static Registrar<Tensor> registrarTensorImpl_cuda_Float32( {"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create); +static Registrar<Tensor> registrarTensorImpl_cuda_Float16( + {"cuda", DataType::Float16}, Aidge::TensorImpl_cuda<half_float::half>::create); static Registrar<Tensor> registrarTensorImpl_cuda_Int32( {"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int>::create); } // namespace diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp index 31b80adbae0602211fa5c11873875a1a10eb40db..0ad995b24082782c611b93fbcc040d1319a7362f 100644 --- a/include/aidge/backend/cuda/operator/ConvImpl.hpp +++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp @@ -49,7 +49,7 @@ public: ~ConvImpl_cuda(); private: - template <class T> void forward_(); + template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2); }; namespace { diff --git a/include/aidge/backend/cuda/operator/ProducerImpl.hpp b/include/aidge/backend/cuda/operator/ProducerImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9912133072e23181df8f384841660bf89a829b60 --- /dev/null +++ b/include/aidge/backend/cuda/operator/ProducerImpl.hpp @@ -0,0 +1,40 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CUDA_OPERATOR_PRODUCERIMPL_H_ +#define AIDGE_CUDA_OPERATOR_PRODUCERIMPL_H_ + +#include <memory> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +class ProducerImpl_cuda : public OperatorImpl { +public: + ProducerImpl_cuda(const Producer_Op &op) : OperatorImpl(op) {} + + static std::unique_ptr<ProducerImpl_cuda> create(const Producer_Op &op) { + return std::make_unique<ProducerImpl_cuda>(op); + } + + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; + void forward() override; +}; + +namespace { +static Registrar<Producer_Op> registrarProducerImpl_cuda("cuda", Aidge::ProducerImpl_cuda::create); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CUDA_OPERATOR_PRODUCERIMPL_H_ */ diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp index a66ccdf690603c39ba4a7bf691f0dffea64ddddb..82dd395e6bbb33bae29c5d881290d6996bfb0332 100644 --- a/include/aidge/backend/cuda/utils/CudaContext.hpp +++ b/include/aidge/backend/cuda/utils/CudaContext.hpp @@ -128,6 +128,11 @@ public: } namespace Aidge { + template <> + struct CudaContext::data_type<half_float::half> { + static const cudnnDataType_t value = CUDNN_DATA_HALF; + }; + template <> struct CudaContext::data_type<float> { static const cudnnDataType_t value = CUDNN_DATA_FLOAT; @@ -139,25 +144,25 @@ namespace Aidge { }; inline cudnnDataType_t DataTypeToCudnn(DataType type) { - if (type == DataType::Float32) - return CUDNN_DATA_FLOAT; - - if (type == DataType::Float64) + switch (type) { + case DataType::Float64: return CUDNN_DATA_DOUBLE; - - if (type == DataType::Int8) + case DataType::Float32: + return CUDNN_DATA_FLOAT; + case DataType::Float16: + return CUDNN_DATA_HALF; + case DataType::Int8: return CUDNN_DATA_INT8; - - if (type == DataType::UInt8) + case DataType::UInt8: return CUDNN_DATA_UINT8; - - if (type == DataType::Int32) + case DataType::Int32: return CUDNN_DATA_INT32; - - if (type == DataType::Int64) + case DataType::Int64: return CUDNN_DATA_INT64; - - assert(false && "Unsupported CuDNN type"); + default: + assert(false && "Unsupported CuDNN type"); + } + return CUDNN_DATA_FLOAT; // TODO: undefined behavior } } diff --git a/include/aidge/backend/cuda/utils/CudaUtils.hpp b/include/aidge/backend/cuda/utils/CudaUtils.hpp index 76d7ea48e02473deeaa2cb0801a292623a666a1d..2f66d0e778778400f0b7def345619d635cc37674 100644 --- a/include/aidge/backend/cuda/utils/CudaUtils.hpp +++ b/include/aidge/backend/cuda/utils/CudaUtils.hpp @@ -67,24 +67,6 @@ namespace Cuda { // Enable Peer-to-Peer communications between devices // when it is possible void setMultiDevicePeerAccess(unsigned int size, unsigned int* devices); - - // CuDNN scaling parameters are typically "alpha" and "beta". - // Their type must be "float" for HALF and FLOAT (default template) - // and "double" for DOUBLE (specialized template) - template <class T> - struct cudnn_scaling_type { - typedef float type; - }; - - template <> - struct cudnn_scaling_type<double> { - typedef double type; - }; - - template <class T> - struct cuda_type { - typedef T type; - }; } } diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index ba2a1348b4c5eae8499884bfe2488d67f016a060..ecacd4d678dd7d79462332fb28e238b063d8bdd1 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -14,6 +14,79 @@ #include <thrust/equal.h> #include <thrust/device_ptr.h> +template <typename SRC_T, typename DST_T> +void Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size) +{ + const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData); + thrust::device_ptr<DST_T> thrustDstPtr(dstData); + thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr); +} + +template <typename SRC_T> +__global__ void +cudaCopyToH_kernel(const SRC_T* srcData, + __half* dstData, + size_t size) +{ + const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int stride = blockDim.x * gridDim.x; + + for (unsigned int i = index; i < size; i += stride) { + dstData[i] = __float2half(static_cast<float>(srcData[i])); + } +} + +template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr> +void Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) +{ + cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>> + (srcData, reinterpret_cast<__half*>(dstData), size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + +template <typename DST_T> +__global__ void +cudaCopyFromH_kernel(const __half* srcData, + DST_T* dstData, + size_t size) +{ + const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int stride = blockDim.x * gridDim.x; + + for (unsigned int i = index; i < size; i += stride) { + dstData[i] = static_cast<DST_T>(__half2float(srcData[i])); + } +} + +template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr> +void Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) +{ + cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>> + (reinterpret_cast<const __half*>(srcData), dstData, size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + +__global__ void +cudaCopyHToH_kernel(const __half* srcData, + __half* dstData, + size_t size) +{ + const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int stride = blockDim.x * gridDim.x; + + for (unsigned int i = index; i < size; i += stride) { + dstData[i] = srcData[i]; + } +} + +template <> +void Aidge::thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size) +{ + cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>> + (reinterpret_cast<const __half*>(srcData), reinterpret_cast<__half*>(dstData), size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + template <class T> bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl); @@ -21,7 +94,7 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { if (mTensor.size() != otherImplCuda.mTensor.size()) return false; - thrust::device_ptr<T> thrustData(mData); - thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData); + thrust::device_ptr<T> thrustData(mData.data()); + thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data()); return thrust::equal(thrustData, thrustData + mTensor.size(), thrustOtherData); } diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index 515f5f19d7702ea5bc037b672e182e97800a703b..9c3684e89f6b27133ca99be16b332c4e9f9a27b1 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -25,8 +25,17 @@ template <Aidge::DimIdx_t DIM> void Aidge::ConvImpl_cuda<DIM>::forward() { // FIXME: uncomment the following code once memory handling will work - assert(mOp.getInput(0) && "missing input #0"); - assert(mOp.getInput(1) && "missing input #1"); + assert(mOp.getRawInput(0) && "missing input #0"); + assert(mOp.getRawInput(1) && "missing input #1"); + + // Convert input data (no overhead if not needed!) + // TODO: right now, if needed, memory will be allocated/deallocated at each + // call to forward(). We might put the following shared_ptr as members of + // this class to avoid that. + std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback; + const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); // Lazy-initialize CuDNN convolution descriptor if (mConvDesc == nullptr) { @@ -43,16 +52,16 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { &strides[0], &upscales[0], CUDNN_CROSS_CORRELATION, - DataTypeToCudnn(mOp.getOutput(0)->dataType()))); + DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()))); } // Lazy-initialize CuDNN filter descriptor if (mFilterDesc == nullptr) { - const std::vector<int> kernels(mOp.getInput(1)->dims().begin(), mOp.getInput(1)->dims().end()); + const std::vector<int> kernels(input1.dims().begin(), input1.dims().end()); CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, - DataTypeToCudnn(mOp.getInput(1)->dataType()), + DataTypeToCudnn(input1.dataType()), CUDNN_TENSOR_NCHW, kernels.size(), &kernels[0])); @@ -72,10 +81,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm( CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(), + dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), mFilterDesc, mConvDesc, - dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), + dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), maxAlgoIterations, &returnAlgoCounts, &returnFwdAlgo[0])); @@ -86,10 +95,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize( CudaContext::cudnnHandle(), - dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(), + dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), mFilterDesc, mConvDesc, - dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), + dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), mFwdAlgo, &workspaceSize)); @@ -100,43 +109,43 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { // Do the actual forward computation // Template is only for scaling parameters, which are always in float // excepted when the convolution is performed in double precision. - if (mOp.getOutput(0)->dataType() == DataType::Float64) { - forward_<double>(); + if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { + forward_<double>(input0, input1, input2); } else { - forward_<float>(); + forward_<float>(input0, input1, input2); } } template <Aidge::DimIdx_t DIM> template <class T> -void Aidge::ConvImpl_cuda<DIM>::forward_() { - const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; - typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; +void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) { + const T alpha = 1.0f; + const T beta = 0.0f; CHECK_CUDNN_STATUS( cudnnConvolutionForward(CudaContext::cudnnHandle(), &alpha, - dynamic_cast<TensorImpl_cuda_*>(mOp.getInput(0)->getImpl().get())->getCudnnTensorDesc(), - mOp.getInput(0)->getImpl()->rawPtr(), + dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), + input0.getImpl()->rawPtr(), mFilterDesc, - mOp.getInput(1)->getImpl()->rawPtr(), + input1.getImpl()->rawPtr(), mConvDesc, mFwdAlgo, mWorkspace, mWorkspaceSize, &beta, - dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), - mOp.getOutput(0)->getImpl()->rawPtr())); + dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr())); // Add bias (if there is any) - if (mOp.getInput(2) && mOp.getInput(2)->size() > 0) { + if (mOp.getRawInput(2) && input2.size() > 0) { // Bias tensor needs to have the same number of dims than output tensor for cudnnAddTensor() std::vector<DimSize_t> biasDims(DIM+2, 1); - biasDims[1] = mOp.getInput(2)->size(); + biasDims[1] = input2.size(); // Create a dummy tensor with the right dims in order to get a CuDNN tensor descriptor (with getCudnnTensorDesc()) - Tensor bias(mOp.getInput(2)->dataType()); + Tensor bias(input2.dataType()); bias.setBackend("cuda"); bias.resize(biasDims); // TODO: find a more elegant solution(?) @@ -144,10 +153,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward_() { CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(), &alpha, dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(), - mOp.getInput(2)->getImpl()->rawPtr(), + input2.getImpl()->rawPtr(), &alpha, - dynamic_cast<TensorImpl_cuda_*>(mOp.getOutput(0)->getImpl().get())->getCudnnTensorDesc(), - mOp.getOutput(0)->getImpl()->rawPtr())); + dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr())); } } diff --git a/src/operator/ProducerImpl.cpp b/src/operator/ProducerImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aca3c4945e357be13017e302cb6e7f12ba61237c --- /dev/null +++ b/src/operator/ProducerImpl.cpp @@ -0,0 +1,34 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <numeric> // std::accumulate +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cuda/operator/ProducerImpl.hpp" + +Aidge::DimSize_t Aidge::ProducerImpl_cuda::getNbProducedData( + Aidge::IOIndex_t outputIdx) const +{ + // Requires the whole tensors, regardless of available data on inputs + assert(outputIdx == 0 && "operator has only one output"); + (void) outputIdx; + + return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(); +} + +void Aidge::ProducerImpl_cuda::forward() +{ +} diff --git a/unit_tests/Test_CastMove.cpp b/unit_tests/Test_CastMove.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b68a4f9dcb6c72df91506a6f92be8c31e95f068 --- /dev/null +++ b/unit_tests/Test_CastMove.cpp @@ -0,0 +1,219 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <memory> +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/TensorUtils.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/scheduler/Scheduler.hpp" +#include "aidge/recipies/Recipies.hpp" + +#include "aidge/backend/cuda.hpp" + +using namespace Aidge; + +TEST_CASE("[cuda/castmove] CastMove(forward)") { + std::shared_ptr<Tensor> inputTensor = + std::make_shared<Tensor>(Array4D<int, 2, 1, 5, 5>{{{{{0, 1, 2, 3, 4}, + {5, 6, 7, 8, 9}, + {10, 11, 12, 13, 14}, + {15, 16, 17, 18, 19}, + {20, 21, 22, 23, 24}}}, + {{{25, 26, 27, 28, 29}, + {30, 31, 32, 33, 34}, + {35, 36, 37, 38, 39}, + {40, 41, 42, 43, 44}, + {45, 46, 47, 48, 49}}}}}); + + std::shared_ptr<Tensor> weight1 = std::make_shared<Tensor>( + Array4D<int, 3, 1, 3, 3>{{{{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}}, + {{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}}}, + {{{19, 20, 21}, {22, 23, 24}, {25, 26, 27}}}}}); + + std::shared_ptr<Tensor> bias1 = std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}}); + + SECTION("Test implicit") { + std::shared_ptr<GraphView> g = + Sequential({ + Conv(1, 3, {3, 3}, "conv1"), + Conv(3, 4, {1, 1}, "conv2"), + Conv(4, 3, {1, 1}, "conv3")}); + + g->getNode("conv1")->getOperator()->setInput(0, inputTensor); + g->getNode("conv1")->getOperator()->setInput(1, weight1); + g->getNode("conv1")->getOperator()->setInput(2, bias1); + + std::shared_ptr<Tensor> weight2 = + std::make_shared<Tensor>(Array4D<int, 4, 3, 1, 1>{{{{{1}}, {{2}}, {{3}}}, + {{{4}}, {{5}}, {{6}}}, + {{{7}}, {{8}}, {{9}}}, + {{{10}}, {{11}}, {{12}}}}}); + std::shared_ptr<Tensor> bias2 = std::make_shared<Tensor>(Array1D<int, 4>{{1, 2, 3, 4}}); + g->getNode("conv2")->getOperator()->setInput(1, weight2); + g->getNode("conv2")->getOperator()->setInput(2, bias2); + // *(g->getNode("conv2")->getOperator()->input(1, weight2); + + std::shared_ptr<Tensor> weight3 = std::make_shared<Tensor>( + Array4D<int, 3, 4, 1, 1>{{{{{1}}, {{2}}, {{3}}, {{4}}}, + {{{5}}, {{6}}, {{7}}, {{8}}}, + {{{9}}, {{10}}, {{11}}, {{12}}}}}); + std::shared_ptr<Tensor> bias3 = std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}}); + g->getNode("conv3")->getOperator()->setInput(1, weight3); + g->getNode("conv3")->getOperator()->setInput(2, bias3); + + // input->addChild(g); + g->setDataType(Aidge::DataType::Float32); + g->getNode("conv1")->getOperator()->setDataType(DataType::Float16); + g->getNode("conv3")->getOperator()->setDataType(DataType::Float64); + + g->setBackend("cuda"); + g->forwardDims(); + SequentialScheduler scheduler(g); + REQUIRE_NOTHROW(scheduler.forward()); + scheduler.saveSchedulingDiagram("schedulingSequential"); + + std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{ + {{{{367, 412, 457}, {592, 637, 682}, {817, 862, 907}}, + {{854, 980, 1106}, {1484, 1610, 1736}, {2114, 2240, 2366}}, + {{1341, 1548, 1755}, {2376, 2583, 2790}, {3411, 3618, 3825}}}, + {{{1492, 1537, 1582}, {1717, 1762, 1807}, {1942, 1987, 2032}}, + {{4004, 4130, 4256}, {4634, 4760, 4886}, {5264, 5390, 5516}}, + {{6516, 6723, 6930}, {7551, 7758, 7965}, {8586, 8793, 9000}}}}}); + + std::shared_ptr<Tensor> expectedOutput2 = std::make_shared<Tensor>(Array4D<int, 2, 4, 3, 3>{ + {{{{6099, 7017, 7935}, {10689, 11607, 12525}, {15279, 16197, 17115}}, + {{13786, 15838, 17890}, {24046, 26098, 28150}, {34306, 36358, 38410}}, + {{21473, 24659, 27845}, {37403, 40589, 43775}, {53333, 56519, 59705}}, + {{29160, 33480, 37800}, {50760, 55080, 59400}, {72360, 76680, 81000}}}, + {{{29049, 29967, 30885}, {33639, 34557, 35475}, {38229, 39147, 40065}}, + {{65086, 67138, 69190}, {75346, 77398, 79450}, {85606, 87658, 89710}}, + {{101123, 104309, 107495}, {117053, 120239, 123425}, {132983, 136169, 139355}}, + {{137160, 141480, 145800}, {158760, 163080, 167400}, {180360, 184680, 189000}}}}}); + + std::shared_ptr<Tensor> expectedOutput3 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{ + {{{{214731, 246591, 278451}, {374031, 405891, 437751}, {533331, 565191, 597051}}, + {{496804, 570568, 644332}, {865624, 939388, 1013152}, {1234444, 1308208, 1381972}}, + {{778877, 894545, 1010213}, {1357217, 1472885, 1588553}, {1935557, 2051225, 2166893}}}, + {{{1011231, 1043091, 1074951}, {1170531, 1202391, 1234251}, {1329831, 1361691, 1393551}}, + {{2340904, 2414668, 2488432}, {2709724, 2783488, 2857252}, {3078544, 3152308, 3226072}}, + {{3670577, 3786245, 3901913}, {4248917, 4364585, 4480253}, {4827257, 4942925, 5058593}}}}}); + + std::shared_ptr<Tensor> other1 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv1")->getOperator())->getOutput(0); + Tensor hostOther1(other1->dataType()); + hostOther1.setBackend("cpu"); + hostOther1.copyCastFrom(*other1); + REQUIRE(approxEq<half_float::half, int>(hostOther1, *expectedOutput1, 0.001, 0.0)); + + std::shared_ptr<Tensor> other2 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv2")->getOperator())->getOutput(0); + Tensor hostOther2(other2->dataType()); + hostOther2.setBackend("cpu"); + hostOther2.copyCastFrom(*other2); + REQUIRE(approxEq<float, int>(hostOther2, *expectedOutput2, 0.001, 0.0)); + + std::shared_ptr<Tensor> other3 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv3")->getOperator())->getOutput(0); + Tensor hostOther3(other3->dataType()); + hostOther3.setBackend("cpu"); + hostOther3.copyCastFrom(*other3); + REQUIRE(approxEq<double, int>(hostOther3, *expectedOutput3, 0.001, 0.0)); + } + + SECTION("Test explicit") { + std::shared_ptr<GraphView> g = + Sequential({ + Conv(1, 3, {3, 3}, "conv1"), + Conv(3, 4, {1, 1}, "conv2"), + Conv(4, 3, {1, 1}, "conv3")}); + + g->getNode("conv1")->getOperator()->setInput(0, inputTensor); + g->getNode("conv1")->getOperator()->setInput(1, weight1); + g->getNode("conv1")->getOperator()->setInput(2, bias1); + + std::shared_ptr<Tensor> weight2 = + std::make_shared<Tensor>(Array4D<int, 4, 3, 1, 1>{{{{{1}}, {{2}}, {{3}}}, + {{{4}}, {{5}}, {{6}}}, + {{{7}}, {{8}}, {{9}}}, + {{{10}}, {{11}}, {{12}}}}}); + std::shared_ptr<Tensor> bias2 = std::make_shared<Tensor>(Array1D<int, 4>{{1, 2, 3, 4}}); + g->getNode("conv2")->getOperator()->setInput(1, weight2); + g->getNode("conv2")->getOperator()->setInput(2, bias2); + // *(g->getNode("conv2")->getOperator()->input(1, weight2); + + std::shared_ptr<Tensor> weight3 = std::make_shared<Tensor>( + Array4D<int, 3, 4, 1, 1>{{{{{1}}, {{2}}, {{3}}, {{4}}}, + {{{5}}, {{6}}, {{7}}, {{8}}}, + {{{9}}, {{10}}, {{11}}, {{12}}}}}); + std::shared_ptr<Tensor> bias3 = std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}}); + g->getNode("conv3")->getOperator()->setInput(1, weight3); + g->getNode("conv3")->getOperator()->setInput(2, bias3); + + // input->addChild(g); + g->setDataType(Aidge::DataType::Float32); + g->getNode("conv1")->getOperator()->setDataType(DataType::Float16); + g->getNode("conv3")->getOperator()->setDataType(DataType::Float64); + + explicitCastMove(g); + g->setBackend("cuda"); + g->forwardDims(); + + SequentialScheduler scheduler(g); + REQUIRE_NOTHROW(scheduler.forward()); + scheduler.saveSchedulingDiagram("schedulingSequential"); + + std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{ + {{{{367, 412, 457}, {592, 637, 682}, {817, 862, 907}}, + {{854, 980, 1106}, {1484, 1610, 1736}, {2114, 2240, 2366}}, + {{1341, 1548, 1755}, {2376, 2583, 2790}, {3411, 3618, 3825}}}, + {{{1492, 1537, 1582}, {1717, 1762, 1807}, {1942, 1987, 2032}}, + {{4004, 4130, 4256}, {4634, 4760, 4886}, {5264, 5390, 5516}}, + {{6516, 6723, 6930}, {7551, 7758, 7965}, {8586, 8793, 9000}}}}}); + + std::shared_ptr<Tensor> expectedOutput2 = std::make_shared<Tensor>(Array4D<int, 2, 4, 3, 3>{ + {{{{6099, 7017, 7935}, {10689, 11607, 12525}, {15279, 16197, 17115}}, + {{13786, 15838, 17890}, {24046, 26098, 28150}, {34306, 36358, 38410}}, + {{21473, 24659, 27845}, {37403, 40589, 43775}, {53333, 56519, 59705}}, + {{29160, 33480, 37800}, {50760, 55080, 59400}, {72360, 76680, 81000}}}, + {{{29049, 29967, 30885}, {33639, 34557, 35475}, {38229, 39147, 40065}}, + {{65086, 67138, 69190}, {75346, 77398, 79450}, {85606, 87658, 89710}}, + {{101123, 104309, 107495}, {117053, 120239, 123425}, {132983, 136169, 139355}}, + {{137160, 141480, 145800}, {158760, 163080, 167400}, {180360, 184680, 189000}}}}}); + + std::shared_ptr<Tensor> expectedOutput3 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{ + {{{{214731, 246591, 278451}, {374031, 405891, 437751}, {533331, 565191, 597051}}, + {{496804, 570568, 644332}, {865624, 939388, 1013152}, {1234444, 1308208, 1381972}}, + {{778877, 894545, 1010213}, {1357217, 1472885, 1588553}, {1935557, 2051225, 2166893}}}, + {{{1011231, 1043091, 1074951}, {1170531, 1202391, 1234251}, {1329831, 1361691, 1393551}}, + {{2340904, 2414668, 2488432}, {2709724, 2783488, 2857252}, {3078544, 3152308, 3226072}}, + {{3670577, 3786245, 3901913}, {4248917, 4364585, 4480253}, {4827257, 4942925, 5058593}}}}}); + + std::shared_ptr<Tensor> other1 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv1")->getOperator())->getOutput(0); + Tensor hostOther1(other1->dataType()); + hostOther1.setBackend("cpu"); + hostOther1.copyCastFrom(*other1); + REQUIRE(approxEq<half_float::half, int>(hostOther1, *expectedOutput1, 0.001, 0.0)); + + std::shared_ptr<Tensor> other2 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv2")->getOperator())->getOutput(0); + Tensor hostOther2(other2->dataType()); + hostOther2.setBackend("cpu"); + hostOther2.copyCastFrom(*other2); + REQUIRE(approxEq<float, int>(hostOther2, *expectedOutput2, 0.001, 0.0)); + + std::shared_ptr<Tensor> other3 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv3")->getOperator())->getOutput(0); + Tensor hostOther3(other3->dataType()); + hostOther3.setBackend("cpu"); + hostOther3.copyCastFrom(*other3); + REQUIRE(approxEq<double, int>(hostOther3, *expectedOutput3, 0.001, 0.0)); + } +} diff --git a/unit_tests/Test_ConvImpl.cpp b/unit_tests/Test_ConvImpl.cpp index 659528dd1b2a45fcdd67ca0bd3440391a0e79654..b7faadd677336b9ff72274ea250251f95785b24f 100644 --- a/unit_tests/Test_ConvImpl.cpp +++ b/unit_tests/Test_ConvImpl.cpp @@ -25,8 +25,9 @@ using namespace Aidge; TEST_CASE("[gpu/operator] Conv(forward)") { SECTION("Simple Conv no bias") { std::shared_ptr<Node> myConv = Conv(1,1,{3,3}, "myconv"); - myConv->getOperator()->setDatatype(DataType::Float32); - myConv->getOperator()->setBackend("cuda"); + auto op = std::static_pointer_cast<OperatorTensor>(myConv->getOperator()); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,1,1,3,3> { { { @@ -50,12 +51,12 @@ TEST_CASE("[gpu/operator] Conv(forward)") { myInput->setBackend("cuda"); myWeights->setBackend("cuda"); - myConv->getOperator()->associateInput(0,myInput); - myConv->getOperator()->associateInput(1,myWeights); - myConv->getOperator()->computeOutputDims(); + op->associateInput(0,myInput); + op->associateInput(1,myWeights); + op->computeOutputDims(); myConv->forward(); - REQUIRE(myConv->getOperator()->getOutput(0)->size() == 1); + REQUIRE(op->getOutput(0)->size() == 1); std::array<float, 9> kernel; cudaMemcpy(&kernel[0], myWeights->getImpl()->rawPtr(), 9 * sizeof(float), cudaMemcpyDeviceToHost); @@ -68,15 +69,16 @@ TEST_CASE("[gpu/operator] Conv(forward)") { } float computedOutput; - cudaMemcpy(&computedOutput, myConv->getOperator()->getOutput(0)->getImpl()->rawPtr(), sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(&computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float), cudaMemcpyDeviceToHost); REQUIRE(fabs(computedOutput - myOutput) < 1e-6); } SECTION("Classic Conv") { std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv"); - myConv->getOperator()->setDatatype(DataType::Float32); - myConv->getOperator()->setBackend("cuda"); + auto op = std::static_pointer_cast<OperatorTensor>(myConv->getOperator()); + op->setDataType(DataType::Float32); + op->setBackend("cuda"); std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,3,3> { { { @@ -205,15 +207,15 @@ TEST_CASE("[gpu/operator] Conv(forward)") { myWeights->setBackend("cuda"); myBias->setBackend("cuda"); - myConv->getOperator()->associateInput(0,myInput); - myConv->getOperator()->associateInput(1,myWeights); - myConv->getOperator()->associateInput(2,myBias); - myConv->getOperator()->computeOutputDims(); + op->associateInput(0,myInput); + op->associateInput(1,myWeights); + op->associateInput(2,myBias); + op->computeOutputDims(); myConv->forward(); - // myConv->getOperator()->getOutput(0)->print(); + // op->getOutput(0)->print(); float* computedOutput = new float[myOutput->size()](); - cudaMemcpy(computedOutput, myConv->getOperator()->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost); + cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost); for(int i = 0; i < myOutput->size(); i++){ const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);