diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 94dcdd7f94e2e5b959742577506a7869ca783baf..a41d0d92835be2b5ef07d30c4a5233da1e3906b7 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -89,7 +89,19 @@ void init_GraphView(py::module& m) { :type to_tensor: int )mydelimiter") - .def_static("replace", &GraphView::replace, py::arg("old_nodes"), py::arg("new_nodes"), + .def_static("replace", py::overload_cast<const std::shared_ptr<GraphView>&, const std::shared_ptr<GraphView>&>(&GraphView::replace), py::arg("old_graph"), py::arg("new_graph"), + R"mydelimiter( + Replace the old set of Nodes in a GraphView with the new set of given Nodes in a GraphView if possible in every GraphView. + + :param old_graph: GraphView of Nodes actually connected in GraphViews. + :type old_graph: GraphView + :param new_graph: GraphView of Nodes with inner connections already taken care of. + :type new_graph: GraphView + :return: Whether any replacement has been made. + :rtype: bool + )mydelimiter") + + .def_static("replace", py::overload_cast<const std::set<NodePtr>&, const std::set<NodePtr>&>(&GraphView::replace), py::arg("old_nodes"), py::arg("new_nodes"), R"mydelimiter( Replace the old set of Nodes with the new set of given Nodes if possible in every GraphView. diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index 3047ad71563cd6d66979a0d70b950784d4b6ee7e..f408959a13d007853c24e30c1ef683648cf9c200 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -47,18 +47,20 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< } std::shared_ptr<Node> weight = nullptr; - if (matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type - && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type) + if ((matmulNode->getParent(1) && !matmulNode->getParent(0)) + || (matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)) { weight = matmulNode->getParent(1)->cloneSharedOperators(); } - else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type - && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type) + else if ((matmulNode->getParent(0) && !matmulNode->getParent(1)) + || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type)) { weight = matmulNode->getParent(0)->cloneSharedOperators(); } - else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type - && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) + else if (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) { // If both inputs are producers, there is an ambiguity, but both options // result in a correct solution.