Skip to content
Snippets Groups Projects
Commit 2ed4315d authored by Maxence Naud's avatar Maxence Naud
Browse files

fix GraphView::compile() binding & add showProducer option in GraphView::save()

parent 37a9cafe
No related branches found
No related tags found
No related merge requests found
...@@ -96,7 +96,7 @@ public: ...@@ -96,7 +96,7 @@ public:
* specified location. * specified location.
* @param path * @param path
*/ */
void save(std::string path, bool verbose = false) const; void save(std::string path, bool verbose = false, bool showProducers = true) const;
inline bool inView(NodePtr nodePtr) const { inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end(); return mNodes.find(nodePtr) != mNodes.end();
......
...@@ -23,7 +23,7 @@ namespace Aidge { ...@@ -23,7 +23,7 @@ namespace Aidge {
void init_GraphView(py::module& m) { void init_GraphView(py::module& m) {
py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView") py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView")
.def(py::init<>()) .def(py::init<>())
.def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, py::arg("show_producers") = true,
R"mydelimiter( R"mydelimiter(
Save the GraphView as a Mermaid graph in a .md file at the specified location. Save the GraphView as a Mermaid graph in a .md file at the specified location.
...@@ -97,7 +97,7 @@ void init_GraphView(py::module& m) { ...@@ -97,7 +97,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes) .def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name")) .def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims) .def("forward_dims", &GraphView::forwardDims)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype")) .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0)
.def("__call__", &GraphView::operator(), py::arg("connectors")) .def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
......
...@@ -55,7 +55,7 @@ std::string Aidge::GraphView::name() const { return mName; } ...@@ -55,7 +55,7 @@ std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; } void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(std::string path, bool verbose) const { void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w"); FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
std::fprintf(fp, std::fprintf(fp,
"%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, " "%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
...@@ -68,7 +68,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -68,7 +68,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &node_ptr : mNodes) {
const std::string currentType = node_ptr->type(); const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end()) if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0; typeCounter[currentType] = 0;
++typeCounter[currentType]; ++typeCounter[currentType];
std::string givenName = std::string givenName =
...@@ -83,13 +83,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -83,13 +83,18 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
givenName.c_str()); givenName.c_str());
} }
else { else {
std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), if ((currentType != "Producer") || showProducers) {
givenName.c_str()); std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
} }
} }
// Write every link // Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &node_ptr : mNodes) {
if ((node_ptr -> type() == "Producer") && !showProducers) {
continue;
}
IOIndex_t outputIdx = 0; IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) { for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) { for (auto child : childs) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment