Skip to content
Snippets Groups Projects
Commit 0efd2f60 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Minor fixes

parent 49620ef4
No related branches found
No related tags found
1 merge request!57Add Convert operator (a.k.a. Transmitter)
Pipeline #35435 passed
......@@ -16,8 +16,15 @@
#include <cstdio>
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge {
/**
* This class manages the raw data storage of a Tensor and provide generic copy
* primitives from other devices and from/to host.
* It can own the data or not (use setRawPtr() to set an external data owner).
* It only knows the data type and data capacity, but does not handle anything else.
*/
class TensorImpl {
public:
TensorImpl() = delete;
......@@ -90,10 +97,11 @@ public:
* UNSAFE: directly setting the device pointer may lead to undefined behavior
* if it does not match the required storage.
* @param ptr A valid device pointer.
* @param length Storage capacity at the provided pointer
*/
virtual void setRawPtr(void* /*ptr*/)
virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/)
{
printf("Cannot set raw pointer for backend %s\n", mBackend);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
};
virtual void* getRaw(std::size_t /*idx*/)=0;
......
......@@ -42,7 +42,7 @@ void addCtor(py::class_<Tensor,
std::set<std::string> availableBackends = Tensor::getAvailableBackends();
if (availableBackends.find("cpu") != availableBackends.end()){
newTensor->setBackend("cpu");
newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr));
newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr), newTensor->size());
}else{
printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n");
}
......@@ -71,7 +71,7 @@ void init_Tensor(py::module& m){
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>())
.def("set_backend", &Tensor::setBackend, py::arg("name"))
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0)
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("dtype", &Tensor::dataType)
.def("size", &Tensor::size)
......
......@@ -99,7 +99,7 @@ void init_GraphView(py::module& m) {
.def("forward_dims", &GraphView::forwardDims)
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
// .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
// // TODO : Should return error if backend not compatible with get
// if (idx >= b.size()) throw py::index_error();
......
......@@ -29,7 +29,7 @@ void init_Operator(py::module& m){
.def("nb_outputs", &Operator::nbOutputs)
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDataType, py::arg("dataType"))
.def("set_backend", &Operator::setBackend, py::arg("name"))
.def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0)
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment