From 2ed4315dfc72a0e0b809bb228dfbaa09928bab26 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 5 Jan 2024 16:46:04 +0000 Subject: [PATCH] fix GraphView::compile() binding & add showProducer option in GraphView::save() --- include/aidge/graph/GraphView.hpp | 2 +- python_binding/graph/pybind_GraphView.cpp | 4 ++-- src/graph/GraphView.cpp | 13 +++++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 0fe66e4b6..813301a14 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -96,7 +96,7 @@ public: * specified location. * @param path */ - void save(std::string path, bool verbose = false) const; + void save(std::string path, bool verbose = false, bool showProducers = true) const; inline bool inView(NodePtr nodePtr) const { return mNodes.find(nodePtr) != mNodes.end(); diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 32151a66a..8e0da01c8 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -23,7 +23,7 @@ namespace Aidge { void init_GraphView(py::module& m) { py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView") .def(py::init<>()) - .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, + .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, py::arg("show_producers") = true, R"mydelimiter( Save the GraphView as a Mermaid graph in a .md file at the specified location. @@ -97,7 +97,7 @@ void init_GraphView(py::module& m) { .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) .def("forward_dims", &GraphView::forwardDims) - .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype")) + .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0) .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index c2439a459..968e98e75 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -55,7 +55,7 @@ std::string Aidge::GraphView::name() const { return mName; } void Aidge::GraphView::setName(const std::string &name) { mName = name; } -void Aidge::GraphView::save(std::string path, bool verbose) const { +void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const { FILE *fp = std::fopen((path + ".mmd").c_str(), "w"); std::fprintf(fp, "%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, " @@ -68,7 +68,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { for (const std::shared_ptr<Node> &node_ptr : mNodes) { const std::string currentType = node_ptr->type(); if (typeCounter.find(currentType) == typeCounter.end()) - typeCounter[currentType] = 0; + typeCounter[currentType] = 0; ++typeCounter[currentType]; std::string givenName = @@ -83,13 +83,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { givenName.c_str()); } else { - std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), - givenName.c_str()); + if ((currentType != "Producer") || showProducers) { + std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } } } // Write every link for (const std::shared_ptr<Node> &node_ptr : mNodes) { + if ((node_ptr -> type() == "Producer") && !showProducers) { + continue; + } IOIndex_t outputIdx = 0; for (auto childs : node_ptr->getOrderedChildren()) { for (auto child : childs) { -- GitLab