Skip to content
Snippets Groups Projects
Commit 3e47b059 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'view' into 'dev'

New proposal for handling tensor views

See merge request eclipse/aidge/aidge_backend_cuda!5
parents e69bdb10 4d17363d
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,9 @@ void thrust_copy(const half_float::half* srcData, half_float::half* dstData, siz ...@@ -29,6 +29,9 @@ void thrust_copy(const half_float::half* srcData, half_float::half* dstData, siz
* class), but whose data type does not need to be known. * class), but whose data type does not need to be known.
*/ */
class TensorImpl_cuda_ { class TensorImpl_cuda_ {
protected:
mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
public: public:
/** /**
* @brief Return the CuDNN tensor descriptor of the tensor. * @brief Return the CuDNN tensor descriptor of the tensor.
...@@ -36,7 +39,12 @@ public: ...@@ -36,7 +39,12 @@ public:
* (which is therefore mutable in the derived class). * (which is therefore mutable in the derived class).
* @return cudnnTensorDescriptor_t CuDNN tensor descriptor. * @return cudnnTensorDescriptor_t CuDNN tensor descriptor.
*/ */
virtual const cudnnTensorDescriptor_t& getCudnnTensorDesc() const = 0; virtual const cudnnTensorDescriptor_t& getCudnnTensorDesc(const Tensor& tensor) const = 0;
virtual ~TensorImpl_cuda_() {
if (mCudnnTensor != nullptr)
cudnnDestroyTensorDescriptor(mCudnnTensor);
}
}; };
template <class T> template <class T>
...@@ -54,119 +62,116 @@ private: ...@@ -54,119 +62,116 @@ private:
} }
private: private:
const Tensor &mTensor; // Impl needs to access Tensor information, but is not
// supposed to change it!
/// Pointer to the data and its capacity
future_std::span<T> mData; future_std::span<T> mData;
/// If this instance own the data, std::unique_ptr manages it /// If this instance own the data, std::unique_ptr manages it
std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner; std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner;
mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
public: public:
static constexpr const char *Backend = "cuda"; static constexpr const char *Backend = "cuda";
TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor), mDataOwner(nullptr, cudaDelete) {} TensorImpl_cuda(DeviceIdx_t device, NbElts_t length) : TensorImpl(Backend, device, length), mDataOwner(nullptr, cudaDelete) {}
bool operator==(const TensorImpl &otherImpl) const override final; bool operator==(const TensorImpl &otherImpl) const override final;
static std::unique_ptr<TensorImpl_cuda> create(const Tensor &tensor) { static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, NbElts_t length) {
return std::make_unique<TensorImpl_cuda<T>>(tensor); return std::make_shared<TensorImpl_cuda<T>>(device, length);
} }
// native interface // native interface
const future_std::span<T>& data() const { return mData; } const future_std::span<T>& data() const { return mData; }
std::size_t size() const override { return mData.size(); }
std::size_t scalarSize() const override { return sizeof(T); } std::size_t scalarSize() const override { return sizeof(T); }
void setDevice(DeviceIdx_t device) override {
mDevice = device;
}
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override { void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
void* dst = static_cast<void*>(static_cast<T*>(rawPtr()) + offset); AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
CHECK_CUDA_STATUS(cudaMemcpy(dst, src, length * sizeof(T), cudaMemcpyDeviceToDevice)); const T* srcT = static_cast<const T *>(src);
T* dstT = static_cast<T *>(rawPtr(offset));
AIDGE_ASSERT(dstT < srcT || dstT >= srcT + length, "overlapping copy is not supported");
CHECK_CUDA_STATUS(cudaMemcpy(dstT, srcT, length * sizeof(T), cudaMemcpyDeviceToDevice));
} }
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override { void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override {
if (length == 0) { if (length == 0) {
return; return;
} }
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
if (srcDt == DataType::Float64) { switch (srcDt) {
case DataType::Float64:
thrust_copy(static_cast<const double*>(src), thrust_copy(static_cast<const double*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::Float32) { case DataType::Float32:
thrust_copy(static_cast<const float*>(src), thrust_copy(static_cast<const float*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::Float16) { case DataType::Float16:
thrust_copy(static_cast<const half_float::half*>(src), thrust_copy(static_cast<const half_float::half*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::Int64) { case DataType::Int64:
thrust_copy(static_cast<const int64_t*>(src), thrust_copy(static_cast<const int64_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::UInt64) { case DataType::UInt64:
thrust_copy(static_cast<const uint64_t*>(src), thrust_copy(static_cast<const uint64_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::Int32) { case DataType::Int32:
thrust_copy(static_cast<const int32_t*>(src), thrust_copy(static_cast<const int32_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::UInt32) { case DataType::UInt32:
thrust_copy(static_cast<const uint32_t*>(src), thrust_copy(static_cast<const uint32_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::Int16) { case DataType::Int16:
thrust_copy(static_cast<const int16_t*>(src), thrust_copy(static_cast<const int16_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::UInt16) { case DataType::UInt16:
thrust_copy(static_cast<const uint16_t*>(src), thrust_copy(static_cast<const uint16_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::Int8) { case DataType::Int8:
thrust_copy(static_cast<const int8_t*>(src), thrust_copy(static_cast<const int8_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else if (srcDt == DataType::UInt8) { case DataType::UInt8:
thrust_copy(static_cast<const uint8_t*>(src), thrust_copy(static_cast<const uint8_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr(offset)),
length); length);
} break;
else { default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
} }
} }
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override { void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) override {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice)); CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(offset), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
} }
void copyFromHost(const void *src, NbElts_t length) override { void copyFromHost(const void *src, NbElts_t length, NbElts_t offset = 0) override {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyHostToDevice)); CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(offset), src, length * sizeof(T), cudaMemcpyHostToDevice));
} }
void copyToHost(void *dst, NbElts_t length) const override { void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const override {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost)); CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(offset), length * sizeof(T), cudaMemcpyDeviceToHost));
} }
void *rawPtr(NbElts_t offset = 0) override { void *rawPtr(NbElts_t offset = 0) override {
...@@ -175,30 +180,27 @@ public: ...@@ -175,30 +180,27 @@ public:
}; };
const void *rawPtr(NbElts_t offset = 0) const override { const void *rawPtr(NbElts_t offset = 0) const override {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); AIDGE_ASSERT(mData.size() >= mNbElts, "accessing uninitialized const rawPtr");
return (mData.data() + offset); return (mData.data() + offset);
}; };
const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override { const cudnnTensorDescriptor_t& getCudnnTensorDesc(const Tensor& tensor) const override {
if (mCudnnTensor == nullptr) { if (mCudnnTensor == nullptr) {
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor)); CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor));
if (mTensor.size() > 0) { if (tensor.size() > 0) {
/** /**
** cudNN Tensors are restricted to having at least 4 dimensions : ** cudNN Tensors are restricted to having at least 4 dimensions :
** When working with lower dimensionsal data, unused dimensions are set to 1. ** When working with lower dimensionsal data, unused dimensions are set to 1.
** Referes to the cudnnSetTensorNdDescriptor documentation from : ** Referes to the cudnnSetTensorNdDescriptor documentation from :
** https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html ** https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html
**/ **/
std::vector<int> dims(mTensor.dims().begin(), mTensor.dims().end()); std::vector<int> dims(tensor.dims().cbegin(), tensor.dims().cend());
std::vector<int> strides(tensor.strides().cbegin(), tensor.strides().cend());
if (dims.size() < 4) if (dims.size() < 4) {
dims.resize(4, 1); dims.resize(4, 1);
strides.resize(4, 1);
std::vector<int> strides(dims.size(), 1);
for (size_t dim = 1; dim < dims.size(); ++dim) {
strides[dims.size() - dim - 1] = strides[dims.size() - dim] * dims[dims.size() - dim];
} }
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor, CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor,
...@@ -213,23 +215,20 @@ public: ...@@ -213,23 +215,20 @@ public:
} }
void setRawPtr(void *ptr, NbElts_t length) override final { void setRawPtr(void *ptr, NbElts_t length) override final {
AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity"); AIDGE_ASSERT(length >= mNbElts, "trying to set raw pointer of insufficient capacity");
mData = future_std::span<T>(static_cast<T *>(ptr), length); mData = future_std::span<T>(static_cast<T *>(ptr), length);
mDataOwner.reset(); mDataOwner.reset();
}; };
virtual ~TensorImpl_cuda() { virtual ~TensorImpl_cuda() = default;
if (mCudnnTensor != nullptr)
cudnnDestroyTensorDescriptor(mCudnnTensor);
}
private: private:
void lazyInit() { void lazyInit() {
if (mData.size() < mTensor.size()) { if (mData.size() < mNbElts) {
// Need more data, a re-allocation will occur // Need more data, a re-allocation will occur
AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data"); AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data");
mDataOwner.reset(cudaAlloc(mTensor.size())); mDataOwner.reset(cudaAlloc(mNbElts));
mData = future_std::span<T>(mDataOwner.get(), mTensor.size()); mData = future_std::span<T>(mDataOwner.get(), mNbElts);
} }
} }
}; };
......
...@@ -35,12 +35,15 @@ private: ...@@ -35,12 +35,15 @@ private:
cudnnFilterDescriptor_t mFilterDesc = nullptr; cudnnFilterDescriptor_t mFilterDesc = nullptr;
cudnnConvolutionFwdAlgo_t mFwdAlgo; cudnnConvolutionFwdAlgo_t mFwdAlgo;
size_t mWorkspaceSize = 0; size_t mWorkspaceSize = 0;
void* mWorkspace = nullptr; void* mFwdWorkspace = nullptr;
std::shared_ptr<Tensor> mInput0Fallback;
std::shared_ptr<Tensor> mInput1Fallback;
std::shared_ptr<Tensor> mInput2Fallback;
public: public:
ConvImpl_cuda(const Conv_Op<DIM> &op) : OperatorImpl(op) {} ConvImpl_cuda(const Conv_Op<DIM> &op) : OperatorImpl(op) {}
static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) { static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<DIM> &op) {
return std::make_unique<ConvImpl_cuda>(op); return std::make_unique<ConvImpl_cuda>(op);
} }
......
...@@ -91,10 +91,10 @@ template <class T> ...@@ -91,10 +91,10 @@ template <class T>
bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl); const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl);
if (mTensor.size() != otherImplCuda.mTensor.size()) if (mNbElts != otherImplCuda.size())
return false; return false;
thrust::device_ptr<T> thrustData(mData.data()); thrust::device_ptr<T> thrustData(mData.data());
thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data()); thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data());
return thrust::equal(thrustData, thrustData + mTensor.size(), thrustOtherData); return thrust::equal(thrustData, thrustData + mNbElts, thrustOtherData);
} }
...@@ -24,18 +24,16 @@ ...@@ -24,18 +24,16 @@
template <Aidge::DimIdx_t DIM> template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::forward() { void Aidge::ConvImpl_cuda<DIM>::forward() {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
// FIXME: uncomment the following code once memory handling will work // FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0"); assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1"); assert(mOp.getRawInput(1) && "missing input #1");
// Convert input data (no overhead if not needed!) // Convert input data (no overhead if not needed!)
// TODO: right now, if needed, memory will be allocated/deallocated at each const auto& input0 = op.getInput(0)->refCastFrom(mInput0Fallback, *op.getOutput(0));
// call to forward(). We might put the following shared_ptr as members of const auto& input1 = op.getInput(1)->refCastFrom(mInput1Fallback, *op.getOutput(0));
// this class to avoid that. const auto& input2 = op.getInput(2)->refCastFrom(mInput2Fallback, *op.getOutput(0));
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
// Lazy-initialize CuDNN convolution descriptor // Lazy-initialize CuDNN convolution descriptor
if (mConvDesc == nullptr) { if (mConvDesc == nullptr) {
...@@ -45,14 +43,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { ...@@ -45,14 +43,13 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
const std::vector<int> upscales(convOp.template getAttr<ConvAttr::DilationDims>().begin(), convOp.template getAttr<ConvAttr::DilationDims>().end()); const std::vector<int> upscales(convOp.template getAttr<ConvAttr::DilationDims>().begin(), convOp.template getAttr<ConvAttr::DilationDims>().end());
CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc));
CHECK_CUDNN_STATUS( CHECK_CUDNN_STATUS(cudnnSetConvolutionNdDescriptor(mConvDesc,
cudnnSetConvolutionNdDescriptor(mConvDesc, DIM,
DIM, &paddings[0],
&paddings[0], &strides[0],
&strides[0], &upscales[0],
&upscales[0], CUDNN_CROSS_CORRELATION,
CUDNN_CROSS_CORRELATION, DataTypeToCudnn(op.getOutput(0)->dataType())));
DataTypeToCudnn(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType())));
} }
// Lazy-initialize CuDNN filter descriptor // Lazy-initialize CuDNN filter descriptor
...@@ -61,14 +58,14 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { ...@@ -61,14 +58,14 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc)); CHECK_CUDNN_STATUS(cudnnCreateFilterDescriptor(&mFilterDesc));
CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc, CHECK_CUDNN_STATUS(cudnnSetFilterNdDescriptor(mFilterDesc,
DataTypeToCudnn(input1.dataType()), DataTypeToCudnn(input1.dataType()),
CUDNN_TENSOR_NCHW, CUDNN_TENSOR_NCHW,
kernels.size(), kernels.size(),
&kernels[0])); &kernels[0]));
} }
// Set forward algorithm and allocate the required workspace // Set forward algorithm and allocate the required workspace
if (mWorkspace == nullptr) { if (mFwdWorkspace == nullptr) {
// Find the best CuDNN forward algorithm (the one with the lowest compute time) // Find the best CuDNN forward algorithm (the one with the lowest compute time)
int maxAlgoIterations = 0; int maxAlgoIterations = 0;
cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(), cudnnGetConvolutionForwardAlgorithmMaxCount(CudaContext::cudnnHandle(),
...@@ -80,14 +77,14 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { ...@@ -80,14 +77,14 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations); std::vector<cudnnConvolutionFwdAlgoPerf_t> returnFwdAlgo(maxAlgoIterations);
CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm( CHECK_CUDNN_STATUS(cudnnFindConvolutionForwardAlgorithm(
CudaContext::cudnnHandle(), CudaContext::cudnnHandle(),
dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
mFilterDesc, mFilterDesc,
mConvDesc, mConvDesc,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
maxAlgoIterations, maxAlgoIterations,
&returnAlgoCounts, &returnAlgoCounts,
&returnFwdAlgo[0])); &returnFwdAlgo[0]));
mFwdAlgo = returnFwdAlgo[0].algo; mFwdAlgo = returnFwdAlgo[0].algo;
// Allocate the workspace required by the chosen CuDNN forward algorithm // Allocate the workspace required by the chosen CuDNN forward algorithm
...@@ -95,21 +92,21 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { ...@@ -95,21 +92,21 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize( CHECK_CUDNN_STATUS(cudnnGetConvolutionForwardWorkspaceSize(
CudaContext::cudnnHandle(), CudaContext::cudnnHandle(),
dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
mFilterDesc, mFilterDesc,
mConvDesc, mConvDesc,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
mFwdAlgo, mFwdAlgo,
&workspaceSize)); &workspaceSize));
CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, workspaceSize)); CHECK_CUDA_STATUS(cudaMalloc(&mFwdWorkspace, workspaceSize));
mWorkspaceSize = workspaceSize; mWorkspaceSize = workspaceSize;
} }
// Do the actual forward computation // Do the actual forward computation
// Template is only for scaling parameters, which are always in float // Template is only for scaling parameters, which are always in float
// excepted when the convolution is performed in double precision. // excepted when the convolution is performed in double precision.
if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) { if (op.getOutput(0)->dataType() == DataType::Float64) {
forward_<double>(input0, input1, input2); forward_<double>(input0, input1, input2);
} }
else { else {
...@@ -120,23 +117,23 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { ...@@ -120,23 +117,23 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
template <Aidge::DimIdx_t DIM> template <Aidge::DimIdx_t DIM>
template <class T> template <class T>
void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) { void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
const T alpha = 1.0f; const T alpha = 1.0f;
const T beta = 0.0f; const T beta = 0.0f;
CHECK_CUDNN_STATUS( CHECK_CUDNN_STATUS(cudnnConvolutionForward(CudaContext::cudnnHandle(),
cudnnConvolutionForward(CudaContext::cudnnHandle(), &alpha,
&alpha, std::dynamic_pointer_cast<TensorImpl_cuda_>(input0.getImpl())->getCudnnTensorDesc(input0),
dynamic_cast<TensorImpl_cuda_*>(input0.getImpl().get())->getCudnnTensorDesc(), input0.getImpl()->rawPtr(),
input0.getImpl()->rawPtr(), mFilterDesc,
mFilterDesc, input1.getImpl()->rawPtr(),
input1.getImpl()->rawPtr(), mConvDesc,
mConvDesc, mFwdAlgo,
mFwdAlgo, mFwdWorkspace,
mWorkspace, mWorkspaceSize,
mWorkspaceSize, &beta,
&beta, std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), op.getOutput(0)->getImpl()->rawPtr()));
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
// Add bias (if there is any) // Add bias (if there is any)
if (mOp.getRawInput(2) && input2.size() > 0) { if (mOp.getRawInput(2) && input2.size() > 0) {
...@@ -151,12 +148,12 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp ...@@ -151,12 +148,12 @@ void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& inp
// TODO: find a more elegant solution(?) // TODO: find a more elegant solution(?)
CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(), CHECK_CUDNN_STATUS(cudnnAddTensor(CudaContext::cudnnHandle(),
&alpha, &alpha,
dynamic_cast<TensorImpl_cuda_*>(bias.getImpl().get())->getCudnnTensorDesc(), std::dynamic_pointer_cast<TensorImpl_cuda_>(bias.getImpl())->getCudnnTensorDesc(bias),
input2.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
&alpha, &alpha,
dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(), std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr())); op.getOutput(0)->getImpl()->rawPtr()));
} }
} }
...@@ -170,8 +167,8 @@ Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() { ...@@ -170,8 +167,8 @@ Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
cudnnDestroyFilterDescriptor(mFilterDesc); cudnnDestroyFilterDescriptor(mFilterDesc);
} }
if (mWorkspace != nullptr) { if (mFwdWorkspace != nullptr) {
cudaFree(mWorkspace); cudaFree(mFwdWorkspace);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment