Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2435 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Node.hpp 15.17 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_GRAPH_NODE_H_
#define AIDGE_CORE_GRAPH_NODE_H_

#include <cassert>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <utility>

#include "aidge/graph/Connector.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {

using NodePtr = std::shared_ptr<Node>;

class GraphView;

/**
 * @brief Object carrying the topological information of the computational graph.
 */
class Node : public std::enable_shared_from_this<Node> {
private:
  struct weakCompare {
      bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const {
          // Compare the content of the weak_ptrs
          auto sharedA = a.lock();
          auto sharedB = b.lock();
          if (!sharedB) return false; // nothing after expired pointer
          if (!sharedA) return true;
          return sharedA < sharedB; // shared_ptr has a valid comparison operator
      }
  };
  std::string mName; /** Name of the Node. Should be unique. */

  std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */
  const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator

  std::vector<NodePtr> mParents; /** List of parent node for each input (Parent --> Node --> Child) */
  std::vector<std::vector<std::weak_ptr<Node>>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */
  std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */
  std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */

public:
  Node() = delete;

  /**
   * @brief Construct a new Node object associated with the input Operator.
   * @param op Operator giving the Node its number of connections.
   * @param name (optional) name for the Node.
   */
  Node(std::shared_ptr<Operator> op, const std::string& name = "");

  virtual ~Node() = default;

  friend bool operator==(const Node &lhs, const Node &rhs) {
    return lhs.shared_from_this() == rhs.shared_from_this();
  }

public:
  ///////////////////////////////////////////////////////
  //        FUNCTIONAL DESCRIPTION
  ///////////////////////////////////////////////////////

  /**
   * @brief Functional operator for user-friendly connection interface using an ordered set of Connectors.
   * @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index.
   * @return Connector
   */
  Connector operator()(const std::vector<Connector> &ctors);

public:
  ///////////////////////////////////////////////////////
  //        INNER
  ///////////////////////////////////////////////////////

  /**
   * @brief Name of the Node.
   * @return std::string
   */
  inline std::string name() const noexcept { return mName; }

  /**
   * @brief Set the Node name.
   * @warning Undefined behaviour when several Nodes have the same name.
   * @param name New name for the node.
   */
  void setName(const std::string &name);

  /**
   * @brief Type of the node.
   * @return std::string
   */
  inline std::string type() const { return mOperator->type(); }

  ///////////////////////////////////////////////////////
  //        OPERATORS
  ///////////////////////////////////////////////////////

  /**
   * @brief Run forward() function of the associated Operator.
   */
  void forward();

  /**
   * @brief Run backward() function of the associated Operator.
   */
  void backward();

  /**
   * @brief Get the Operator object of the Node.
   * @return std::shared_ptr<Operator>
   */
  inline std::shared_ptr<Operator> getOperator() const { return mOperator; }

  ///////////////////////////////////////////////////////
  //        TENSOR MANAGEMENT
  ///////////////////////////////////////////////////////

  /**
   * @brief Whether or not every input of the Node is linked to a Parent.
   * If true then the Node is ready to be executed.
   * @return true
   * @return false
   */
  bool valid() const;
  /**
   * @brief List of pair <Parent, ID of the data intput>. When an input is not
   * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
   */
  std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;

  /**
   * @brief List of pair <Parent, ID of the parent output>. When an input is not linked
   * to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
   */
  std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;

