From 23d46d9ef0906d84724fb1ed0d034c90c44cafc2 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 3 Dec 2023 18:57:06 +0100
Subject: [PATCH] Initial concept for Convert operator (UNTESTED)

---
 include/aidge/backend/cpu/data/TensorImpl.hpp | 72 +++++++++++++++++++
 1 file changed, 72 insertions(+)

diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp
index 012ff5af..377b4546 100644
--- a/include/aidge/backend/cpu/data/TensorImpl.hpp
+++ b/include/aidge/backend/cpu/data/TensorImpl.hpp
@@ -5,6 +5,7 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
 
 namespace Aidge {
 template <class T>
@@ -37,16 +38,87 @@ class TensorImpl_cpu : public TensorImpl {
 
     std::size_t scalarSize() const override { return sizeof(T); }
 
+    void setDevice(int device) override {
+        AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend");
+    }
+
     void copy(const void *src, NbElts_t length) override {
         std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
                   static_cast<T *>(rawPtr()));
     }
 
+    void copyCast(const void *src, NbElts_t length, const DataType srcDt) override {
+        if (srcDt == DataType::Float64) {
+            std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::Float32) {
+            std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::Int64) {
+            std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::UInt64) {
+            std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::Int32) {
+            std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::UInt32) {
+            std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::Int16) {
+            std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::UInt16) {
+            std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::Int8) {
+            std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else if (srcDt == DataType::UInt8) {
+            std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
+        else {
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
+        }
+    }
+
+    void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override {
+        AIDGE_ASSERT(device.first == Backend, "backend must match");
+        AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend");
+        copy(src, length);
+    }
+
+    void copyFromHost(const void *src, NbElts_t length) override {
+        copy(src, length);
+    }
+
+    void copyToHost(void *dst, NbElts_t length) override {
+        const T* src = static_cast<const T*>(rawPtr());
+        std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
+                  static_cast<T *>(dst));
+    }
+
     void *rawPtr() override {
         lazyInit(mData);
         return mData.data();
     };
 
+    void *hostPtr() override {
+        lazyInit(mData);
+        return mData.data();
+    };
+
    void* getRaw(std::size_t idx){
        return  static_cast<void*>(static_cast<T *>(rawPtr()) + idx);
    };
-- 
GitLab