From 0efd2f60bb76f21f2e3b011e02151ec8090f9145 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 7 Dec 2023 00:09:29 +0100 Subject: [PATCH] Minor fixes --- include/aidge/backend/TensorImpl.hpp | 12 ++++++++++-- python_binding/data/pybind_Tensor.cpp | 4 ++-- python_binding/graph/pybind_GraphView.cpp | 2 +- python_binding/operator/pybind_Operator.cpp | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index 965483ae7..1aabd7b2b 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -16,8 +16,15 @@ #include <cstdio> #include "aidge/data/Data.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { +/** + * This class manages the raw data storage of a Tensor and provide generic copy + * primitives from other devices and from/to host. + * It can own the data or not (use setRawPtr() to set an external data owner). + * It only knows the data type and data capacity, but does not handle anything else. +*/ class TensorImpl { public: TensorImpl() = delete; @@ -90,10 +97,11 @@ public: * UNSAFE: directly setting the device pointer may lead to undefined behavior * if it does not match the required storage. * @param ptr A valid device pointer. + * @param length Storage capacity at the provided pointer */ - virtual void setRawPtr(void* /*ptr*/) + virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/) { - printf("Cannot set raw pointer for backend %s\n", mBackend); + AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); }; virtual void* getRaw(std::size_t /*idx*/)=0; diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index babc534bd..067ad8c00 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -42,7 +42,7 @@ void addCtor(py::class_<Tensor, std::set<std::string> availableBackends = Tensor::getAvailableBackends(); if (availableBackends.find("cpu") != availableBackends.end()){ newTensor->setBackend("cpu"); - newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); + newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr), newTensor->size()); }else{ printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); } @@ -71,7 +71,7 @@ void init_Tensor(py::module& m){ (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); pyClassTensor.def(py::init<>()) - .def("set_backend", &Tensor::setBackend, py::arg("name")) + .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("dtype", &Tensor::dataType) .def("size", &Tensor::size) diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 195e2740b..19e3e70d6 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -99,7 +99,7 @@ void init_GraphView(py::module& m) { .def("forward_dims", &GraphView::forwardDims) .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) - .def("set_backend", &GraphView::setBackend, py::arg("backend")) + .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) // .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { // // TODO : Should return error if backend not compatible with get // if (idx >= b.size()) throw py::index_error(); diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index f9482eda2..3cfa4b15e 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -29,7 +29,7 @@ void init_Operator(py::module& m){ .def("nb_outputs", &Operator::nbOutputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDataType, py::arg("dataType")) - .def("set_backend", &Operator::setBackend, py::arg("name")) + .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0) .def("forward", &Operator::forward) // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) -- GitLab