From 60247cc73141a045afaa431bd36c4df96dda3140 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 5 Dec 2023 15:23:50 +0100
Subject: [PATCH] Move Convert function to TensorImpl::copyFrom() and
 Tensor::copyCastFrom()

---
 include/aidge/backend/TensorImpl.hpp | 13 +++++-
 include/aidge/data/Tensor.hpp        | 44 +++++++++++++-------
 include/aidge/operator/Convert.hpp   |  3 --
 src/backend/TensorImpl.cpp           | 47 +++++++++++++++++++++
 src/data/Tensor.cpp                  | 35 ++++++++++++++++
 src/operator/Convert.cpp             | 61 +---------------------------
 6 files changed, 124 insertions(+), 79 deletions(-)
 create mode 100644 src/backend/TensorImpl.cpp
 create mode 100644 src/data/Tensor.cpp

diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp
index 2060c9273..965483ae7 100644
--- a/include/aidge/backend/TensorImpl.hpp
+++ b/include/aidge/backend/TensorImpl.hpp
@@ -69,19 +69,21 @@ public:
      * @param src Host pointer to copy to.
      * @param length Number of bytes to copy.
     */
-    virtual void copyToHost(void *dst, NbElts_t length) = 0;
+    virtual void copyToHost(void *dst, NbElts_t length) const = 0;
 
     /**
      * Return the raw device pointer.
      * The raw pointer is garanteed to be valid only on the *same* device.
     */
     virtual void* rawPtr() = 0;
+    virtual const void* rawPtr() const = 0;
 
     /**
      * Return the host pointer.
      * If the implementation does not have a valid host pointer, nullptr is returned.
     */
-    virtual void* hostPtr() = 0;
+    virtual void* hostPtr() { return nullptr; };
+    virtual const void* hostPtr() const { return nullptr; };
 
     /**
      * Sets the device pointer.
@@ -101,6 +103,13 @@ public:
     virtual ~TensorImpl() = default;
     virtual bool operator==(const TensorImpl &othImpl) const = 0;
 
+    /**
+     * Copy from another backend.
+     * @param srcImpl Source TensorImpl to copy from.
+     * @param length Number of elements of size scalarSize() to copy
+    */
+    void copyFrom(const TensorImpl& srcImpl, NbElts_t length);
+
 private:
     const char *mBackend;
 
diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 9fabd2d4e..3d50c2500 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -178,6 +178,8 @@ class Tensor : public Data,
     {
         if (otherTensor.hasImpl()) {
             mImpl = Registrar<Tensor>::create({otherTensor.mImpl->backend(), dataType()})(*this);
+            mImpl->setDevice(otherTensor.mImpl->device().second);
+            // Same backend, same device => directly use copy()
             mImpl->copy(otherTensor.mImpl->rawPtr(), mSize);
         }
     }
@@ -195,7 +197,7 @@ class Tensor : public Data,
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
           mSize(SIZE_0),
           mSizeM1(SIZE_0) {
-        mImpl->copy(&arr.data[0], SIZE_0);
+        mImpl->copyFromHost(&arr.data[0], SIZE_0);
     }
 
     template <typename T, std::size_t SIZE_0>
@@ -204,7 +206,7 @@ class Tensor : public Data,
         if (!mImpl) {
             mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
         }
-        mImpl->copy(&arr.data[0], SIZE_0);
+        mImpl->copyFromHost(&arr.data[0], SIZE_0);
         return *this;
     }
 
@@ -222,7 +224,7 @@ class Tensor : public Data,
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
           mSize(SIZE_0 * SIZE_1),
           mSizeM1(SIZE_1) {
-        mImpl->copy(&arr.data[0][0], SIZE_0 * SIZE_1);
+        mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1);
     }
 
     template <typename T, std::size_t SIZE_0, std::size_t SIZE_1>
@@ -231,7 +233,7 @@ class Tensor : public Data,
         if (!mImpl) {
             mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
         }
-        mImpl->copy(&arr.data[0][0], SIZE_0 * SIZE_1);
+        mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1);
         return *this;
     }
 
@@ -250,7 +252,7 @@ class Tensor : public Data,
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
           mSize(SIZE_0 * SIZE_1 * SIZE_2),
           mSizeM1(SIZE_1 * SIZE_2) {
-        mImpl->copy(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
+        mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
     }
 
     template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2>
@@ -259,7 +261,7 @@ class Tensor : public Data,
         if (!mImpl) {
             mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
         }
-        mImpl->copy(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
+        mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
         return *this;
     }
 
@@ -279,7 +281,7 @@ class Tensor : public Data,
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
           mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3),
           mSizeM1(SIZE_1 * SIZE_2 * SIZE_3) {
-        mImpl->copy(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
+        mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
     }
 
     template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3>
@@ -288,7 +290,7 @@ class Tensor : public Data,
         if (!mImpl) {
             mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
         }
-        mImpl->copy(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
+        mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
         return *this;
     }
 
@@ -301,8 +303,15 @@ class Tensor : public Data,
         resize(t.dims());
         setDataType(t.dataType());
         if (t.hasImpl()) {
-            setBackend(t.mImpl->backend(), t.mImpl->device().second);
-            mImpl->copy(t.mImpl->rawPtr(), size());
+            if (hasImpl()) {
+                copyCastFrom(t);
+            }
+            else {
+                mImpl = Registrar<Tensor>::create({t.mImpl->backend(), dataType()})(*this);
+                mImpl->setDevice(t.mImpl->device().second);
+                // Same backend, same device => directly use copy()
+                mImpl->copy(t.mImpl->rawPtr(), mSize);
+            }
         }
         else {
             mImpl = nullptr;
@@ -334,10 +343,8 @@ class Tensor : public Data,
                 // impl
                 std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this);
                 newImpl->setDevice(device);
-
-                //TODO: FIXME: copy() work only on same device!
-                //newImpl->copy(mImpl->rawPtr(), size());
-                //mImpl = std::move(newImpl);
+                newImpl->copyFrom(*mImpl, size());
+                mImpl = std::move(newImpl);
             }
         }
         else {
@@ -385,6 +392,7 @@ class Tensor : public Data,
      * @return constexpr const std::unique_ptr<TensorImpl>&
      */
     constexpr const std::unique_ptr<TensorImpl> &getImpl() { return mImpl; }
+    constexpr const std::unique_ptr<TensorImpl> &getImpl() const { return mImpl; }
 
     /**
      * @brief Return if an implementaiton has been associated.
@@ -621,6 +629,14 @@ class Tensor : public Data,
         return flatIdx + coordIdx[i];
     }
 
+    void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrc);
+    void copyCastFrom(const Tensor& src) {
+        // Internal buffers will be allocated and deallocated at each call
+        // (if they are needed)
+        std::shared_ptr<Tensor> convertedSrc;
+        copyCastFrom(src, convertedSrc);
+    }
+
 private:
     ///\bug not protected against overflow
     std::size_t computeSize() {
diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Convert.hpp
index 19ffb9a0d..f115a243d 100644
--- a/include/aidge/operator/Convert.hpp
+++ b/include/aidge/operator/Convert.hpp
@@ -74,9 +74,6 @@ private:
     /// @brief Store the data to the right type on input device
     /// Required for any type conversion.
     std::shared_ptr<Tensor> mConvertedInput;
-    /// @brief Store the data to the right type on host
-    /// Required if there is no direct link between input and output devices
-    std::shared_ptr<Tensor> mHostBuffer;
 };
 
 inline std::shared_ptr<Node> Convert(const std::string& name = "") {
diff --git a/src/backend/TensorImpl.cpp b/src/backend/TensorImpl.cpp
new file mode 100644
index 000000000..371d775d7
--- /dev/null
+++ b/src/backend/TensorImpl.cpp
@@ -0,0 +1,47 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/backend/TensorImpl.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
+    if (srcImpl.device() != device()) {
+        if (srcImpl.backend() == backend()) {
+            // Same backend, but different device
+            copyFromDevice(srcImpl.rawPtr(), length, srcImpl.device());
+        }
+        else if (srcImpl.hostPtr() != nullptr) {
+            // Different backend, but input is valid on host
+            copyFromHost(srcImpl.hostPtr(), length);
+        }
+        else if (hostPtr() != nullptr) {
+            // Different backend, but dst is valid on host
+            srcImpl.copyToHost(hostPtr(), length);
+        }
+        else {
+            // No direct link possible from src to dst device
+            // SLOW SOLUTION: must pass through the host, requires TWO copies
+            // Allocate a temporary host buffer just for the copy
+            // We might reuse a pre-allocated buffer, but for now this feature is not provided because:
+            // - There is currently no concrete use case
+            // - Just providing a pointer would be unsafe (risk of buffer overflow...)
+            auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]);
+            srcImpl.copyToHost(tmpHostBuffer.get(), length);
+            copyFromHost(tmpHostBuffer.get(), length);
+        }
+    }
+    else {
+        // Same device: simple copy on device
+        copy(srcImpl.rawPtr(), length);
+    }
+}
diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp
new file mode 100644
index 000000000..1ae14a2b0
--- /dev/null
+++ b/src/data/Tensor.cpp
@@ -0,0 +1,35 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrcPtr) {
+    // convertedSrcPtr stores data to the desired (dst) type
+    if (src.dataType() != dataType()) {
+        // Different type: create a new tensor on same src device
+        if (!convertedSrcPtr) {
+            convertedSrcPtr = std::make_shared<Tensor>(dataType());
+        }
+
+        convertedSrcPtr->setDataType(dataType());
+        const auto device = src.getImpl()->device();
+        convertedSrcPtr->setBackend(device.first, device.second);
+        convertedSrcPtr->resize(src.dims());
+
+        // Copy convert src to convertedSrcPtr
+        convertedSrcPtr->getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType());
+    }
+
+    const Tensor& convertedSrc = (src.dataType() != dataType()) ? *convertedSrcPtr : src;
+    getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size());
+}
diff --git a/src/operator/Convert.cpp b/src/operator/Convert.cpp
index d88fcfed4..14769add6 100644
--- a/src/operator/Convert.cpp
+++ b/src/operator/Convert.cpp
@@ -9,74 +9,15 @@
  *
  ********************************************************************************/
 
-#include <cassert>
-#include <cstddef>
-#include <vector>
-#include <utility>
-
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/operator/Convert.hpp"
-#include "aidge/utils/Types.h"
-#include "aidge/utils/ErrorHandling.hpp"
 
 void Aidge::Convert_Op::forward() {
     if (mImpl) {
         mImpl->forward();
     }
     else {
-        // mConvertedInput stores data to the desired (output) type
-        if (mInputs[0]->dataType() != mOutputs[0]->dataType()) {
-            // Different type: create a new tensor on same input device
-            if (!mConvertedInput) {
-                mConvertedInput = std::make_shared<Tensor>(mOutputs[0]->dataType());
-            }
-
-            mConvertedInput->setDataType(mOutputs[0]->dataType());
-            const auto device = mInputs[0]->getImpl()->device();
-            mConvertedInput->setBackend(device.first, device.second);
-            mConvertedInput->resize(mInputs[0]->dims());
-
-            // Copy convert input to mConvertedInput
-            mConvertedInput->getImpl()->copyCast(mInputs[0]->getImpl()->rawPtr(), mInputs[0]->size(), mInputs[0]->dataType());
-        }
-        else {
-            // Same type: mConvertedInput *is* the input
-            mConvertedInput = mInputs[0];
-        }
-
-        // Copy to output device, if necessary
-        if (mConvertedInput->getImpl()->device() != mOutputs[0]->getImpl()->device()) {
-            if (mConvertedInput->getImpl()->backend() == mOutputs[0]->getImpl()->backend()) {
-                // Same backend, but different device
-                mOutputs[0]->getImpl()->copyFromDevice(mConvertedInput->getImpl()->rawPtr(), mConvertedInput->size(), mConvertedInput->getImpl()->device());
-            }
-            else if (mConvertedInput->getImpl()->hostPtr() != nullptr) {
-                // Different backend, but input is valid on host
-                mOutputs[0]->getImpl()->copyFromHost(mConvertedInput->getImpl()->hostPtr(), mConvertedInput->size());
-            }
-            else if (mOutputs[0]->getImpl()->hostPtr() != nullptr) {
-                // Different backend, but output is valid on host
-                mConvertedInput->getImpl()->copyToHost(mOutputs[0]->getImpl()->hostPtr(), mConvertedInput->size());
-            }
-            else {
-                // No direct link possible from input to output device
-                // SLOW SOLUTION: must pass through the host, requires TWO copies
-                const auto availableBackends = Tensor::getAvailableBackends();
-                AIDGE_ASSERT(availableBackends.find("cpu") != availableBackends.end(), "Conversion requires CPU backend");
-
-                if (!mHostBuffer) {
-                    mHostBuffer = std::make_shared<Tensor>(mOutputs[0]->dataType());
-                    mHostBuffer->setBackend("cpu");
-                }
-
-                mConvertedInput->getImpl()->copyToHost(mHostBuffer->getImpl()->rawPtr(), mConvertedInput->size());
-                mOutputs[0]->getImpl()->copyFromHost(mHostBuffer->getImpl()->rawPtr(), mConvertedInput->size());
-            }
-        }
-        else {
-            // Same device: simple copy on device
-            mConvertedInput->getImpl()->copy(mConvertedInput->getImpl()->rawPtr(), mConvertedInput->size());
-        }
+        mOutputs[0]->copyCastFrom(*(mInputs[0]), mConvertedInput);
     }
 
     runHooks();
-- 
GitLab