Skip to content
Snippets Groups Projects
Commit d6f3b1c6 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

GraphView cloning proposal + labelGraph proof of concept

parent 1b19726b
No related branches found
No related tags found
No related merge requests found
...@@ -336,6 +336,27 @@ public: ...@@ -336,6 +336,27 @@ public:
*/ */
void updateOutputNodes(); void updateOutputNodes();
/**
* @brief Make a clone of the graph with shared operators. It a new graph, with new, cloned nodes, but the new nodes refer to the same operators as the origin ones.
* @param newNodes Set of Nodes.
* @return std::shared_ptr<GraphView>
*/
inline std::shared_ptr<GraphView> cloneSharedOperators() const {
return clone(&Node::cloneSharedOperators);
}
/**
* @brief Make a clone of the graph with shared producers. All the other operators are copied.
*/
inline std::shared_ptr<GraphView> cloneSharedProducers() const {
return clone(&Node::cloneSharedProducers);
}
/**
* @brief Make a clone of the graph. Everything is cloned: nodes and operators.
*/
std::shared_ptr<GraphView> clone(NodePtr(*cloneNode)(NodePtr) = &Node::clone) const;
private: private:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TENSOR MANAGEMENT // TENSOR MANAGEMENT
......
...@@ -350,6 +350,57 @@ public: ...@@ -350,6 +350,57 @@ public:
*/ */
void resetConnections(bool includeLearnableParam = false); void resetConnections(bool includeLearnableParam = false);
///////////////////////////////////////////////////////
// CLONE
///////////////////////////////////////////////////////
/**
* @brief Clone the node keeping the same operator object instance. The new node has no connection.
* @return NodePtr
*/
NodePtr cloneSharedOperators() const;
/**
* @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection.
* @return NodePtr
*/
NodePtr cloneSharedProducers() const;
/**
* @brief Clone the node and its operator. The new node has no connection.
* @return NodePtr
*/
NodePtr clone() const;
/**
* @brief Clone the node keeping the same operator object instance. The new node has no connection.
* @param node Node to clone.
* @return NodePtr
*/
static NodePtr cloneSharedOperators(NodePtr node) {
return node->cloneSharedOperators();
}
/**
* @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection.
* @param node Node to clone.
* @return NodePtr
*/
static NodePtr cloneSharedProducers(NodePtr node) {
return node->cloneSharedProducers();
}
/**
* @brief Clone the node and its operator. The new node has no connection.
* @param node Node to clone.
* @return NodePtr
*/
static NodePtr clone(NodePtr node) {
return node->clone();
}
friend std::shared_ptr<GraphView> GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const;
private: private:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// OPERATORS // OPERATORS
......
...@@ -673,4 +673,72 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { ...@@ -673,4 +673,72 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
mOutputNodes.erase(val); mOutputNodes.erase(val);
} }
} }
} }
\ No newline at end of file
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph;
std::map<NodePtr, NodePtr> oldToNewNodes;
// Map for old node -> new node correspondance
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
oldToNewNodes[node_ptr] = cloneNode(node_ptr);
}
// For each node, convert old node -> new node references
for (auto &oldToNewNode : oldToNewNodes) {
// The only GraphView for the new graph is the one we return at the end
oldToNewNode.second->addView(newGraph);
if (!(oldToNewNode.second->name()).empty())
newGraph->mNodeRegistry.insert(std::make_pair(oldToNewNode.second->name(), oldToNewNode.second));
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0;
for (auto& parent : oldToNewNode.first->mParents) {
oldToNewNode.second->mParents.push_back(oldToNewNodes[parent]);
if (oldToNewNodes[parent])
oldToNewNode.second->mIdOutParents.push_back(oldToNewNode.first->mIdOutParents[parentId]);
else
oldToNewNode.second->mIdOutParents.push_back(gk_IODefaultIndex);
++parentId;
}
// Connect child nodes.
size_t childId = 0;
for (const auto &childrenOfOneOutput : oldToNewNode.first->mChildren) {
oldToNewNode.second->mChildren.push_back(std::vector<std::weak_ptr<Node>>());
oldToNewNode.second->mIdInChildren.push_back(std::vector<IOIndex_t>());
size_t j = 0;
for (const auto &oneChild : childrenOfOneOutput) {
if (oldToNewNodes[oneChild.lock()]) {
oldToNewNode.second->mChildren.back().push_back(std::weak_ptr<Node>(oldToNewNodes[oneChild.lock()]));
oldToNewNode.second->mIdInChildren.back().push_back(oldToNewNode.first->mIdInChildren[childId][j]);
}
++j;
}
++childId;
}
}
// Copy remaining GraphView informations
newGraph->setName(mName);
for (const auto& inputNode : mInputNodes) {
if (oldToNewNodes[inputNode]) {
newGraph->mInputNodes.insert(oldToNewNodes[inputNode]);
}
}
for (const auto& outputNode : mOutputNodes) {
if (oldToNewNodes[outputNode]) {
newGraph->mOutputNodes.insert(oldToNewNodes[outputNode]);
}
}
return newGraph;
}
...@@ -321,6 +321,26 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { ...@@ -321,6 +321,26 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
} }
} }
///////////////////////////////////////////////////////
// CLONE
///////////////////////////////////////////////////////
Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
return std::make_shared<Node>(mOperator, mName);
}
Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
std::shared_ptr<Operator> op = (op->type() == Producer_Op::Type)
? mOperator
: std::make_shared<Operator>(*mOperator);
return std::make_shared<Node>(op, mName);
}
Aidge::NodePtr Aidge::Node::clone() const {
return std::make_shared<Node>(std::make_shared<Operator>(*mOperator), mName);
}
///////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////
// private // private
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
using namespace Aidge;
NodePtr nodeLabel(NodePtr node) {
// TODO: this is just a proof of concept right now!
if (node->type() == Conv_Op<2>::Type) {
auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator());
// TODO: adapt the following code.
auto newOp = std::make_shared<GenericOperator_Op>("CenterCropPad", 3, 1, 1);
newOp->addParameter("KernelDims", op->get<ConvParam::KernelDims>());
return std::make_shared<Node>(newOp, node->name());
}
// By default, remove the node from the graph
return nullptr;
}
std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph) {
return graph->clone(&nodeLabel);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment