Skip to content
Snippets Groups Projects
Commit 956bb3f0 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'chore_26-compare-layer-import-export-onnxruntime' into 'dev'

Chore_26_compare-layer-import-export-onnx

See merge request !157
parents 63779e04 edc2e9f0
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!157Chore_26_compare-layer-import-export-onnx
Pipeline #50321 passed
......@@ -51,6 +51,7 @@ void addCtor(py::class_<Tensor,
return newTensor;
}), py::arg("array"), py::arg("backend")="cpu")
.def(py::init<T>(), py::arg("val"))
.def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set)
.def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set)
;
......@@ -73,6 +74,7 @@ void init_Tensor(py::module& m){
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>())
.def(py::init<const std::vector<std::size_t>&>(), py::arg("dims"))
.def(py::self + py::self)
.def(py::self - py::self)
.def(py::self * py::self)
......@@ -86,7 +88,7 @@ void init_Tensor(py::module& m){
.def("dtype", &Tensor::dataType)
.def("size", &Tensor::size)
.def("capacity", &Tensor::capacity)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize, py::arg("dims"), py::arg("strides") = std::vector<DimSize_t>())
.def("has_impl", &Tensor::hasImpl)
.def("get_coord", &Tensor::getCoord)
.def("get_idx", &Tensor::getIdx)
......
......@@ -175,10 +175,16 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
}
size_t inputIdx = 0;
for (auto input : mInputNodes) {
for (const auto& input : mInputNodes) {
if (input.first != nullptr) {
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}\"|{}_{}\n", inputIdx, inputIdx,
const auto& op_ = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator());
if (op_->getInput(input.second) && (!op_->getInput(input.second)->empty())) {
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}{}\"|{}_{}\n", inputIdx, inputIdx,
input.second, op_->getInput(input.second)->dims(), input.first->type(), namePtrTable.at(input.first));
} else {
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}\"|{}_{}\n", inputIdx, inputIdx,
input.second, input.first->type(), namePtrTable.at(input.first));
}
}
else {
fmt::print(fp.get(), "input{}((in#{})):::inputCls\n", inputIdx, inputIdx);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment