diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 9de5210cfffc2b1bb24061eb1a4c5fea02103694..cce8c3f66b2551de45eba32923c389ae6a90ded0 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -47,7 +47,7 @@ class TensorImpl_cpu : public TensorImpl { std::size_t size() const override { return mData.size(); } std::size_t scalarSize() const override { return sizeof(T); } - void setDevice(int device) override { + void setDevice(DeviceIdx_t device) override { AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend"); } @@ -112,7 +112,7 @@ class TensorImpl_cpu : public TensorImpl { } } - void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override { + void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override { AIDGE_ASSERT(device.first == Backend, "backend must match"); AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend"); copy(src, length); @@ -129,29 +129,24 @@ class TensorImpl_cpu : public TensorImpl { static_cast<T *>(dst)); } - void *rawPtr() override { + void *rawPtr(NbElts_t offset = 0) override { lazyInit(); - return mData.data(); + return (mData.data() + offset); }; - const void *rawPtr() const override { + const void *rawPtr(NbElts_t offset = 0) const override { AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); - return mData.data(); + return (mData.data() + offset); }; - void *hostPtr() override { + void *hostPtr(NbElts_t offset = 0) override { lazyInit(); - return mData.data(); + return (mData.data() + offset); }; - const void *hostPtr() const override { + const void *hostPtr(NbElts_t offset = 0) const override { AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const hostPtr"); - return mData.data(); - }; - - void* getRawPtr(NbElts_t idx) override final { - AIDGE_ASSERT(idx < mData.size(), "idx out of range"); - return static_cast<void*>(static_cast<T*>(rawPtr()) + idx); + return (mData.data() + offset); }; void setRawPtr(void *ptr, NbElts_t length) override final { diff --git a/unit_tests/data/Test_TensorImpl.cpp b/unit_tests/data/Test_TensorImpl.cpp index 6c75c4dc19ff1b646308858ad262441d43390122..b75c49077f190ed61486fea8eaa18152423a73ed 100644 --- a/unit_tests/data/Test_TensorImpl.cpp +++ b/unit_tests/data/Test_TensorImpl.cpp @@ -45,7 +45,7 @@ TEST_CASE("Tensor creation") { REQUIRE(x.get<int>({0, 0, 1}) == 2); REQUIRE(x.get<int>({0, 1, 1}) == 4); REQUIRE(x.get<int>({1, 1, 0}) == 7); - x.get<int>({1, 1, 1}) = 36; + x.set<int>({1, 1, 1}, 36); REQUIRE(x.get<int>({1, 1, 1}) == 36); }