/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #ifndef AIDGE_CORE_GRAPH_NODE_H_ #define AIDGE_CORE_GRAPH_NODE_H_ #include <cassert> #include <memory> #include <set> #include <string> #include <vector> #include <utility> #include "aidge/graph/Connector.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/utils/Types.h" namespace Aidge { using NodePtr = std::shared_ptr<Node>; class GraphView; /** * @brief Object carrying the topological information of the computational graph. */ class Node : public std::enable_shared_from_this<Node> { private: struct weakCompare { bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const { // Compare the content of the weak_ptrs auto sharedA = a.lock(); auto sharedB = b.lock(); if (!sharedB) return false; // nothing after expired pointer if (!sharedA) return true; return sharedA < sharedB; // shared_ptr has a valid comparison operator } }; std::string mName; /** Name of the Node. Should be unique. */ std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */ const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator std::vector<NodePtr> mParents; /** List of parent node for each input (Parent --> Node --> Child) */ std::vector<std::vector<std::weak_ptr<Node>>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */ std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */ std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */ public: Node() = delete; /** * @brief Construct a new Node object associated with the input Operator. * @param op Operator giving the Node its number of connections. * @param name (optional) name for the Node. */ Node(std::shared_ptr<Operator> op, const std::string& name = ""); virtual ~Node() = default; friend bool operator==(const Node &lhs, const Node &rhs) { return lhs.shared_from_this() == rhs.shared_from_this(); } public: /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// /** * @brief Functional operator for user-friendly connection interface using an ordered set of Connectors. * @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index. * @return Connector */ Connector operator()(const std::vector<Connector> &ctors); public: /////////////////////////////////////////////////////// // INNER /////////////////////////////////////////////////////// /** * @brief Name of the Node. * @return std::string */ inline std::string name() const noexcept { return mName; } /** * @brief Set the Node name. * @warning Undefined behaviour when several Nodes have the same name. * @param name New name for the node. */ void setName(const std::string &name); /** * @brief Type of the node. * @return std::string */ inline std::string type() const { return mOperator->type(); } /////////////////////////////////////////////////////// // OPERATORS /////////////////////////////////////////////////////// /** * @brief Run forward() function of the associated Operator. */ void forward(); /** * @brief Run backward() function of the associated Operator. */ void backward(); /** * @brief Get the Operator object of the Node. * @return std::shared_ptr<Operator> */ inline std::shared_ptr<Operator> getOperator() const { return mOperator; } /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// /** * @brief Whether or not every input of the Node is linked to a Parent. * If true then the Node is ready to be executed. * @return true * @return false */ bool valid() const; /** * @brief List of pair <Parent, ID of the data intput>. When an input is not * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; /** * @brief List of pair <Parent, ID of the parent output>. When an input is not linked * to any Parent, the pair is <nullptr, gk_IODefaultIndex>. * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; /** * @brief Parent and its output Tensor ID linked to the inID-th input Tensor. * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. * @param inID * @return std::pair<std::shared_ptr<Node>, IOIndex_t> */ inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const { assert((inID != gk_IODefaultIndex) && (inID < nbInputs()) && "Input index out of bound."); return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]); } /** * @brief Set fix value for the specified input by creating a Producer wrapping the given Tensor. * * @param idx Input index. * @param tensor Constant Tensor to add as parent for specified index. */ void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor); /** * @brief Get the lowest index in the InputData Parent list equal to the * nullptr. * @return std::size_t */ inline IOIndex_t getFirstFreeDataInput() const { IOIndex_t i = 0; for (; (i < nbDataInputs()) && (input(i).second != gk_IODefaultIndex); ++i) {} // assert((i<nbDataInputs()) && "No free data input for Node"); return (i < nbDataInputs()) ? i : gk_IODefaultIndex; } IOIndex_t getNbFreeDataInputs() const; /** * @brief List input ids of children liked to outputs of the node * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; /** * @brief Children and their input Tensor ID linked to the outId-th output * Tensor. * @param outId * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> output(IOIndex_t outId) const; /** * @brief Number of inputs, including both data and learnable parameters. * @details [data, data, weight, bias] => 4 * @return IOIndex_t */ inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); } /** * @brief Number of input specifically for data * @details [data, data, weight, bias] => 2 * @return IOIndex_t */ inline IOIndex_t nbDataInputs() const noexcept { return getOperator()->nbDataInputs(); } /** * @brief Number of inputs linked to a Parent's output. * @return IOIndex_t */ IOIndex_t nbValidInputs() const; /** * @brief Getter for the number of Output Tensors of the Node. * @return IOIndex_t */ inline IOIndex_t nbOutputs() const noexcept { return getOperator()->nbOutputs(); } IOIndex_t nbValidOutputs() const; /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// /** * @brief Vector of pointers to each GraphView containing the object * @return std::vector<GraphView> */ inline std::set<std::shared_ptr<GraphView>> views() const noexcept { std::set<std::shared_ptr<GraphView>> res; for (const auto &v : mViews) { res.insert(v.lock()); } return res; } /** * @brief Add a GraphView pointer to the list of GraphView containing * the current Node. This feature allows transparent GraphViews. * @param graphPtr Pointer to GraphView to add to the list. */ inline void addView(const std::shared_ptr<GraphView> &graphPtr) { mViews.insert(std::weak_ptr<GraphView>(graphPtr)); } inline void removeView(const std::shared_ptr<GraphView> &graphPtr) { std::set<std::weak_ptr<GraphView>, weakCompare>::const_iterator viewIt = mViews.cbegin(); for (; (viewIt != mViews.cend()) && ((*viewIt).lock() != graphPtr) ; ++viewIt) {} mViews.erase(*viewIt); } /** * @brief Link another Node to an output of the current Node. * @param otherNode Pointer to the other Node. * @param outId ID of the current Node output to connect to the other Node. * Default to 0. * @param otherInId ID of the other Node input to connect to the current Node. * Default to the first avaible data input. */ void addChild(NodePtr otherNode, const IOIndex_t outId = IOIndex_t(0), IOIndex_t otherInId = gk_IODefaultIndex); /** * @brief Link a Node from a specific GraphView to the current Node. * @param otherView Pointer to the GraphView whose content should be * linked to the current Node. * @param outId ID of the output Tensor to connect to the other Node. * Default to 0. * @param otherInId Pair of pointer to Node and Tensor ID for specifying the * connection. If the GraphView whose content is linked has only one input * Node, then it defaults to the first available data input Tensor of this * Node. */ void addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId = IOIndex_t(0), std::pair<NodePtr, IOIndex_t> otherInId = std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex)); /** * @brief Get the list of parent Nodes. As an input is linked to a unique Node, * if none is linked then the parent is a nullptr. * @return std::vector<std::shared_ptr<Node>> */ std::vector<NodePtr> getParents() const; /** * @brief Get the pointer to parent of the specified input index. This pointer is nullptr if no parent is linked. * @param inId Input index. * @return std::shared_ptr<Node>& */ inline NodePtr &getParents(const IOIndex_t inId) { assert(inId != gk_IODefaultIndex); return mParents.at(inId); } /** * @brief Unlink the parent Node at the specified input index and return its pointer. * Return a nullptr is no parent was linked. * @param inId Input index. * @return std::shared_ptr<Node> */ NodePtr popParent(const IOIndex_t inId); bool removeParent(const IOIndex_t inId); /** * @brief Get the set of pointers to children Nodes linked to the current Node.object. * @details The returned set does not include any nullptr as an output maybe linked to * an undifined number of Nodes. It does not change the computation of its associated Operator. * @return std::set<std::shared_ptr<Node>>> */ std::set<NodePtr> getChildren() const; std::vector<std::vector<NodePtr>> getOrderedChildren() const; /** * @brief Get the list of children Nodes linked to the output at specified index. * @param outId Output index. * @return std::vector<std::shared_ptr<Node>> */ std::vector<NodePtr> getChildren(const IOIndex_t outId) const; /** * @brief Remove registered child from children list of specified output if possible. * If so, also remove current Node from child Node from parent. * @param std::shared_ptr<Node> Node to remove. * @param outId Output index. Default 0. * @return true Child found and removed for given output index. * @return false Child not found at given index. Nothing removed. */ bool removeChild(const NodePtr nodePtr, const IOIndex_t outId = 0); /** * @brief Remove every link of surrounding nodes to it and conversly */ void resetConnections(bool includeLearnableParam = false); /////////////////////////////////////////////////////// // CLONE /////////////////////////////////////////////////////// /** * @brief Clone the node keeping the same operator object instance. The new node has no connection. * @return NodePtr */ NodePtr cloneSharedOperators() const; /** * @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection. * @return NodePtr */ NodePtr cloneSharedProducers() const; /** * @brief Clone the node and its operator. The new node has no connection. * @return NodePtr */ NodePtr clone() const; /** * @brief Clone the node keeping the same operator object instance. The new node has no connection. * @param node Node to clone. * @return NodePtr */ static NodePtr cloneSharedOperators(NodePtr node) { return node->cloneSharedOperators(); } /** * @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection. * @param node Node to clone. * @return NodePtr */ static NodePtr cloneSharedProducers(NodePtr node) { return node->cloneSharedProducers(); } /** * @brief Clone the node and its operator. The new node has no connection. * @param node Node to clone. * @return NodePtr */ static NodePtr clone(NodePtr node) { return node->clone(); } friend std::shared_ptr<GraphView> GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const; private: /////////////////////////////////////////////////////// // OPERATORS /////////////////////////////////////////////////////// // cannot change operator for now // void setOperator(const std::shared_ptr<Operator> op_ptr); /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// /** * @brief Set the idInChildren parameter. * @param inID * @param newNodeOutID */ void setInputId(const IOIndex_t inID, const IOIndex_t newNodeOutID); /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// /** * @brief Add the given Node as a child for the current Node. * @param otherNode * @param outId * @param otherInId */ void addChildOp(NodePtr otherNode, const IOIndex_t outId, const IOIndex_t otherInId); /** * @brief Add the given GraphView's input Node as a child for the current Node * @param otherGraph * @param outId * @param otherInId pointer the GraphView's input Node and its input index. Defaults to the * only input Node if the GraphView has got one. */ void addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, std::pair<NodePtr, IOIndex_t> otherInId); /** * @brief Add a Node to the list of parents. * @param otherNode Node to add to parents list. * @param inId index for adding the parent. */ void addParent(const NodePtr otherNode, const IOIndex_t inId); }; } // namespace Aidge #endif /* AIDGE_CORE_GRAPH_NODE_H_ */