From e45c2e38cab58909b8da68df53fde01347795f25 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 8 Dec 2023 22:39:14 +0100
Subject: [PATCH] Various fixes

---
 include/aidge/data/Tensor.hpp            |  9 +++------
 include/aidge/operator/ConvDepthWise.hpp | 10 +++++++++-
 include/aidge/operator/Convert.hpp       |  4 ----
 include/aidge/operator/FC.hpp            | 11 +++++++++--
 include/aidge/utils/TensorUtils.hpp      |  8 ++++----
 src/backend/TensorImpl.cpp               |  4 ++++
 src/data/Tensor.cpp                      |  2 +-
 7 files changed, 30 insertions(+), 18 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 903ce2f10..7195f0b20 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -377,11 +377,8 @@ class Tensor : public Data,
      */
     void setDataType(const DataType dt) {
         if (mImpl && (dataType() != dt)) {
-            // get ptr before changing Tensor backend or the type difference will trigger a warning
-            const void *data = mImpl->rawPtr();
-            mDataType = dt;
             std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this);
-            newImpl->copy(data, size());  // /!\ it does not cast data but reinterpret them
+            newImpl->copyCast(mImpl->rawPtr(), size(), mDataType);
             mImpl = std::move(newImpl);
         }
         mDataType = dt;
@@ -487,7 +484,7 @@ class Tensor : public Data,
 
 
 
-    std::string toString() {
+    std::string toString() const {
         if (dims().empty()) { return "{}"; }
         std::string res;
         std::size_t dim = 0;
@@ -580,7 +577,7 @@ class Tensor : public Data,
         return res;
     }
 
-    inline void print() { printf("%s\n", toString().c_str()); }
+    inline void print() const { printf("%s\n", toString().c_str()); }
 
     std::shared_ptr<Tensor> grad() {
         if (!mGrad) {
diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp
index c374c7f71..e03524643 100644
--- a/include/aidge/operator/ConvDepthWise.hpp
+++ b/include/aidge/operator/ConvDepthWise.hpp
@@ -169,11 +169,19 @@ public:
         mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
 
-        // FIXME: temporary workaround
+        // By default, automatically set backend for weight and bias inputs
         getInput(1)->setBackend(name, device);
         getInput(2)->setBackend(name, device);
     }
 
+    void setDataType(const DataType& dt) const override {
+        mOutputs[0]->setDataType(dt);
+
+        // By default, automatically set data type for weight and bias inputs
+        getInput(1)->setDataType(dt);
+        getInput(2)->setDataType(dt);
+    }
+
     static const std::vector<std::string> getInputsName(){
         return {"data_input", "weight", "bias"};
     }
diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Convert.hpp
index 6a08fbf0d..cb54ffbf7 100644
--- a/include/aidge/operator/Convert.hpp
+++ b/include/aidge/operator/Convert.hpp
@@ -57,10 +57,6 @@ public:
         mOutputs[0]->setBackend(name, device);
     }
 
-    void setDataType(const DataType& dataType) const override {
-        mOutputs[0]->setDataType(dataType);
-    }
-
     void forward() override;
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp
index beb5f82f5..2e7b3a22f 100644
--- a/include/aidge/operator/FC.hpp
+++ b/include/aidge/operator/FC.hpp
@@ -99,12 +99,19 @@ public:
         mImpl = Registrar<FC_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
 
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
+        // By default, automatically set backend for weight and bias inputs
         getInput(1)->setBackend(name, device);
         getInput(2)->setBackend(name, device);
     }
 
+    void setDataType(const DataType& dt) const override {
+        mOutputs[0]->setDataType(dt);
+
+        // By default, automatically set data type for weight and bias inputs
+        getInput(1)->setDataType(dt);
+        getInput(2)->setDataType(dt);
+    }
+
     static const std::vector<std::string> getInputsName(){
         return {"data_input", "weight", "bias"};
     }
diff --git a/include/aidge/utils/TensorUtils.hpp b/include/aidge/utils/TensorUtils.hpp
index 638761954..cb10f2f8d 100644
--- a/include/aidge/utils/TensorUtils.hpp
+++ b/include/aidge/utils/TensorUtils.hpp
@@ -31,10 +31,10 @@
  * @param absolute absolute error allowed (shoulmd be positive)
  * @return true if both tensor are approximately equal and have the datatype, shape. Else return false
  */
-template <typename T>
+template <typename T1, typename T2 = T1>
 bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute){
-    assert(t1.dataType() == t2.dataType());
-    assert(t1.dataType() == NativeType<T>::type);
+    assert(t1.dataType() == NativeType<T1>::type);
+    assert(t2.dataType() == NativeType<T2>::type);
     assert(relative >= 0);
     assert(absolute >= 0 && absolute<=1);
 
@@ -42,7 +42,7 @@ bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute
         return false;
     }
     for(size_t i; i < t1.size(); ++i){
-        if (static_cast<float>(std::abs(t1.get<T>(i) - t2.get<T>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T>(i)))))){
+        if (static_cast<float>(std::abs(t1.get<T1>(i) - t2.get<T2>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T2>(i)))))){
             return false;
         }
     }
diff --git a/src/backend/TensorImpl.cpp b/src/backend/TensorImpl.cpp
index 371d775d7..282f1222e 100644
--- a/src/backend/TensorImpl.cpp
+++ b/src/backend/TensorImpl.cpp
@@ -15,6 +15,10 @@
 #include "aidge/utils/ErrorHandling.hpp"
 
 void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
+    if (srcImpl == *this) {
+        return;
+    }
+
     if (srcImpl.device() != device()) {
         if (srcImpl.backend() == backend()) {
             // Same backend, but different device
diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp
index a4690ea42..1f8257a70 100644
--- a/src/data/Tensor.cpp
+++ b/src/data/Tensor.cpp
@@ -42,7 +42,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
         const auto device = getImpl()->device();
         fallback->setBackend(device.first, device.second);
         fallback->resize(dims());
-        fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dt);
+        fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType());
         return *fallback;
     }
 }
-- 
GitLab