Skip to content
Snippets Groups Projects
Commit 0efd2f60 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Minor fixes

parent 49620ef4
No related branches found
No related tags found
No related merge requests found
......@@ -16,8 +16,15 @@
#include <cstdio>
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge {
/**
* This class manages the raw data storage of a Tensor and provide generic copy
* primitives from other devices and from/to host.
* It can own the data or not (use setRawPtr() to set an external data owner).
* It only knows the data type and data capacity, but does not handle anything else.
*/
class TensorImpl {
public:
TensorImpl() = delete;
......@@ -90,10 +97,11 @@ public:
* UNSAFE: directly setting the device pointer may lead to undefined behavior
* if it does not match the required storage.
* @param ptr A valid device pointer.
* @param length Storage capacity at the provided pointer
*/
virtual void setRawPtr(void* /*ptr*/)
virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/)
{
printf("Cannot set raw pointer for backend %s\n", mBackend);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
};
virtual void* getRaw(std::size_t /*idx*/)=0;
......
......@@ -42,7 +42,7 @@ void addCtor(py::class_<Tensor,
std::set<std::string> availableBackends = Tensor::getAvailableBackends();
if (availableBackends.find("cpu") != availableBackends.end()){
newTensor->setBackend("cpu");
newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr));
newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr), newTensor->size());
}else{
printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n");
}
......@@ -71,7 +71,7 @@ void init_Tensor(py::module& m){
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>())
.def("set_backend", &Tensor::setBackend, py::arg("name"))
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0)
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("dtype", &Tensor::dataType)
.def("size", &Tensor::size)
......
......@@ -99,7 +99,7 @@ void init_GraphView(py::module& m) {
.def("forward_dims", &GraphView::forwardDims)
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
// .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
// // TODO : Should return error if backend not compatible with get
// if (idx >= b.size()) throw py::index_error();
......
......@@ -29,7 +29,7 @@ void init_Operator(py::module& m){
.def("nb_outputs", &Operator::nbOutputs)
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDataType, py::arg("dataType"))
.def("set_backend", &Operator::setBackend, py::arg("name"))
.def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0)
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment