Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
1397 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Node.hpp 15.65 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_GRAPH_NODE_H_
#define AIDGE_CORE_GRAPH_NODE_H_

#include <cassert>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <utility>

#include "aidge/graph/Connector.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {

using NodePtr = std::shared_ptr<Node>;

class GraphView;

/**
 * @brief Object carrying the topological information of the computational graph.
 */
class Node : public std::enable_shared_from_this<Node> {
private:
  struct weakCompare {
      bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const {
          // Compare the content of the weak_ptrs
          auto sharedA = a.lock();
          auto sharedB = b.lock();
          if (!sharedB) return false; // nothing after expired pointer
          if (!sharedA) return true;
          return sharedA < sharedB; // shared_ptr has a valid comparison operator
      }
  };
  std::string mName; /** Name of the Node. Should be unique. */

  std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */
  const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator

  std::vector<NodePtr> mParents; /** List of parent node for each input (Parent --> Node --> Child) */
  std::vector<std::vector<std::weak_ptr<Node>>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */
  std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */
  std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */

public:
  Node() = delete;

  /**
   * @brief Construct a new Node object associated with the input Operator.
   * @param op Operator giving the Node its number of connections.
   * @param name (optional) name for the Node.
   */
  Node(std::shared_ptr<Operator> op, const std::string& name = "");

  virtual ~Node() = default;

  friend bool operator==(const Node &lhs, const Node &rhs) {
    return lhs.shared_from_this() == rhs.shared_from_this();
  }

public:
  ///////////////////////////////////////////////////////
  //        FUNCTIONAL DESCRIPTION
  ///////////////////////////////////////////////////////

  /**
   * @brief Functional operator for user-friendly connection interface using an ordered set of Connectors.
   * @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index.
   * @return Connector
   */
  Connector operator()(const std::vector<Connector> &ctors);

public:
  ///////////////////////////////////////////////////////
  //        INNER
  ///////////////////////////////////////////////////////

  /**
   * @brief Name of the Node.
   * @return std::string
   */
  inline std::string name() const noexcept { return mName; }

  /**
   * @brief Set the Node name.
   * @warning Undefined behaviour when several Nodes have the same name.
   * @param name New name for the node.
   */
  void setName(const std::string &name);

  /**
   * @brief Type of the node.
   * @return std::string
   */
  inline std::string type() const { return mOperator->type(); }

  ///////////////////////////////////////////////////////
  //        OPERATORS
  ///////////////////////////////////////////////////////

  /**
   * @brief Run forward() function of the associated Operator.
   */
  void forward();

  /**
   * @brief Run backward() function of the associated Operator.
   */
  void backward();

  /**
   * @brief Get the Operator object of the Node.
   * @return std::shared_ptr<Operator>
   */
  inline std::shared_ptr<Operator> getOperator() const { return mOperator; }

  ///////////////////////////////////////////////////////
  //        TENSOR MANAGEMENT
  ///////////////////////////////////////////////////////

  /**
   * @brief Whether or not every input of the Node is linked to a Parent.
   * If true then the Node is ready to be executed.
   * @return true
   * @return false
   */
  bool valid() const;
  /**
   * @brief List of pair <Parent, ID of the data intput>. When an input is not
   * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
   * Data inputs exclude inputs expecting parameters (weights or bias).
   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
   */
  std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;

  /**
   * @brief List of pair <Parent, ID of the parent output>. When an input is not linked
   * to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
   */
  std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;

