Forked from
Eclipse Projects / aidge / aidge_core
2525 commits behind the upstream repository.
-
Maxence Naud authored
[Merge] GraphView.hpp [Doc] update documentation for Node.hpp and GraphView.hpp [Upd] IOIndex_t from signed<uint16_t> to uint6_t. IONb_t is replaced by IOIndex_t because both types represent the same thing. [Upd] Make many member-function arguments const for Node and GraphView to improve safety
Maxence Naud authored[Merge] GraphView.hpp [Doc] update documentation for Node.hpp and GraphView.hpp [Upd] IOIndex_t from signed<uint16_t> to uint6_t. IONb_t is replaced by IOIndex_t because both types represent the same thing. [Upd] Make many member-function arguments const for Node and GraphView to improve safety
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
GraphView.hpp 13.31 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_GRAPH_GRAPHVIEW_H__
#define __AIDGE_CORE_GRAPH_GRAPHVIEW_H__
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class DataType;
/**
* @brief Groupement of Nodes forming a computational graph on which properties and functions
* can easily and safely be applied or run.
*/
class GraphView : public std::enable_shared_from_this<GraphView> {
private:
/// @brief Name of the graphview
std::string mName;
/// @brief Set of nodes included in the GraphView
std::set<NodePtr> mNodes;
/// @brief Set of nodes included in the graphview with names
std::map<std::string, NodePtr> mNodeRegistry;
/// @brief Nodes without input link
std::set<NodePtr> mInputNodes;
/// @brief Nodes without output link
std::set<NodePtr> mOutputNodes;
public:
GraphView(std::string name="")
: mName(name)
{
// ctor
}
// GraphView(std::set<NodePtr> nodes, std::string name="")
// : mName(name)
// {
// add(nodes);
// }
bool operator==(const GraphView &gv) const
{
return mNodes == gv.mNodes;
}
NodePtr operator[](std::string name)
{
assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
return mNodeRegistry.at(name);
}
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Connector operator()(const std::vector<Connector> ctors);
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
public:
/**
* @brief Name of the node.
* @return std::string
*/
std::string name() const;
/**
* @brief Set the node name.
* @warning Undefined behaviour when several Nodes have the same name.
* @param name New name for the node.
*/
void setName(const std::string &name);
/**
* @brief Save the GraphView as a Mermaid graph in a .md file at the
* specified location.
* @param path
*/
void save(std::string path, bool verbose = false) const;
inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end();
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
public:
/** @brief Get reference to the set of input Nodes. */
inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; }
/** @brief Get reference to the set of output Nodes. */
inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; }
/** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const {
return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false;
}
/** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const {
return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false;
}
/**
* @brief List dataInput connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
/**
* @brief List dataInput connections of the GraphView object's inputNodes.
* @param name Name of the Node.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/**
* @brief List input connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
/**
* @brief List input connections of the specified GraphView object's inputNode.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const;
/**
* @brief List output connections of the GraphView object's outputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;
/**
* @brief Specific i-th output connection of the GraphView object.
* @param nodeName Name of the Node of which to show the output.
* @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
std::string nodeName) const;
/**
* @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes.
*/
void forwardDims();
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string &backend);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setDatatype(const DataType &datatype);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
public:
/**
* @brief Get the parents Nodes of inputNodes.
* @return std::set<NodePtr>
*/
std::set<NodePtr> getParents() const;
/**
* @brief Get parents Nodes of the specified Node.
* @param nodeName Name of the Node.
* @return std::vector<NodePtr>
*/
std::vector<NodePtr> getParents(const std::string nodeName) const;
std::vector<std::vector<NodePtr>> getOrderedParents() const;
/**
* @brief Get the children Nodes of outputNodes.
* @return std::set<NodePtr>
*/
std::set<NodePtr> getChildren() const;
/**
* @brief Get children Nodes of the specified Node.
* @param nodeName Name of the Node.
* @return std::vector<std::vector<NodePtr>>
*/
std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const;
std::set<NodePtr> getChildren(
const NodePtr otherNode) const; // TODO change it for a vector<vector> ?
/**
* @brief Get the Nodes pointed to by the GraphView object.
* @return std::set<NodePtr>
*/
inline std::set<NodePtr> getNodes() const { return mNodes; }
/**
* @brief Get the operator with the corresponding name if it is in the
* GraphView.
* @param nodeName Name of the node.
* @return NodePtr returns a new empty node if the one asked for
* was not found.
*/
NodePtr getNode(const char *nodeName) const;
/**
* @brief Remove a Node from the current GraphView scope without affecting its connections.
* @param nodePtr Node to remove
* @param includeLearnableParam Whether learnable parameters should also be removed. Default true.
*/
void remove(NodePtr nodePtr, bool includeLearnableParam = true);
// Surrounding nodes management
void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID);
/**
* @brief Include a Node to the current GraphView object.
* @param other_Nde Node to add.
* @param includeLearnableParam Include non-data inputs, like weights and biases
* in the GraphView automatically. Default: true.
*/
void add(NodePtr otherNode, bool includeLearnableParam = true);
/**
* @brief Include a set of Nodes to the current GraphView object.
* @param otherNodes
* @param includeLearnableParam
*/
void add(std::set<NodePtr> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include every Node inside another GraphView to the current
* GraphView.
* @param other_graph GraphView containing the Nodes to include.
*/
void add(std::shared_ptr<GraphView> otherGraph);
/**
* @brief Include a Node in the current GraphView and link it to another
* already contained Node.
*
* @param toOtherNode Pointer to the Node to add.
* @param fromOutNode Pointer to the already included Node the new Node will
* be linked to (it will become a parent of the new Node). If the GraphView
* only has one output Node, then default to this Node.
* @param fromTensor Ouput Tensor ID of the already included Node. Default to
* 0.
* @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
* first available data input for the Node.
*/
void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr,
const IOIndex_t fromTensor = IOIndex_t(0),
IOIndex_t toTensor = gk_IODefaultIndex);
/**
* @brief Include a Node in the current GraphView and link it to another
* already contained Node.
*
* @param toOtherNode Pointer to the Node to add.
* @param fromOutNodeName Name of the already included Node the new Node will
* be linked to (it will become a parent of the new Node). As a name is
* optional, ensure such Node is in the GraphView or it will send back an
* error message.
* @param fromTensor Ouput Tensor ID of the already included Node. Default to
* 0.
* @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
* first available data input for the Node.
*/
inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName,
const IOIndex_t fromTensor = IOIndex_t(0),
IOIndex_t toTensor = gk_IODefaultIndex) {
assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() &&
"No Node with this name found in the GraphView.");
addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
}
/**
* @brief Include a GraphView content in the current GraphView and link
* the two sets by linking one Node from each GraphView.
* @param toOtherView Pointer to the GraphView whose content should be added.
* @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the
* connection. If the GraphView including the other one has only one output
* Node, then it defaults to the first output Tensor of this Node.
* @param toNode Pair of pointer to Node and Tensor ID for specifying the
* connection. If the GraphView whose content is included has only one input
* Node, then it defaults to the first available data input Tensor of this
* Node.
*/
void addChild(std::shared_ptr<GraphView> toOtherView,
std::pair<NodePtr, IOIndex_t> fromOutNode =
std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)),
std::pair<NodePtr, IOIndex_t> toNode =
std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));
/**
* @brief Swap two Node instances if possible.
* @param node
* @param otherNode
* @return true
* @return false
*/
bool swap(Node &node, Node &otherNode);
void link(std::string name1_inID, std::string name2_outID);
void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes,
IOIndex_t tensorIdx);
/**
* @brief Replace the current GraphView with the set of given Nodes if possible
* @param newNodes Set of Nodes.
* @return true
* @return false
*/
bool replaceWith(std::set<NodePtr> newNodes);
void updateInputNodes();
/**
* @brief Process from zero the set of output Nodes.
*/
void updateOutputNodes();
private:
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
/**
* @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t getNbDataInputs() const;
/**
* @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t getNbFreeDataInputs() const;
/**
* @brief Update the set of inputNodes with a new Node, checking if it can be
* added and removing any Node not part of mInputNode anymore.
* @param nodePtr
*/
void updateInputNodes(NodePtr node);
/**
* @brief Update the set of outputNodes with a new Node, checking if it can be
* added and removing any Node not part of mOutputNode anymore.
* @param nodePtr
*/
void updateOutputNodes(NodePtr node);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes);
void removeInputNode(const std::string nodeName);
void removeOutputNode(const std::string nodeName);
};
} // namespace Aidge
#endif /* __AIDGE_CORE_GRAPH_GRAPHVIEW_H__ */