From c41f86433e0fe346448ba2408b25365c47746085 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 16 Apr 2024 17:43:52 +0200
Subject: [PATCH] minor cleanings

---
 .../operator/GlobalAveragePoolingImpl.hpp     |  2 +-
 src/operator/AddImpl.cpp                      |  2 +-
 src/operator/AvgPoolingImpl.cpp               | 17 +++++------------
 src/operator/BatchNormImpl.cpp                |  2 +-
 src/operator/GlobalAveragePoolingImpl.cpp     | 19 ++++++-------------
 src/operator/MaxPoolingImpl.cpp               | 17 +++++------------
 src/operator/ReLUImpl.cpp                     | 19 ++++++-------------
 src/operator/ReshapeImpl.cpp                  |  7 +++----
 8 files changed, 28 insertions(+), 57 deletions(-)

diff --git a/include/aidge/backend/cuda/operator/GlobalAveragePoolingImpl.hpp b/include/aidge/backend/cuda/operator/GlobalAveragePoolingImpl.hpp
index 79d1413..d1cd602 100644
--- a/include/aidge/backend/cuda/operator/GlobalAveragePoolingImpl.hpp
+++ b/include/aidge/backend/cuda/operator/GlobalAveragePoolingImpl.hpp
@@ -1,5 +1,5 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 CEA-List
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
diff --git a/src/operator/AddImpl.cpp b/src/operator/AddImpl.cpp
index 11e577a..22ff4d8 100644
--- a/src/operator/AddImpl.cpp
+++ b/src/operator/AddImpl.cpp
@@ -1,5 +1,5 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 CEA-List
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp
index eb9cc6a..f1180c1 100644
--- a/src/operator/AvgPoolingImpl.cpp
+++ b/src/operator/AvgPoolingImpl.cpp
@@ -45,18 +45,11 @@ void Aidge::AvgPoolingImpl_cuda<DIM>::forward() {
                                         &strides[0]));
     }
 
-    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
-        case DataType::Float64:
-            forward_<double>(input);
-            break;
-        case DataType::Float32:
-            forward_<float>(input);
-            break;
-        case DataType::Float16:
-            forward_<half>(input);
-            break;
-        default:
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    if (op.getOutput(0)->dataType() == DataType::Float64) {
+        forward_<double>(input);
+    }
+    else {
+        forward_<float>(input);
     }
 }
 
diff --git a/src/operator/BatchNormImpl.cpp b/src/operator/BatchNormImpl.cpp
index eb90baa..c194151 100644
--- a/src/operator/BatchNormImpl.cpp
+++ b/src/operator/BatchNormImpl.cpp
@@ -1,5 +1,5 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 CEA-List
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
diff --git a/src/operator/GlobalAveragePoolingImpl.cpp b/src/operator/GlobalAveragePoolingImpl.cpp
index ca41b41..1192b63 100644
--- a/src/operator/GlobalAveragePoolingImpl.cpp
+++ b/src/operator/GlobalAveragePoolingImpl.cpp
@@ -1,5 +1,5 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 CEA-List
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
@@ -39,18 +39,11 @@ void Aidge::GlobalAveragePoolingImpl_cuda::forward() {
         );
     }
 
-    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
-        case DataType::Float64:
-            forward_<double>(input);
-            break;
-        case DataType::Float32:
-            forward_<float>(input);
-            break;
-        case DataType::Float16:
-            forward_<half>(input);
-            break;
-        default:
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    if (op.getOutput(0)->dataType() == DataType::Float64) {
+        forward_<double>(input);
+    }
+    else {
+        forward_<float>(input);
     }
 }
 
diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp
index b8d7c81..3054bb1 100644
--- a/src/operator/MaxPoolingImpl.cpp
+++ b/src/operator/MaxPoolingImpl.cpp
@@ -45,18 +45,11 @@ void Aidge::MaxPoolingImpl_cuda<DIM>::forward() {
                                         &strides[0]));
     }
 
-    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
-        case DataType::Float64:
-            forward_<double>(input);
-            break;
-        case DataType::Float32:
-            forward_<float>(input);
-            break;
-        case DataType::Float16:
-            forward_<half>(input);
-            break;
-        default:
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    if (op.getOutput(0)->dataType() == DataType::Float64) {
+        forward_<double>(input);
+    }
+    else {
+        forward_<float>(input);
     }
 }
 
diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp
index 2ebd6b2..6dd211e 100644
--- a/src/operator/ReLUImpl.cpp
+++ b/src/operator/ReLUImpl.cpp
@@ -1,5 +1,5 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 CEA-List
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
@@ -37,18 +37,11 @@ void Aidge::ReLUImpl_cuda::forward() {
 		#endif
     }
 
-    switch(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()) {
-        case DataType::Float64:
-            forward_<double>(input);
-            break;
-        case DataType::Float32:
-            forward_<float>(input);
-            break;
-        case DataType::Float16:
-            forward_<half>(input);
-            break;
-        default:
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type is not supported by Backend Cuda");
+    if (op.getOutput(0)->dataType() == DataType::Float64) {
+        forward_<double>(input);
+    }
+    else {
+        forward_<float>(input);
     }
 }
 
diff --git a/src/operator/ReshapeImpl.cpp b/src/operator/ReshapeImpl.cpp
index bd05bd2..59f1cfd 100644
--- a/src/operator/ReshapeImpl.cpp
+++ b/src/operator/ReshapeImpl.cpp
@@ -1,5 +1,5 @@
 /********************************************************************************
- * Copyright (c) 2023 CEA-List
+ * Copyright (c) 2024 CEA-List
  *
  * This program and the accompanying materials are made available under the
  * terms of the Eclipse Public License 2.0 which is available at
@@ -15,12 +15,11 @@
 #include <thread>  // std::this_thread::sleep_for
 #include <vector>
 
-#include "aidge/utils/Types.h"
-#include "aidge/operator/Reshape.hpp"
-
 #include "aidge/backend/cuda/data/TensorImpl.hpp"
 #include "aidge/backend/cuda/operator/ReshapeImpl.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
+#include "aidge/operator/Reshape.hpp"
+#include "aidge/utils/Types.h"
 
 void Aidge::ReshapeImpl_cuda::forward() {
     const OperatorTensor& op = static_cast<const OperatorTensor&>(mOp);
-- 
GitLab