From d1fd75fd655ae02119d0079cc08186a104159bc1 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 16 Jan 2024 15:08:07 +0000 Subject: [PATCH] Add a trampoline to Node.add_child in order to set None default value to avoid conversion of nullptr -> std::shared_ptr. --- include/aidge/graph/Node.hpp | 2 +- python_binding/graph/pybind_Node.cpp | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 5ae4eb5d8..de2a7b6aa 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -140,7 +140,7 @@ public: /** * @brief List of pair <Parent, ID of the data intput>. When an input is not - * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. * Data inputs exclude inputs expecting parameters (weights or bias). * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index 1f655b50a..c4525dce8 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -63,12 +63,19 @@ void init_Node(py::module& m) { )mydelimiter") .def("add_child", - (void (Node::*)(std::shared_ptr<GraphView>, const IOIndex_t, - std::pair<std::shared_ptr<Node>, IOIndex_t>)) & - Node::addChild, + [](Node &self, std::shared_ptr<GraphView> other_graph, const IOIndex_t out_id=0, + py::object other_in_id = py::none()) { + std::pair<NodePtr, IOIndex_t> cpp_other_in_id; + // Note: default arg nullptr to allow python binding + if (other_in_id.is_none()) { + cpp_other_in_id = std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex); + }else{ + cpp_other_in_id = other_in_id.cast<std::pair<NodePtr, IOIndex_t>>(); + } + self.addChild(other_graph, out_id, cpp_other_in_id); + }, py::arg("other_graph"), py::arg("out_id") = 0, - py::arg("other_in_id") = - std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex), + py::arg("other_in_id") = py::none(), R"mydelimiter( Link a Node from a specific GraphView to the current Node. -- GitLab