Skip to content
Snippets Groups Projects
Commit d30c27b1 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Adaptations to core changes

parent 594e9898
No related branches found
No related tags found
1 merge request!4Add Convert operator (a.k.a. Transmitter)
Pipeline #36493 failed
...@@ -76,14 +76,16 @@ public: ...@@ -76,14 +76,16 @@ public:
// native interface // native interface
const future_std::span<T>& data() const { return mData; } const future_std::span<T>& data() const { return mData; }
std::size_t size() const override { return mData.size(); }
std::size_t scalarSize() const override { return sizeof(T); } std::size_t scalarSize() const override { return sizeof(T); }
void setDevice(int device) override { void setDevice(int device) override {
mDevice = device; mDevice = device;
} }
void copy(const void *src, NbElts_t length) override { void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice)); void* dst = static_cast<void*>(static_cast<T*>(rawPtr()) + offset);
CHECK_CUDA_STATUS(cudaMemcpy(dst, src, length * sizeof(T), cudaMemcpyDeviceToDevice));
} }
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override { void copyCast(const void *src, NbElts_t length, const DataType srcDt) override {
...@@ -177,10 +179,6 @@ public: ...@@ -177,10 +179,6 @@ public:
return mData.data(); return mData.data();
}; };
void* getRaw(std::size_t idx) {
return static_cast<void*>(static_cast<T*>(rawPtr()) + idx);
}
const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override { const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override {
if (mCudnnTensor == nullptr) { if (mCudnnTensor == nullptr) {
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor)); CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor));
...@@ -214,10 +212,10 @@ public: ...@@ -214,10 +212,10 @@ public:
return mCudnnTensor; return mCudnnTensor;
} }
virtual ~TensorImpl_cuda() { void* getRawPtr(NbElts_t idx) override final {
if (mCudnnTensor != nullptr) AIDGE_ASSERT(idx < mData.size(), "idx out of range");
cudnnDestroyTensorDescriptor(mCudnnTensor); return static_cast<void*>(static_cast<T*>(rawPtr()) + idx);
} };
void setRawPtr(void *ptr, NbElts_t length) override final { void setRawPtr(void *ptr, NbElts_t length) override final {
AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity"); AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity");
...@@ -225,6 +223,11 @@ public: ...@@ -225,6 +223,11 @@ public:
mDataOwner.reset(); mDataOwner.reset();
}; };
virtual ~TensorImpl_cuda() {
if (mCudnnTensor != nullptr)
cudnnDestroyTensorDescriptor(mCudnnTensor);
}
private: private:
void lazyInit() { void lazyInit() {
if (mData.size() < mTensor.size()) { if (mData.size() < mTensor.size()) {
......
...@@ -67,24 +67,6 @@ namespace Cuda { ...@@ -67,24 +67,6 @@ namespace Cuda {
// Enable Peer-to-Peer communications between devices // Enable Peer-to-Peer communications between devices
// when it is possible // when it is possible
void setMultiDevicePeerAccess(unsigned int size, unsigned int* devices); void setMultiDevicePeerAccess(unsigned int size, unsigned int* devices);
// CuDNN scaling parameters are typically "alpha" and "beta".
// Their type must be "float" for HALF and FLOAT (default template)
// and "double" for DOUBLE (specialized template)
template <class T>
struct cudnn_scaling_type {
typedef float type;
};
template <>
struct cudnn_scaling_type<double> {
typedef double type;
};
template <class T>
struct cuda_type {
typedef T type;
};
} }
} }
......
...@@ -120,8 +120,8 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { ...@@ -120,8 +120,8 @@ void Aidge::ConvImpl_cuda<DIM>::forward() {
template <Aidge::DimIdx_t DIM> template <Aidge::DimIdx_t DIM>
template <class T> template <class T>
void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) { void Aidge::ConvImpl_cuda<DIM>::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2) {
const typename Cuda::cudnn_scaling_type<T>::type alpha = 1.0f; const T alpha = 1.0f;
typename Cuda::cudnn_scaling_type<T>::type beta = 0.0f; const T beta = 0.0f;
CHECK_CUDNN_STATUS( CHECK_CUDNN_STATUS(
cudnnConvolutionForward(CudaContext::cudnnHandle(), cudnnConvolutionForward(CudaContext::cudnnHandle(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment