From 69cc72c54a0aab17b694a3d7b7a013a333c6ff82 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 12 Jan 2024 14:26:50 +0000
Subject: [PATCH] [Upd] tensorImpl 'data()' and switch cases

---
 include/aidge/backend/cpu/data/TensorImpl.hpp | 99 ++++++++++---------
 1 file changed, 51 insertions(+), 48 deletions(-)

diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp
index cce8c3f6..b02c9ef2 100644
--- a/include/aidge/backend/cpu/data/TensorImpl.hpp
+++ b/include/aidge/backend/cpu/data/TensorImpl.hpp
@@ -27,7 +27,7 @@ class TensorImpl_cpu : public TensorImpl {
 
     bool operator==(const TensorImpl &otherImpl) const override final {
         const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl);
-        AIDGE_INTERNAL_ASSERT(typedOtherImpl.data().size() >= mTensor.size());
+        AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mTensor.size());
 
         std::size_t i = 0;
         for (; i < mTensor.size() &&
@@ -42,7 +42,7 @@ class TensorImpl_cpu : public TensorImpl {
     }
 
     // native interface
-    const future_std::span<T>& data() const { return mData; }
+    auto data() const -> decltype(mData.data()) { return mData.data(); }
 
     std::size_t size() const override { return mData.size(); }
     std::size_t scalarSize() const override { return sizeof(T); }
@@ -63,52 +63,55 @@ class TensorImpl_cpu : public TensorImpl {
         }
 
         AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
-        if (srcDt == DataType::Float64) {
-            std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::Float32) {
-            std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::Float16) {
-            std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::Int64) {
-            std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::UInt64) {
-            std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::Int32) {
-            std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::UInt32) {
-            std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::Int16) {
-            std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::UInt16) {
-            std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::Int8) {
-            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else if (srcDt == DataType::UInt8) {
-            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
-                    static_cast<T *>(rawPtr()));
-        }
-        else {
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
+        switch (srcDt)
+        {
+            case DataType::Float64:
+                std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::Float32:
+                std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::Float16:
+                std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::Int64:
+                std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::UInt64:
+                std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::Int32:
+                std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::UInt32:
+                std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::Int16:
+                std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::UInt16:
+                std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case ataType::Int8:
+                std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            case DataType::UInt8:
+                std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                        static_cast<T *>(rawPtr()));
+                break;
+            default:
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
+                break;
         }
     }
 
-- 
GitLab