/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #ifndef AIDGE_CORE_GRAPH_GRAPHVIEW_H_ #define AIDGE_CORE_GRAPH_GRAPHVIEW_H_ #include <map> #include <memory> #include <set> #include <string> #include <utility> #include <vector> #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" namespace Aidge { enum class DataType; /** * @brief Groupement of Nodes forming a computational graph on which properties and functions * can easily and safely be applied or run. */ class GraphView : public std::enable_shared_from_this<GraphView> { private: /// @brief Name of the graphview std::string mName; /// @brief Set of nodes included in the GraphView std::set<NodePtr> mNodes; /// @brief Set of nodes included in the graphview with names std::map<std::string, NodePtr> mNodeRegistry; /// @brief Nodes without input link std::set<NodePtr> mInputNodes; /// @brief Nodes without output link std::set<NodePtr> mOutputNodes; public: GraphView(std::string name="") : mName(name) { // ctor } // GraphView(std::set<NodePtr> nodes, std::string name="") // : mName(name) // { // add(nodes); // } bool operator==(const GraphView &gv) const { return mNodes == gv.mNodes; } NodePtr operator[](std::string name) { assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView."); return mNodeRegistry.at(name); } /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// Connector operator()(const std::vector<Connector> ctors); /////////////////////////////////////////////////////// // INNER /////////////////////////////////////////////////////// public: /** * @brief Name of the node. * @return std::string */ std::string name() const; /** * @brief Set the node name. * @warning Undefined behaviour when several Nodes have the same name. * @param name New name for the node. */ void setName(const std::string &name); /** * @brief Save the GraphView as a Mermaid graph in a .md file at the * specified location. * @param path */ void save(std::string path, bool verbose = false) const; inline bool inView(NodePtr nodePtr) const { return mNodes.find(nodePtr) != mNodes.end(); } /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; } /** @brief Get reference to the set of output Nodes. */ inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; } /** @brief Assess if the given Node is an input Node of the GraphView object. */ inline bool isInputNode(NodePtr nodePtr) const { return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false; } /** @brief Assess if the given Node is an output Node of the GraphView object. */ inline bool isOutputNode(NodePtr nodePtr) const { return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false; } /** * @brief List dataInput connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; /** * @brief List dataInput connections of the GraphView object's inputNodes. * @param name Name of the Node. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** * @brief List input connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; /** * @brief List input connections of the specified GraphView object's inputNode. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const; /** * @brief List output connections of the GraphView object's outputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; /** * @brief Specific i-th output connection of the GraphView object. * @param nodeName Name of the Node of which to show the output. * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs( std::string nodeName) const; /** * @brief Compute dimensions of input/output Tensors for each Operator of the * GraphView object's Nodes. */ void forwardDims(); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string &backend); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setDatatype(const DataType &datatype); /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// public: /** * @brief Get the parents Nodes of inputNodes. * @return std::set<NodePtr> */ std::set<NodePtr> getParents() const; /** * @brief Get parents Nodes of the specified Node. * @param nodeName Name of the Node. * @return std::vector<NodePtr> */ std::vector<NodePtr> getParents(const std::string nodeName) const; std::vector<std::vector<NodePtr>> getOrderedParents() const; /** * @brief Get the children Nodes of outputNodes. * @return std::set<NodePtr> */ std::set<NodePtr> getChildren() const; /** * @brief Get children Nodes of the specified Node. * @param nodeName Name of the Node. * @return std::vector<std::vector<NodePtr>> */ std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const; std::set<NodePtr> getChildren( const NodePtr otherNode) const; // TODO change it for a vector<vector> ? /** * @brief Get the Nodes pointed to by the GraphView object. * @return std::set<NodePtr> */ inline std::set<NodePtr> getNodes() const { return mNodes; } /** * @brief Get the operator with the corresponding name if it is in the * GraphView. * @param nodeName Name of the node. * @return NodePtr returns a new empty node if the one asked for * was not found. */ NodePtr getNode(const char *nodeName) const; /** * @brief Remove a Node from the current GraphView scope without affecting its connections. * @param nodePtr Node to remove * @param includeLearnableParam Whether learnable parameters should also be removed. Default true. */ void remove(NodePtr nodePtr, bool includeLearnableParam = true); // Surrounding nodes management void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID); /** * @brief Include a Node to the current GraphView object. * @param other_Nde Node to add. * @param includeLearnableParam Include non-data inputs, like weights and biases * in the GraphView automatically. Default: true. */ void add(NodePtr otherNode, bool includeLearnableParam = true); /** * @brief Include a set of Nodes to the current GraphView object. * @param otherNodes * @param includeLearnableParam */ void add(std::set<NodePtr> otherNodes, bool includeLearnableParam = true); /** * @brief Include every Node inside another GraphView to the current * GraphView. * @param other_graph GraphView containing the Nodes to include. */ void add(std::shared_ptr<GraphView> otherGraph); /** * @brief Include a Node in the current GraphView and link it to another * already contained Node. * * @param toOtherNode Pointer to the Node to add. * @param fromOutNode Pointer to the already included Node the new Node will * be linked to (it will become a parent of the new Node). If the GraphView * only has one output Node, then default to this Node. * @param fromTensor Ouput Tensor ID of the already included Node. Default to * 0. * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning * first available data input for the Node. */ void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr, const IOIndex_t fromTensor = IOIndex_t(0), IOIndex_t toTensor = gk_IODefaultIndex); /** * @brief Include a Node in the current GraphView and link it to another * already contained Node. * * @param toOtherNode Pointer to the Node to add. * @param fromOutNodeName Name of the already included Node the new Node will * be linked to (it will become a parent of the new Node). As a name is * optional, ensure such Node is in the GraphView or it will send back an * error message. * @param fromTensor Ouput Tensor ID of the already included Node. Default to * 0. * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning * first available data input for the Node. */ inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName, const IOIndex_t fromTensor = IOIndex_t(0), IOIndex_t toTensor = gk_IODefaultIndex) { assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() && "No Node with this name found in the GraphView."); addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor); } /** * @brief Include a GraphView content in the current GraphView and link * the two sets by linking one Node from each GraphView. * @param toOtherView Pointer to the GraphView whose content should be added. * @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the * connection. If the GraphView including the other one has only one output * Node, then it defaults to the first output Tensor of this Node. * @param toNode Pair of pointer to Node and Tensor ID for specifying the * connection. If the GraphView whose content is included has only one input * Node, then it defaults to the first available data input Tensor of this * Node. */ void addChild(std::shared_ptr<GraphView> toOtherView, std::pair<NodePtr, IOIndex_t> fromOutNode = std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)), std::pair<NodePtr, IOIndex_t> toNode = std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex)); /** * @brief Swap two Node instances if possible. * @param node * @param otherNode * @return true * @return false */ bool swap(Node &node, Node &otherNode); void link(std::string name1_inID, std::string name2_outID); void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes, IOIndex_t tensorIdx); /** * @brief Replace the current GraphView with the set of given Nodes if possible * @param newNodes Set of Nodes. * @return true * @return false */ bool replaceWith(std::set<NodePtr> newNodes); void updateInputNodes(); /** * @brief Process from zero the set of output Nodes. */ void updateOutputNodes(); private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// /** * @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object. * @return IOIndex_t */ IOIndex_t getNbDataInputs() const; /** * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object. * @return IOIndex_t */ IOIndex_t getNbFreeDataInputs() const; /** * @brief Update the set of inputNodes with a new Node, checking if it can be * added and removing any Node not part of mInputNode anymore. * @param nodePtr */ void updateInputNodes(NodePtr node); /** * @brief Update the set of outputNodes with a new Node, checking if it can be * added and removing any Node not part of mOutputNode anymore. * @param nodePtr */ void updateOutputNodes(NodePtr node); /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// void _forwardDims(std::set<NodePtr> listNodes); void removeInputNode(const std::string nodeName); void removeOutputNode(const std::string nodeName); }; } // namespace Aidge #endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */