From 1c612b18efb3b90cb9621ea47145a121cba65c18 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 7 Sep 2023 16:12:28 +0200 Subject: [PATCH] Improved API --- include/aidge/graph/GraphView.hpp | 18 ++++++++++++++---- include/aidge/graph/Node.hpp | 6 +++--- src/graph/GraphView.cpp | 2 +- src/recipies/LabelGraph.cpp | 4 ++-- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index de90e6b2d..deabfdec4 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -338,24 +338,34 @@ public: /** * @brief Make a clone of the graph with shared operators. It a new graph, with new, cloned nodes, but the new nodes refer to the same operators as the origin ones. - * @param newNodes Set of Nodes. * @return std::shared_ptr<GraphView> */ inline std::shared_ptr<GraphView> cloneSharedOperators() const { - return clone(&Node::cloneSharedOperators); + return cloneCallback(&Node::cloneSharedOperators); } /** * @brief Make a clone of the graph with shared producers. All the other operators are copied. + * @return std::shared_ptr<GraphView> */ inline std::shared_ptr<GraphView> cloneSharedProducers() const { - return clone(&Node::cloneSharedProducers); + return cloneCallback(&Node::cloneSharedProducers); } /** * @brief Make a clone of the graph. Everything is cloned: nodes and operators. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> clone() const { + return cloneCallback(&Node::clone); + } + + /** + * @brief This function clones the graph using a callback function for the node cloning, allowing to specify how each node should be cloned, or replaced by an other node type, or removed (i.e. replaced by identity). When a node is removed, the clone() method automatically finds the next valid parent in line, going backward in the graph and connects it if that makes sense without ambiguity (effectively treating the removed node as an identity operation). + * @param cloneNode Callback function to clone a node + * @return std::shared_ptr<GraphView> */ - std::shared_ptr<GraphView> clone(NodePtr(*cloneNode)(NodePtr) = &Node::clone) const; + std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const; private: /////////////////////////////////////////////////////// diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 9b7c1b0ea..d07b61c6a 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -373,7 +373,7 @@ public: NodePtr clone() const; /** - * @brief Clone the node keeping the same operator object instance. The new node has no connection. + * @brief Callback function to clone the node keeping the same operator object instance. The new node has no connection. * @param node Node to clone. * @return NodePtr */ @@ -382,7 +382,7 @@ public: } /** - * @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection. + * @brief Callback function to clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection. * @param node Node to clone. * @return NodePtr */ @@ -391,7 +391,7 @@ public: } /** - * @brief Clone the node and its operator. The new node has no connection. + * @brief Callback function to clone the node and its operator. The new node has no connection. * @param node Node to clone. * @return NodePtr */ diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 44d8fbb18..af1b68fd3 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -675,7 +675,7 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { } } -std::shared_ptr<Aidge::GraphView> Aidge::GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const { +std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); // Map for old node -> new node correspondance diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp index f85e61b54..af09baeb7 100644 --- a/src/recipies/LabelGraph.cpp +++ b/src/recipies/LabelGraph.cpp @@ -20,7 +20,7 @@ Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { if (node->type() == Conv_Op<2>::Type) { auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); // TODO: adapt the following code. - auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 3, 1, 1); + auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 1, 1, 1); newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>()); return std::make_shared<Node>(newOp, node->name()); } @@ -30,5 +30,5 @@ Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { } std::shared_ptr<Aidge::GraphView> Aidge::labelGraph(std::shared_ptr<GraphView> graph) { - return graph->clone(&nodeLabel); + return graph->cloneCallback(&nodeLabel); } -- GitLab