diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index f11136adaaa3d23fa9d3dc5749dd5d6771cbc42c..de90e6b2d5f73cdddf860af9627458569d2f9572 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -336,6 +336,27 @@ public: */ void updateOutputNodes(); + /** + * @brief Make a clone of the graph with shared operators. It a new graph, with new, cloned nodes, but the new nodes refer to the same operators as the origin ones. + * @param newNodes Set of Nodes. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> cloneSharedOperators() const { + return clone(&Node::cloneSharedOperators); + } + + /** + * @brief Make a clone of the graph with shared producers. All the other operators are copied. + */ + inline std::shared_ptr<GraphView> cloneSharedProducers() const { + return clone(&Node::cloneSharedProducers); + } + + /** + * @brief Make a clone of the graph. Everything is cloned: nodes and operators. + */ + std::shared_ptr<GraphView> clone(NodePtr(*cloneNode)(NodePtr) = &Node::clone) const; + private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 11def52dbab30159e9e882fb19d16f1549aa3887..3f347f386f0c175f63fcb9c559e3491577faf658 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -350,6 +350,57 @@ public: */ void resetConnections(bool includeLearnableParam = false); + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + + /** + * @brief Clone the node keeping the same operator object instance. The new node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedOperators() const; + + /** + * @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedProducers() const; + + /** + * @brief Clone the node and its operator. The new node has no connection. + * @return NodePtr + */ + NodePtr clone() const; + + /** + * @brief Clone the node keeping the same operator object instance. The new node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedOperators(NodePtr node) { + return node->cloneSharedOperators(); + } + + /** + * @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedProducers(NodePtr node) { + return node->cloneSharedProducers(); + } + + /** + * @brief Clone the node and its operator. The new node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr clone(NodePtr node) { + return node->clone(); + } + + friend std::shared_ptr<GraphView> GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const; + private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index a0641032281c6bedb4459a0d08da1193d6375129..4d8692d2d69e9177b63bfdd36fe4c63be0f8a769 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -673,4 +673,72 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { mOutputNodes.erase(val); } } -} \ No newline at end of file +} + +std::shared_ptr<Aidge::GraphView> Aidge::GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const { + std::shared_ptr<GraphView> newGraph; + std::map<NodePtr, NodePtr> oldToNewNodes; + + // Map for old node -> new node correspondance + for (const std::shared_ptr<Node> &node_ptr : mNodes) { + oldToNewNodes[node_ptr] = cloneNode(node_ptr); + } + + // For each node, convert old node -> new node references + for (auto &oldToNewNode : oldToNewNodes) { + // The only GraphView for the new graph is the one we return at the end + oldToNewNode.second->addView(newGraph); + + if (!(oldToNewNode.second->name()).empty()) + newGraph->mNodeRegistry.insert(std::make_pair(oldToNewNode.second->name(), oldToNewNode.second)); + + // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr + size_t parentId = 0; + for (auto& parent : oldToNewNode.first->mParents) { + oldToNewNode.second->mParents.push_back(oldToNewNodes[parent]); + + if (oldToNewNodes[parent]) + oldToNewNode.second->mIdOutParents.push_back(oldToNewNode.first->mIdOutParents[parentId]); + else + oldToNewNode.second->mIdOutParents.push_back(gk_IODefaultIndex); + + ++parentId; + } + + // Connect child nodes. + size_t childId = 0; + for (const auto &childrenOfOneOutput : oldToNewNode.first->mChildren) { + oldToNewNode.second->mChildren.push_back(std::vector<std::weak_ptr<Node>>()); + oldToNewNode.second->mIdInChildren.push_back(std::vector<IOIndex_t>()); + + size_t j = 0; + for (const auto &oneChild : childrenOfOneOutput) { + if (oldToNewNodes[oneChild.lock()]) { + oldToNewNode.second->mChildren.back().push_back(std::weak_ptr<Node>(oldToNewNodes[oneChild.lock()])); + oldToNewNode.second->mIdInChildren.back().push_back(oldToNewNode.first->mIdInChildren[childId][j]); + } + + ++j; + } + + ++childId; + } + } + + // Copy remaining GraphView informations + newGraph->setName(mName); + + for (const auto& inputNode : mInputNodes) { + if (oldToNewNodes[inputNode]) { + newGraph->mInputNodes.insert(oldToNewNodes[inputNode]); + } + } + + for (const auto& outputNode : mOutputNodes) { + if (oldToNewNodes[outputNode]) { + newGraph->mOutputNodes.insert(oldToNewNodes[outputNode]); + } + } + + return newGraph; +} diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5fcc0e1139d8ccd9368eaba90231fb12370e761e..7a4dc54c30cf0f843c421effe355da65b4d89815 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -321,6 +321,26 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { } } + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + +Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { + return std::make_shared<Node>(mOperator, mName); +} + +Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { + std::shared_ptr<Operator> op = (op->type() == Producer_Op::Type) + ? mOperator + : std::make_shared<Operator>(*mOperator); + + return std::make_shared<Node>(op, mName); +} + +Aidge::NodePtr Aidge::Node::clone() const { + return std::make_shared<Node>(std::make_shared<Operator>(*mOperator), mName); +} + ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a42a9fdaedb0137d94f2b9924cd03adaa8132096 --- /dev/null +++ b/src/recipies/LabelGraph.cpp @@ -0,0 +1,37 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +using namespace Aidge; + +NodePtr nodeLabel(NodePtr node) { + // TODO: this is just a proof of concept right now! + if (node->type() == Conv_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); + // TODO: adapt the following code. + auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 3, 1, 1); + newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // By default, remove the node from the graph + return nullptr; +} + +std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph) { + return graph->clone(&nodeLabel); +}