  /**
   * @brief Parent and its output Tensor ID linked to the inID-th input Tensor.
   * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
   * @param inID
   * @return std::pair<std::shared_ptr<Node>, IOIndex_t>
   */
  inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const {
    assert((inID != gk_IODefaultIndex) && (inID < nbInputs()) && "Input index out of bound.");
    return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]);
  }


  /**
   * @brief Get the lowest index in the InputData Parent list equal to the
   * nullptr.
   * Data inputs exclude inputs expecting parameters (weights or bias).
   * @return std::size_t
   */
  inline IOIndex_t getFirstFreeDataInput() const {
    IOIndex_t i = 0;
    for (; (i < nbData()) && (input(i).second != gk_IODefaultIndex); ++i) {}
    // assert((i<nbData()) && "No free data input for Node");
    return (i < nbData()) ? i : gk_IODefaultIndex;
  }


  IOIndex_t getNbFreeDataInputs() const;

  /**
   * @brief List input ids of children linked to outputs of the node. The vector
   * size is garanteed to match the number of outputs of the node. If there is
   * no connection to a given output, the corresponding sub-vector will be empty.
   * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
   * IOIndex_t>>>
   */
  std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;

  /**
   * @brief Children and their input Tensor ID linked to the outId-th output
   * Tensor.
   * @param outId
   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
   */
  std::vector<std::pair<NodePtr, IOIndex_t>>
  output(IOIndex_t outId) const;

  /**
   * @brief Number of inputs, including both data and learnable parameters.
   * @details [data, data, weight, bias] => 4
   * @return IOIndex_t
   */
  inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); }

  /**
   * @brief Number of input specifically for data.
   * Data inputs exclude inputs expecting parameters (weights or bias).
   * @details [data, data, weight, bias] => 2
   * @return IOIndex_t
   */
  inline IOIndex_t nbData() const noexcept {
    return getOperator()->nbData();
  }

  /**
   * @brief Number of inputs linked to a Parent's output.
   * @return IOIndex_t
   */
  IOIndex_t nbValidInputs() const;

  /**
   * @brief Getter for the number of Output Tensors of the Node.
   * @return IOIndex_t
   */
  inline IOIndex_t nbOutputs() const noexcept { return getOperator()->nbOutputs(); }

  IOIndex_t nbValidOutputs() const;

  ///////////////////////////////////////////////////////
  //        TOPOLOGY
  ///////////////////////////////////////////////////////

  /**
   * @brief Vector of pointers to each GraphView containing the object
   * @return std::vector<GraphView>
   */
  inline std::set<std::shared_ptr<GraphView>> views() const noexcept {
    std::set<std::shared_ptr<GraphView>> res;
    for (const auto &v : mViews) {
      res.insert(v.lock());
    }
    return res;
  }

  /**
   * @brief Add a GraphView pointer to the list of GraphView containing
   * the current Node. This feature allows transparent GraphViews.
   * @param graphPtr Pointer to GraphView to add to the list.
   */
  inline void addView(const std::shared_ptr<GraphView> &graphPtr) {
    mViews.insert(std::weak_ptr<GraphView>(graphPtr));
  }

  inline void removeView(const std::shared_ptr<GraphView> &graphPtr) {
    mViews.erase(graphPtr);
  }

  /**
   * @brief Link another Node to an output of the current Node.
   * @param otherNode Pointer to the other Node.
   * @param outId ID of the current Node output to connect to the other Node.
   * Default to 0.
   * @param otherInId ID of the other Node input to connect to the current Node.
   * Default to the first avaible data input.
   */
  void addChild(NodePtr otherNode,
                const IOIndex_t outId = IOIndex_t(0),
                IOIndex_t otherInId = gk_IODefaultIndex);

  /**
   * @brief Link a Node from a specific GraphView to the current Node.
   * @param otherView Pointer to the GraphView whose content should be
   * linked to the current Node.
   * @param outId ID of the output Tensor to connect to the other Node.
   * Default to 0.
   * @param otherInId Pair of pointer to Node and Tensor ID for specifying the
   * connection. If the GraphView whose content is linked has only one input
   * Node, then it defaults to the first available data input Tensor of this
   * Node.
   */
  void addChild(std::shared_ptr<GraphView> otherView,
                const IOIndex_t outId = IOIndex_t(0),
                std::pair<NodePtr, IOIndex_t> otherInId =
                std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));

  /**
   * @brief Get the list of parent Nodes. As an input is linked to a unique Node,
   * if none is linked then the parent is a nullptr.
   * @return std::vector<std::shared_ptr<Node>>
   */
  std::vector<NodePtr> getParents() const;

  /**
   * @brief Get the pointer to parent of the specified input index. This pointer is nullptr if no parent is linked.
   * @param inId Input index.
   * @return std::shared_ptr<Node>&
   */
  inline NodePtr &getParent(const IOIndex_t inId) {
    assert(inId != gk_IODefaultIndex);
    return mParents.at(inId);
  }

  /**
   * @brief Unlink the parent Node at the specified input index and return its pointer.
   * Return a nullptr is no parent was linked.
   * @param inId Input index.
   * @return std::shared_ptr<Node>
   */
  NodePtr popParent(const IOIndex_t inId);

  bool removeParent(const IOIndex_t inId);

  /**
   * @brief Get the set of pointers to children Nodes linked to the current Node.object.
   * @details The returned set does not include any nullptr as an output maybe linked to
   * an undifined number of Nodes. It does not change the computation of its associated Operator.
   * @return std::set<std::shared_ptr<Node>>>
   */
  std::set<NodePtr> getChildren() const;

  std::vector<std::vector<NodePtr>> getOrderedChildren() const;

  /**
   * @brief Get the list of children Nodes linked to the output at specified index.
   * @param outId Output index.
   * @return std::vector<std::shared_ptr<Node>>
   */
  std::vector<NodePtr> getChildren(const IOIndex_t outId) const;

  /**
   * @brief Remove registered child from children list of specified output if possible.
   * If so, also remove current Node from child Node from parent.
   * @param std::shared_ptr<Node> Node to remove.
   * @param outId Output index. Default 0.
   * @return true Child found and removed for given output index.
   * @return false Child not found at given index. Nothing removed.
   */
  bool removeChild(const NodePtr nodePtr, const IOIndex_t outId = 0);

  /**
   * @brief Remove every link of surrounding nodes to it and conversly
   */
  void resetConnections(bool includeLearnableParam = false);

  ///////////////////////////////////////////////////////
  //        CLONE
  ///////////////////////////////////////////////////////

  /**
   * @brief Clone the current Node. The Operator attribute of the new Node is not copied but shared with the current Node. The new node has no connection.
   * @return NodePtr
   */
  NodePtr cloneSharedOperators() const;

  /**
   * @brief Clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection.
   * @return NodePtr
   */
  NodePtr cloneSharedProducers() const;

  /**
   * @brief Clone the Node and its Operator. The new Node has no connection.
   * @return NodePtr
   */
  NodePtr clone() const;

  /**
   * @brief Callback function to clone the Node keeping the same Operator object instance. The new Node has no connection.
   * @param node Node to clone.
   * @return NodePtr
   */
  static NodePtr cloneSharedOperators(NodePtr node) {
    return node->cloneSharedOperators();
  }

  /**
   * @brief Callback function to clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection.
   * @param node Node to clone.
   * @return NodePtr
   */
  static NodePtr cloneSharedProducers(NodePtr node) {
    return node->cloneSharedProducers();
  }

  /**
   * @brief Callback function to clone the Node and its Operator. The new Node has no connection.
   * @param node Node to clone.
   * @return NodePtr
   */
  static NodePtr clone(NodePtr node) {
    return node->clone();
  }


  /**
   * @brief  Get the set of pointers to connected node at a distance of a delta.
   * @details the recution are cut
   * Return a nullptr is nofing found.
   * @param delta Input delta.
   * @return std::shared_ptr<Node>
   */

  std::set<NodePtr> getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee);


