From ef618824ad18dd99a3a8fb7d8459c2c9e423454e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 5 Jan 2024 15:02:52 +0100 Subject: [PATCH] Additionnal fixes --- include/aidge/backend/TensorImpl.hpp | 6 ++---- include/aidge/data/Tensor.hpp | 8 +++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index 1e5c49baf..a27f0317c 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -159,11 +159,9 @@ public: */ void copyFrom(const TensorImpl& srcImpl, NbElts_t length); -private: - const char *mBackend; - protected: - int mDevice; + const char *mBackend; + DeviceIdx_t mDevice; }; } // namespace Aidge diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index c1458df04..aab6f3757 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -391,6 +391,8 @@ class Tensor : public Data, std::string toString() const { + AIDGE_ASSERT(mImpl && mImpl->hostPtr() != nullptr, "tensor should have a valid host pointer"); + // TODO: move lambda elsewhere? auto ptrToString = [](DataType dt, void* ptr, size_t idx) { switch (dt) { @@ -447,9 +449,9 @@ class Tensor : public Data, for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { res += spaceString + "{"; for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { - res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + ","; + res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + ","; } - res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + "}"; + res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + "}"; if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { res += ","; } @@ -469,7 +471,7 @@ class Tensor : public Data, } else { res += "{"; for (DimSize_t j = 0; j < dims()[0]; ++j) { - res += " " + ptrToString(mDataType, mImpl->rawPtr(), j) + ((j < dims()[0]-1) ? "," : ""); + res += " " + ptrToString(mDataType, mImpl->hostPtr(), j) + ((j < dims()[0]-1) ? "," : ""); } } res += "}"; -- GitLab