Skip to content
Snippets Groups Projects
Commit 88f6e241 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/convert' into dev

parents 2b4d63d1 13d6efaf
No related branches found
No related tags found
1 merge request!29Temporary master branch
Pipeline #36746 failed
......@@ -47,7 +47,7 @@ class TensorImpl_cpu : public TensorImpl {
std::size_t size() const override { return mData.size(); }
std::size_t scalarSize() const override { return sizeof(T); }
void setDevice(int device) override {
void setDevice(DeviceIdx_t device) override {
AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend");
}
......@@ -112,7 +112,7 @@ class TensorImpl_cpu : public TensorImpl {
}
}
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override {
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override {
AIDGE_ASSERT(device.first == Backend, "backend must match");
AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend");
copy(src, length);
......@@ -129,29 +129,24 @@ class TensorImpl_cpu : public TensorImpl {
static_cast<T *>(dst));
}
void *rawPtr() override {
void *rawPtr(NbElts_t offset = 0) override {
lazyInit();
return mData.data();
return (mData.data() + offset);
};
const void *rawPtr() const override {
const void *rawPtr(NbElts_t offset = 0) const override {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
return mData.data();
return (mData.data() + offset);
};
void *hostPtr() override {
void *hostPtr(NbElts_t offset = 0) override {
lazyInit();
return mData.data();
return (mData.data() + offset);
};
const void *hostPtr() const override {
const void *hostPtr(NbElts_t offset = 0) const override {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const hostPtr");
return mData.data();
};
void* getRawPtr(NbElts_t idx) override final {
AIDGE_ASSERT(idx < mData.size(), "idx out of range");
return static_cast<void*>(static_cast<T*>(rawPtr()) + idx);
return (mData.data() + offset);
};
void setRawPtr(void *ptr, NbElts_t length) override final {
......
......@@ -45,7 +45,7 @@ TEST_CASE("Tensor creation") {
REQUIRE(x.get<int>({0, 0, 1}) == 2);
REQUIRE(x.get<int>({0, 1, 1}) == 4);
REQUIRE(x.get<int>({1, 1, 0}) == 7);
x.get<int>({1, 1, 1}) = 36;
x.set<int>({1, 1, 1}, 36);
REQUIRE(x.get<int>({1, 1, 1}) == 36);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment