From 0efd2f60bb76f21f2e3b011e02151ec8090f9145 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 7 Dec 2023 00:09:29 +0100
Subject: [PATCH] Minor fixes

---
 include/aidge/backend/TensorImpl.hpp        | 12 ++++++++++--
 python_binding/data/pybind_Tensor.cpp       |  4 ++--
 python_binding/graph/pybind_GraphView.cpp   |  2 +-
 python_binding/operator/pybind_Operator.cpp |  2 +-
 4 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp
index 965483ae7..1aabd7b2b 100644
--- a/include/aidge/backend/TensorImpl.hpp
+++ b/include/aidge/backend/TensorImpl.hpp
@@ -16,8 +16,15 @@
 #include <cstdio>
 #include "aidge/data/Data.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
 
 namespace Aidge {
+/**
+ * This class manages the raw data storage of a Tensor and provide generic copy
+ * primitives from other devices and from/to host.
+ * It can own the data or not (use setRawPtr() to set an external data owner).
+ * It only knows the data type and data capacity, but does not handle anything else.
+*/
 class TensorImpl {
 public:
     TensorImpl() = delete;
@@ -90,10 +97,11 @@ public:
      * UNSAFE: directly setting the device pointer may lead to undefined behavior
      * if it does not match the required storage.
      * @param ptr A valid device pointer.
+     * @param length Storage capacity at the provided pointer
     */
-    virtual void setRawPtr(void* /*ptr*/)
+    virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/)
     {
-        printf("Cannot set raw pointer for backend %s\n", mBackend);
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
     };
 
     virtual void* getRaw(std::size_t /*idx*/)=0;
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index babc534bd..067ad8c00 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -42,7 +42,7 @@ void addCtor(py::class_<Tensor,
         std::set<std::string> availableBackends = Tensor::getAvailableBackends();
         if (availableBackends.find("cpu") != availableBackends.end()){
             newTensor->setBackend("cpu");
-            newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr));
+            newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr), newTensor->size());
         }else{
             printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n");
         }
@@ -71,7 +71,7 @@ void init_Tensor(py::module& m){
         (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
 
     pyClassTensor.def(py::init<>())
-    .def("set_backend", &Tensor::setBackend, py::arg("name"))
+    .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0)
     .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
     .def("dtype", &Tensor::dataType)
     .def("size", &Tensor::size)
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index 195e2740b..19e3e70d6 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -99,7 +99,7 @@ void init_GraphView(py::module& m) {
           .def("forward_dims", &GraphView::forwardDims)
           .def("__call__", &GraphView::operator(), py::arg("connectors"))
           .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
-          .def("set_backend", &GraphView::setBackend, py::arg("backend"))
+          .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
           //   .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
           //      // TODO : Should return error if backend not compatible with get
           //      if (idx >= b.size()) throw py::index_error();
diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp
index f9482eda2..3cfa4b15e 100644
--- a/python_binding/operator/pybind_Operator.cpp
+++ b/python_binding/operator/pybind_Operator.cpp
@@ -29,7 +29,7 @@ void init_Operator(py::module& m){
     .def("nb_outputs", &Operator::nbOutputs)
     .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
     .def("set_datatype", &Operator::setDataType, py::arg("dataType"))
-    .def("set_backend", &Operator::setBackend, py::arg("name"))
+    .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0)
     .def("forward", &Operator::forward)
     // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
     .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
-- 
GitLab