From e45c2e38cab58909b8da68df53fde01347795f25 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 8 Dec 2023 22:39:14 +0100 Subject: [PATCH] Various fixes --- include/aidge/data/Tensor.hpp | 9 +++------ include/aidge/operator/ConvDepthWise.hpp | 10 +++++++++- include/aidge/operator/Convert.hpp | 4 ---- include/aidge/operator/FC.hpp | 11 +++++++++-- include/aidge/utils/TensorUtils.hpp | 8 ++++---- src/backend/TensorImpl.cpp | 4 ++++ src/data/Tensor.cpp | 2 +- 7 files changed, 30 insertions(+), 18 deletions(-) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 903ce2f10..7195f0b20 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -377,11 +377,8 @@ class Tensor : public Data, */ void setDataType(const DataType dt) { if (mImpl && (dataType() != dt)) { - // get ptr before changing Tensor backend or the type difference will trigger a warning - const void *data = mImpl->rawPtr(); - mDataType = dt; std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this); - newImpl->copy(data, size()); // /!\ it does not cast data but reinterpret them + newImpl->copyCast(mImpl->rawPtr(), size(), mDataType); mImpl = std::move(newImpl); } mDataType = dt; @@ -487,7 +484,7 @@ class Tensor : public Data, - std::string toString() { + std::string toString() const { if (dims().empty()) { return "{}"; } std::string res; std::size_t dim = 0; @@ -580,7 +577,7 @@ class Tensor : public Data, return res; } - inline void print() { printf("%s\n", toString().c_str()); } + inline void print() const { printf("%s\n", toString().c_str()); } std::shared_ptr<Tensor> grad() { if (!mGrad) { diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index c374c7f71..e03524643 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -169,11 +169,19 @@ public: mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround + // By default, automatically set backend for weight and bias inputs getInput(1)->setBackend(name, device); getInput(2)->setBackend(name, device); } + void setDataType(const DataType& dt) const override { + mOutputs[0]->setDataType(dt); + + // By default, automatically set data type for weight and bias inputs + getInput(1)->setDataType(dt); + getInput(2)->setDataType(dt); + } + static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Convert.hpp index 6a08fbf0d..cb54ffbf7 100644 --- a/include/aidge/operator/Convert.hpp +++ b/include/aidge/operator/Convert.hpp @@ -57,10 +57,6 @@ public: mOutputs[0]->setBackend(name, device); } - void setDataType(const DataType& dataType) const override { - mOutputs[0]->setDataType(dataType); - } - void forward() override; static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index beb5f82f5..2e7b3a22f 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -99,12 +99,19 @@ public: mImpl = Registrar<FC_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); + // By default, automatically set backend for weight and bias inputs getInput(1)->setBackend(name, device); getInput(2)->setBackend(name, device); } + void setDataType(const DataType& dt) const override { + mOutputs[0]->setDataType(dt); + + // By default, automatically set data type for weight and bias inputs + getInput(1)->setDataType(dt); + getInput(2)->setDataType(dt); + } + static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/include/aidge/utils/TensorUtils.hpp b/include/aidge/utils/TensorUtils.hpp index 638761954..cb10f2f8d 100644 --- a/include/aidge/utils/TensorUtils.hpp +++ b/include/aidge/utils/TensorUtils.hpp @@ -31,10 +31,10 @@ * @param absolute absolute error allowed (shoulmd be positive) * @return true if both tensor are approximately equal and have the datatype, shape. Else return false */ -template <typename T> +template <typename T1, typename T2 = T1> bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute){ - assert(t1.dataType() == t2.dataType()); - assert(t1.dataType() == NativeType<T>::type); + assert(t1.dataType() == NativeType<T1>::type); + assert(t2.dataType() == NativeType<T2>::type); assert(relative >= 0); assert(absolute >= 0 && absolute<=1); @@ -42,7 +42,7 @@ bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute return false; } for(size_t i; i < t1.size(); ++i){ - if (static_cast<float>(std::abs(t1.get<T>(i) - t2.get<T>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T>(i)))))){ + if (static_cast<float>(std::abs(t1.get<T1>(i) - t2.get<T2>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T2>(i)))))){ return false; } } diff --git a/src/backend/TensorImpl.cpp b/src/backend/TensorImpl.cpp index 371d775d7..282f1222e 100644 --- a/src/backend/TensorImpl.cpp +++ b/src/backend/TensorImpl.cpp @@ -15,6 +15,10 @@ #include "aidge/utils/ErrorHandling.hpp" void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { + if (srcImpl == *this) { + return; + } + if (srcImpl.device() != device()) { if (srcImpl.backend() == backend()) { // Same backend, but different device diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index a4690ea42..1f8257a70 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -42,7 +42,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c const auto device = getImpl()->device(); fallback->setBackend(device.first, device.second); fallback->resize(dims()); - fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dt); + fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); return *fallback; } } -- GitLab