From 53f2e8b82fe030c1c03a20f185c08ba336f040a4 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Wed, 20 Sep 2023 06:18:23 +0000
Subject: [PATCH] [Tensor] Add get & set method at Tensor level.

---
 include/aidge/backend/TensorImpl.hpp  |  3 ++
 include/aidge/data/Tensor.hpp         | 49 +++++++++++++++++----------
 python_binding/data/pybind_Tensor.cpp | 37 ++++++++++++++------
 3 files changed, 61 insertions(+), 28 deletions(-)

diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp
index c56f66fc0..dfe3d932a 100644
--- a/include/aidge/backend/TensorImpl.hpp
+++ b/include/aidge/backend/TensorImpl.hpp
@@ -27,6 +27,9 @@ public:
     {
         printf("Cannot set raw pointer for backend %s\n", mBackend);
     };
+
+    virtual void* getRaw(std::size_t /*idx*/)=0;
+
     virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes)
     constexpr const char *backend() const { return mBackend; }
     virtual ~TensorImpl() = default;
diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 468f48fea..0a67d73a9 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -446,18 +446,33 @@ class Tensor : public Data,
      */
     bool empty() const { return mDims.empty(); }
 
-    template <typename expectedType, std::array<std::size_t, 1>::size_type DIM>
-    constexpr expectedType &get(std::array<std::size_t, DIM> idx) {
-        assert(DIM == mDims.size());
-        assert(mImpl);
-        std::size_t unfoldedIdx = 0;
-        for (std::size_t i = 0; i < DIM - std::size_t(1); ++i) {
-            unfoldedIdx = (unfoldedIdx + idx[i]) * mDims[i + 1];
-        }
-        unfoldedIdx += idx[DIM - 1];
-        return static_cast<expectedType *>(mImpl->rawPtr())[unfoldedIdx];
+    template <typename expectedType>
+    expectedType& get(std::size_t idx){
+        // TODO : add assert expected Type compatible with datatype
+        // TODO : add assert idx < Size
+        return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
+    }
+
+    template <typename expectedType>
+    expectedType& get(std::vector<std::size_t> coordIdx){
+        return get<expectedType>(getIdx(coordIdx));
+    }
+
+    template <typename expectedType>
+    void set(std::size_t idx, expectedType value){
+        // TODO : add assert expected Type compatible with datatype
+        // TODO : add assert idx < Size
+        void* dataPtr = mImpl->getRaw(idx);
+        std::memcpy(dataPtr, &value, sizeof(expectedType));
     }
 
+    template <typename expectedType>
+    void set(std::vector<std::size_t> coordIdx, expectedType value){
+        set<expectedType>(getIdx(coordIdx), value);
+    }
+
+
+
     std::string toString() {
         if (dims().empty()) { return "{}"; }
         std::string res;
@@ -565,10 +580,10 @@ class Tensor : public Data,
      * @param flatIdx 1D index of the value considering a flatten tensor.
      * @return std::vector<DimSize_t>
      */
-    std::vector<DimSize_t> getCoord(DimSize_t flatIdx){
-        std::vector<DimSize_t> coordIdx = {};
-        DimSize_t idx = flatIdx;
-        for (DimSize_t d: mDims){
+    std::vector<std::size_t> getCoord(std::size_t flatIdx){
+        std::vector<std::size_t> coordIdx = {};
+        std::size_t idx = flatIdx;
+        for (std::size_t d: mDims){
             coordIdx.push_back(idx % d);
             idx/=d;
         }
@@ -581,9 +596,9 @@ class Tensor : public Data,
      * @param coordIdx Coordinate to an element in the tensor
      * @return DimSize_t
      */
-    DimSize_t getIdx(std::vector<DimSize_t> coordIdx){
-        DimSize_t flatIdx = 0;
-        DimSize_t stride = 1;
+    std::size_t getIdx(std::vector<std::size_t> coordIdx){
+        std::size_t flatIdx = 0;
+        std::size_t stride = 1;
         assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
         for(std::size_t i=0; i< mDims.size(); ++i){
             assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 168c2c946..31470e0eb 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -48,7 +48,10 @@ void addCtor(py::class_<Tensor,
         }
 
         return newTensor;
-    }));
+    }))
+    .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set)
+    .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set)
+    ;
 }
 
 
@@ -84,15 +87,27 @@ void init_Tensor(py::module& m){
         return b.size();
     })
     .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
-        // TODO : Should return error if backend not compatible with get
         if (idx >= b.size()) throw py::index_error();
         switch(b.dataType()){
             case DataType::Float64:
-                return py::cast(static_cast<double*>(b.getImpl()->rawPtr())[idx]);
+                return py::cast(b.get<double>(idx));
             case DataType::Float32:
-                return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]);
+                return py::cast(b.get<float>(idx));
             case DataType::Int32:
-                return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]);
+                return py::cast(b.get<int>(idx));
+            default:
+                return py::none();
+        }
+    })
+    .def("__getitem__", [](Tensor& b, std::vector<size_t> coordIdx)-> py::object {
+        if (b.getIdx(coordIdx) >= b.size()) throw py::index_error();
+        switch(b.dataType()){
+            case DataType::Float64:
+                return py::cast(b.get<double>(coordIdx));
+            case DataType::Float32:
+                return py::cast(b.get<float>(coordIdx));
+            case DataType::Int32:
+                return py::cast(b.get<int>(coordIdx));
             default:
                 return py::none();
         }
@@ -128,12 +143,12 @@ void init_Tensor(py::module& m){
         }
 
         return py::buffer_info(
-            tensorImpl->rawPtr(),                       /* Pointer to buffer */
-            tensorImpl->scalarSize(),                   /* Size of one scalar */
-            dataFormatDescriptor,                /* Python struct-style format descriptor */
-            b.nbDims(),                                 /* Number of dimensions */
-            dims,                                       /* Buffer dimensions */
-            strides                                     /* Strides (in bytes) for each index */
+            tensorImpl->rawPtr(),       /* Pointer to buffer */
+            tensorImpl->scalarSize(),   /* Size of one scalar */
+            dataFormatDescriptor,       /* Python struct-style format descriptor */
+            b.nbDims(),                 /* Number of dimensions */
+            dims,                       /* Buffer dimensions */
+            strides                     /* Strides (in bytes) for each index */
         );
     });
 
-- 
GitLab