From 57dd2c6cbe689b0f11bcacbe2d7bc5373da3524d Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 3 Mar 2024 10:54:15 +0100 Subject: [PATCH] Fixed binding issue and seg fault --- python_binding/graph/pybind_GraphView.cpp | 14 +++++++++++++- src/recipes/FuseMulAdd.cpp | 14 ++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 94dcdd7f9..a41d0d928 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -89,7 +89,19 @@ void init_GraphView(py::module& m) { :type to_tensor: int )mydelimiter") - .def_static("replace", &GraphView::replace, py::arg("old_nodes"), py::arg("new_nodes"), + .def_static("replace", py::overload_cast<const std::shared_ptr<GraphView>&, const std::shared_ptr<GraphView>&>(&GraphView::replace), py::arg("old_graph"), py::arg("new_graph"), + R"mydelimiter( + Replace the old set of Nodes in a GraphView with the new set of given Nodes in a GraphView if possible in every GraphView. + + :param old_graph: GraphView of Nodes actually connected in GraphViews. + :type old_graph: GraphView + :param new_graph: GraphView of Nodes with inner connections already taken care of. + :type new_graph: GraphView + :return: Whether any replacement has been made. + :rtype: bool + )mydelimiter") + + .def_static("replace", py::overload_cast<const std::set<NodePtr>&, const std::set<NodePtr>&>(&GraphView::replace), py::arg("old_nodes"), py::arg("new_nodes"), R"mydelimiter( Replace the old set of Nodes with the new set of given Nodes if possible in every GraphView. diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index 3047ad715..f408959a1 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -47,18 +47,20 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< } std::shared_ptr<Node> weight = nullptr; - if (matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type - && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type) + if ((matmulNode->getParent(1) && !matmulNode->getParent(0)) + || (matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() != Producer_Op::Type)) { weight = matmulNode->getParent(1)->cloneSharedOperators(); } - else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type - && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type) + else if ((matmulNode->getParent(0) && !matmulNode->getParent(1)) + || (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() != Producer_Op::Type)) { weight = matmulNode->getParent(0)->cloneSharedOperators(); } - else if (matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type - && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) + else if (matmulNode->getParent(0) && matmulNode->getParent(0)->getOperator()->type() == Producer_Op::Type + && matmulNode->getParent(1) && matmulNode->getParent(1)->getOperator()->type() == Producer_Op::Type) { // If both inputs are producers, there is an ambiguity, but both options // result in a correct solution. -- GitLab