  /**
   * @brief Parent and its output Tensor ID linked to the inID-th input Tensor.
   * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
   * @param inID
   * @return std::pair<std::shared_ptr<Node>, IOIndex_t>
   */
  inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const {
    assert((inID != gk_IODefaultIndex) && (inID < nbInputs()) && "Input index out of bound.");
    return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]);
  }

  /**
   * @brief Set fix value for the specified input by creating a Producer wrapping the given Tensor.
   *
   * @param idx Input index.
   * @param tensor Constant Tensor to add as parent for specified index.
   */
  void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor);

  /**
   * @brief Get the lowest index in the InputData Parent list equal to the
   * nullptr.
   * @return std::size_t
   */
  inline IOIndex_t getFirstFreeDataInput() const {
    IOIndex_t i = 0;
    for (; (i < nbDataInputs()) && (input(i).second != gk_IODefaultIndex); ++i) {}
    // assert((i<nbDataInputs()) && "No free data input for Node");
    return (i < nbDataInputs()) ? i : gk_IODefaultIndex;
  }


  IOIndex_t getNbFreeDataInputs() const;

  /**
   * @brief List input ids of children liked to outputs of the node
   * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
   * IOIndex_t>>>
   */
  std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;

  /**
   * @brief Children and their input Tensor ID linked to the outId-th output
   * Tensor.
   * @param outId
   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
   */
  std::vector<std::pair<NodePtr, IOIndex_t>>
  output(IOIndex_t outId) const;

  /**
   * @brief Number of inputs, including both data and learnable parameters.
   * @details [data, data, weight, bias] => 4
   * @return IOIndex_t
   */
  inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); }

  /**
   * @brief Number of input specifically for data
   * @details [data, data, weight, bias] => 2
   * @return IOIndex_t
   */
  inline IOIndex_t nbDataInputs() const noexcept {
    return getOperator()->nbDataInputs();
  }

  /**
   * @brief Number of inputs linked to a Parent's output.
   * @return IOIndex_t
   */
  IOIndex_t nbValidInputs() const;

  /**
   * @brief Getter for the number of Output Tensors of the Node.
   * @return IOIndex_t
   */
  inline IOIndex_t nbOutputs() const noexcept { return getOperator()->nbOutputs(); }

  IOIndex_t nbValidOutputs() const;

  ///////////////////////////////////////////////////////
  //        TOPOLOGY
  ///////////////////////////////////////////////////////

  /**
   * @brief Vector of pointers to each GraphView containing the object
   * @return std::vector<GraphView>
   */
  inline std::set<std::shared_ptr<GraphView>> views() const noexcept {
    std::set<std::shared_ptr<GraphView>> res;
    for (const auto &v : mViews) {
      res.insert(v.lock());
    }
    return res;
  }

  /**
   * @brief Add a GraphView pointer to the list of GraphView containing
   * the current Node. This feature allows transparent GraphViews.
   * @param graphPtr Pointer to GraphView to add to the list.
   */
  inline void addView(const std::shared_ptr<GraphView> &graphPtr) {
    mViews.insert(std::weak_ptr<GraphView>(graphPtr));
  }

  inline void removeView(const std::shared_ptr<GraphView> &graphPtr) {
    std::set<std::weak_ptr<GraphView>, weakCompare>::const_iterator viewIt = mViews.cbegin();
    for (; (viewIt != mViews.cend()) && ((*viewIt).lock() != graphPtr) ; ++viewIt) {}
    mViews.erase(*viewIt);
  }

  /**
   * @brief Link another Node to an output of the current Node.
   * @param otherNode Pointer to the other Node.
   * @param outId ID of the current Node output to connect to the other Node.
   * Default to 0.
   * @param otherInId ID of the other Node input to connect to the current Node.
   * Default to the first avaible data input.
   */
  void addChild(NodePtr otherNode,
                const IOIndex_t outId = IOIndex_t(0),
                IOIndex_t otherInId = gk_IODefaultIndex);

  /**
   * @brief Link a Node from a specific GraphView to the current Node.
   * @param otherView Pointer to the GraphView whose content should be
   * linked to the current Node.
   * @param outId ID of the output Tensor to connect to the other Node.
   * Default to 0.
   * @param otherInId Pair of pointer to Node and Tensor ID for specifying the
   * connection. If the GraphView whose content is linked has only one input
   * Node, then it defaults to the first available data input Tensor of this
   * Node.
   */
  void addChild(std::shared_ptr<GraphView> otherView,
                const IOIndex_t outId = IOIndex_t(0),
                std::pair<NodePtr, IOIndex_t> otherInId =
                std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));

  /**
   * @brief Get the list of parent Nodes. As an input is linked to a unique Node,
   * if none is linked then the parent is a nullptr.
   * @return std::vector<std::shared_ptr<Node>>
   */
  std::vector<NodePtr> getParents() const;

  /**
   * @brief Get the pointer to parent of the specified input index. This pointer is nullptr if no parent is linked.
   * @param inId Input index.
   * @return std::shared_ptr<Node>&
   */
  inline NodePtr &getParents(const IOIndex_t inId) {
    assert(inId != gk_IODefaultIndex);
    return mParents.at(inId);
  }

  /**
   * @brief Unlink the parent Node at the specified input index and return its pointer.
   * Return a nullptr is no parent was linked.
   * @param inId Input index.
   * @return std::shared_ptr<Node>
   */
  NodePtr popParent(const IOIndex_t inId);

  bool removeParent(const IOIndex_t inId);

  /**
   * @brief Get the set of pointers to children Nodes linked to the current Node.object.
   * @details The returned set does not include any nullptr as an output maybe linked to
   * an undifined number of Nodes. It does not change the computation of its associated Operator.
   * @return std::set<std::shared_ptr<Node>>>
   */
  std::set<NodePtr> getChildren() const;

  std::vector<std::vector<NodePtr>> getOrderedChildren() const;

  /**
   * @brief Get the list of children Nodes linked to the output at specified index.
   * @param outId Output index.
   * @return std::vector<std::shared_ptr<Node>>
   */
  std::vector<NodePtr> getChildren(const IOIndex_t outId) const;

  /**
   * @brief Remove registered child from children list of specified output if possible.
   * If so, also remove current Node from child Node from parent.
   * @param std::shared_ptr<Node> Node to remove.
   * @param outId Output index. Default 0.
   * @return true Child found and removed for given output index.
   * @return false Child not found at given index. Nothing removed.
   */
  bool removeChild(const NodePtr nodePtr, const IOIndex_t outId = 0);

  /**
   * @brief Remove every link of surrounding nodes to it and conversly
   */
  void resetConnections(bool includeLearnableParam = false);

  ///////////////////////////////////////////////////////
  //        CLONE
  ///////////////////////////////////////////////////////

  /**
   * @brief Clone the node keeping the same operator object instance. The new node has no connection.
   * @return NodePtr
   */
  NodePtr cloneSharedOperators() const;

  /**
   * @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection.
   * @return NodePtr
   */
  NodePtr cloneSharedProducers() const;

  /**
   * @brief Clone the node and its operator. The new node has no connection.
   * @return NodePtr
   */
  NodePtr clone() const;

  /**
   * @brief Clone the node keeping the same operator object instance. The new node has no connection.
   * @param node Node to clone.
   * @return NodePtr
   */
  static NodePtr cloneSharedOperators(NodePtr node) {
    return node->cloneSharedOperators();
  }

  /**
   * @brief Clone the node keeping the same operator object instance only for producers. Any other operator object instance is cloned as wel. The new node has no connection.
   * @param node Node to clone.
   * @return NodePtr
   */
  static NodePtr cloneSharedProducers(NodePtr node) {
    return node->cloneSharedProducers();
  }

  /**
   * @brief Clone the node and its operator. The new node has no connection.
   * @param node Node to clone.
   * @return NodePtr
   */
  static NodePtr clone(NodePtr node) {
    return node->clone();
  }

  friend std::shared_ptr<GraphView> GraphView::clone(NodePtr(*cloneNode)(NodePtr)) const;

private:
  ///////////////////////////////////////////////////////
  //        OPERATORS
  ///////////////////////////////////////////////////////

  // cannot change operator for now
  // void setOperator(const std::shared_ptr<Operator> op_ptr);

  ///////////////////////////////////////////////////////
  //        TENSOR MANAGEMENT
  ///////////////////////////////////////////////////////

  /**
   * @brief Set the idInChildren parameter.
   * @param inID
   * @param newNodeOutID
   */
  void setInputId(const IOIndex_t inID, const IOIndex_t newNodeOutID);

  ///////////////////////////////////////////////////////
  //        TOPOLOGY
  ///////////////////////////////////////////////////////

  /**
   * @brief Add the given Node as a child for the current Node.
   * @param otherNode
   * @param outId
   * @param otherInId
   */
  void addChildOp(NodePtr otherNode, const IOIndex_t outId,
                  const IOIndex_t otherInId);

  /**
   * @brief Add the given GraphView's input Node as a child for the current Node
   * @param otherGraph
   * @param outId
   * @param otherInId pointer the GraphView's input Node and its input index. Defaults to the
   * only input Node if the GraphView has got one.
   */
  void addChildView(std::shared_ptr<GraphView> otherGraph,
                    const IOIndex_t outId,
                    std::pair<NodePtr, IOIndex_t> otherInId);

  /**
   * @brief Add a Node to the list of parents.
   * @param otherNode Node to add to parents list.
   * @param inId index for adding the parent.
   */
  void addParent(const NodePtr otherNode, const IOIndex_t inId);
};
} // namespace Aidge

#endif /* AIDGE_CORE_GRAPH_NODE_H_ */