private:
  ///////////////////////////////////////////////////////
  //        OPERATORS
  ///////////////////////////////////////////////////////

  // cannot change operator for now
  // void setOperator(const std::shared_ptr<Operator> op_ptr);

  ///////////////////////////////////////////////////////
  //        TENSOR MANAGEMENT
  ///////////////////////////////////////////////////////

  /**
   * @brief Set the idInChildren parameter.
   * @param inID
   * @param newNodeOutID
   */
  void setInputId(const IOIndex_t inID, const IOIndex_t newNodeOutID);

  ///////////////////////////////////////////////////////
  //        TOPOLOGY
  ///////////////////////////////////////////////////////

  /**
   * @brief Add the given Node as a child for the current Node.
   * @param otherNode
   * @param outId
   * @param otherInId
   */
  void addChildOp(NodePtr otherNode, const IOIndex_t outId,
                  const IOIndex_t otherInId);

  /**
   * @brief Add the given GraphView's input Node as a child for the current Node
   * @param otherGraph
   * @param outId
   * @param otherInId pointer the GraphView's input Node and its input index. Defaults to the
   * only input Node if the GraphView has got one.
   */
  void addChildView(std::shared_ptr<GraphView> otherGraph,
                    const IOIndex_t outId,
                    std::pair<NodePtr, IOIndex_t> otherInId);

  /**
   * @brief Add a Node to the list of parents.
   * @param otherNode Node to add to parents list.
   * @param inId index for adding the parent.
   */
  void addParent(const NodePtr otherNode, const IOIndex_t inId);

  // OPERATOR FUNCTIONNAL but commented out to avoid iostream inclusion
  // /**
  //  * @brief operator<< overload to ease print & debug of nodes
  //  * @param[inout] ostream to print to 
  //  * @param[in] n node to print
  //  */
  // friend std::ostream& operator << (std::ostream& os, Node& n); 
};

} // namespace Aidge


#endif /* AIDGE_CORE_GRAPH_NODE_H_ */