diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 83bb4afeacdd6de181fd6738edad2229736854c8..60283039b709b783484ba0b1cf821497e5bb3a8f 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -51,6 +51,7 @@ void addCtor(py::class_<Tensor, return newTensor; }), py::arg("array"), py::arg("backend")="cpu") + .def(py::init<T>(), py::arg("val")) .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) ; @@ -73,6 +74,7 @@ void init_Tensor(py::module& m){ (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); pyClassTensor.def(py::init<>()) + .def(py::init<const std::vector<std::size_t>&>(), py::arg("dims")) .def(py::self + py::self) .def(py::self - py::self) .def(py::self * py::self) @@ -86,7 +88,7 @@ void init_Tensor(py::module& m){ .def("dtype", &Tensor::dataType) .def("size", &Tensor::size) .def("capacity", &Tensor::capacity) - .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) + .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize, py::arg("dims"), py::arg("strides") = std::vector<DimSize_t>()) .def("has_impl", &Tensor::hasImpl) .def("get_coord", &Tensor::getCoord) .def("get_idx", &Tensor::getIdx) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index a7959419ab22fae8443ee9cd3ca286874fa65725..9528e511be230cd8ac689876689f313782c9b0ab 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -175,10 +175,16 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd } size_t inputIdx = 0; - for (auto input : mInputNodes) { + for (const auto& input : mInputNodes) { if (input.first != nullptr) { - fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"→{}\"|{}_{}\n", inputIdx, inputIdx, + const auto& op_ = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator()); + if (op_->getInput(input.second) && (!op_->getInput(input.second)->empty())) { + fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"→{}{}\"|{}_{}\n", inputIdx, inputIdx, + input.second, op_->getInput(input.second)->dims(), input.first->type(), namePtrTable.at(input.first)); + } else { + fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"→{}\"|{}_{}\n", inputIdx, inputIdx, input.second, input.first->type(), namePtrTable.at(input.first)); + } } else { fmt::print(fp.get(), "input{}((in#{})):::inputCls\n", inputIdx, inputIdx);