From 0efec0ee82f9a686727a9fb32223a4192e24e471 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Wed, 28 Feb 2024 15:46:46 +0100
Subject: [PATCH] update tensorDesc call

---
 src/operator/AvgPoolingImpl.cpp | 13 ++++++++-----
 src/operator/MaxPoolingImpl.cpp | 13 ++++++++-----
 src/operator/ReLUImpl.cpp       | 11 +++++++----
 3 files changed, 23 insertions(+), 14 deletions(-)

diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp
index 0257f19..4737e3e 100644
--- a/src/operator/AvgPoolingImpl.cpp
+++ b/src/operator/AvgPoolingImpl.cpp
@@ -24,14 +24,16 @@
 
 template <Aidge::DimIdx_t DIM>
 void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
     assert(mOp.getRawInput(0) && "missing input #0");
 
     std::shared_ptr<Tensor> inputFallback;
-    const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
 
     // Lazy-initialize CuDNN AvgPooling descriptor
     if (mAvgPoolingDesc == nullptr) {
-        const AvgPooling_Op<DIM>& avgPoolingOp = static_cast<const AvgPooling_Op<DIM>&>(mOp);
+        const AvgPooling_Op<DIM>& avgPoolingOp = static_cast<const AvgPooling_Op<DIM>&>(op);
         const std::vector<int> strides(avgPoolingOp.template getAttr<AvgPoolingAttr::StrideDims>().begin(), avgPoolingOp.template getAttr<AvgPoolingAttr::StrideDims>().end());
         const std::vector<int> paddings(DIM, 0);
         const std::vector<int> window_dims(avgPoolingOp.template getAttr<AvgPoolingAttr::KernelDims>().begin(), avgPoolingOp.template getAttr<AvgPoolingAttr::KernelDims>().end());
@@ -58,6 +60,7 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
 template <Aidge::DimIdx_t DIM>
 template <class T>
 void Aidge::AvgPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const T alpha = 1.0f;
     const T beta = 0.0f;
     CHECK_CUDNN_STATUS(
@@ -65,11 +68,11 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
             CudaContext::cudnnHandle(),
             mAvgPoolingDesc,
             &alpha,
-            dynamic_cast<TensorImpl_cuda_*>(input.getImpl().get())->getCudnnTensorDesc(),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
             input.getImpl()->rawPtr(),
             &beta,
-            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
-            std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()
         )
     );
 }
diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp
index 8ba19a6..9304160 100644
--- a/src/operator/MaxPoolingImpl.cpp
+++ b/src/operator/MaxPoolingImpl.cpp
@@ -24,14 +24,16 @@
 
 template <Aidge::DimIdx_t DIM>
 void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
     assert(mOp.getRawInput(0) && "missing input #0");
 
     std::shared_ptr<Tensor> inputFallback;
-    const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(op.getRawOutput(0)));
 
     // Lazy-initialize CuDNN MaxPooling descriptor
     if (mMaxPoolingDesc == nullptr) {
-        const MaxPooling_Op<DIM>& maxPoolingOp = static_cast<const MaxPooling_Op<DIM>&>(mOp);
+        const MaxPooling_Op<DIM>& maxPoolingOp = static_cast<const MaxPooling_Op<DIM>&>(op);
         const std::vector<int> strides(maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::StrideDims>().end());
         const std::vector<int> paddings(DIM, 0);
         const std::vector<int> window_dims(maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().begin(), maxPoolingOp.template getAttr<MaxPoolingAttr::KernelDims>().end());
@@ -58,6 +60,7 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
 template <Aidge::DimIdx_t DIM>
 template <class T>
 void Aidge::MaxPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const T alpha = 1.0f;
     const T beta = 0.0f;
     CHECK_CUDNN_STATUS(
@@ -65,11 +68,11 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward_(const Tensor& input) {
             CudaContext::cudnnHandle(),
             mMaxPoolingDesc,
             &alpha,
-            dynamic_cast<TensorImpl_cuda_*>(input.getImpl().get())->getCudnnTensorDesc(),
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
             input.getImpl()->rawPtr(),
             &beta,
-            dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
-            std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
+            std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+            std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()
         )
     );
 }
diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp
index 4636087..ed2e5d4 100644
--- a/src/operator/ReLUImpl.cpp
+++ b/src/operator/ReLUImpl.cpp
@@ -23,10 +23,12 @@
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
 
 void Aidge::ReLUImpl_cuda::forward() {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
+
     assert(mOp.getRawInput(0) && "missing input #0");
 
     std::shared_ptr<Tensor> inputFallback;
-    const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input = std::static_pointer_cast<Tensor>(op.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(op.getRawOutput(0)));
 
     // Lazy-initialize CuDNN ReLU descriptor
     if (mReLUDesc == nullptr) {
@@ -49,17 +51,18 @@ void Aidge::ReLUImpl_cuda::forward() {
 
 template <class T>
 void Aidge::ReLUImpl_cuda::forward_(const Tensor& input) {
+    const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
     const T alpha = 1.0f;
     const T beta = 0.0f;
     CHECK_CUDNN_STATUS(
         cudnnActivationForward(CudaContext::cudnnHandle(),
                                mReLUDesc,
                                &alpha,
-							   dynamic_cast<TensorImpl_cuda_*>(input.getImpl().get())->getCudnnTensorDesc(),
+							   std::dynamic_pointer_cast<TensorImpl_cuda_>(input.getImpl())->getCudnnTensorDesc(input),
                                input.getImpl()->rawPtr(),
                                &beta,
-                               dynamic_cast<TensorImpl_cuda_*>(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl().get())->getCudnnTensorDesc(),
-                               std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()));
+                               std::dynamic_pointer_cast<TensorImpl_cuda_>(op.getOutput(0)->getImpl())->getCudnnTensorDesc(*op.getOutput(0)),
+                               std::static_pointer_cast<Tensor>(op.getRawOutput(0))->getImpl()->rawPtr()));
 }
 
 Aidge::ReLUImpl_cuda::~ReLUImpl_cuda() {
-- 
GitLab