Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2197 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
GraphView.hpp 16.18 KiB

/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_GRAPH_GRAPHVIEW_H_
#define AIDGE_CORE_GRAPH_GRAPHVIEW_H_

#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {
enum class DataType;

/**
 * @brief Groupement of Nodes forming a computational graph on which properties and functions
 * can easily and safely be applied or run.
 */
class GraphView : public std::enable_shared_from_this<GraphView> {
private:
    /// @brief Name of the graphview
    std::string mName;

    /// @brief Set of nodes included in the GraphView
    std::set<NodePtr> mNodes;

    /// @brief Set of nodes included in the graphview with names
    std::map<std::string, NodePtr> mNodeRegistry;

    /// @brief Nodes without input link
    std::set<NodePtr> mInputNodes;

    /// @brief Nodes without output link
    std::set<NodePtr> mOutputNodes;

public:
    GraphView(std::string name="")
        : mName(name)
    {
        // ctor
    }

    // GraphView(std::set<NodePtr> nodes, std::string name="")
    //     : mName(name)
    // {
    //     add(nodes);
    // }

    bool operator==(const GraphView &gv) const
    {
        return mNodes == gv.mNodes;
    }

    NodePtr operator[](std::string name)
    {
        assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
        return mNodeRegistry.at(name);
    }

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

    Connector operator()(const std::vector<Connector> ctors);

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////
public:
    /**
     * @brief Name of the node.
     * @return std::string
     */
    std::string name() const;

    /**
     * @brief Set the node name.
     * @warning Undefined behaviour when several Nodes have the same name.
     * @param name New name for the node.
     */
    void setName(const std::string &name);

    /**
     * @brief Save the GraphView as a Mermaid graph in a .md file at the
     * specified location.
     * @param path
     */
    void save(std::string path, bool verbose = false) const;

    inline bool inView(NodePtr nodePtr) const {
        return mNodes.find(nodePtr) != mNodes.end();
    }

///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////
public:
    /** @brief Get reference to the set of input Nodes. */
    inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; }
    /** @brief Get reference to the set of output Nodes. */
    inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; }

    /** @brief Assess if the given Node is an input Node of the GraphView object. */
    inline bool isInputNode(NodePtr nodePtr) const {
        return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false;
    }
    /** @brief Assess if the given Node is an output Node of the GraphView object. */
    inline bool isOutputNode(NodePtr nodePtr) const {
        return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false;
    }

    /**
     * @brief List outside dataInput connections of the GraphView object's inputNodes.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;

    /**
     * @brief List dataInput connections of the GraphView object's inputNodes.
     * @param name Name of the Node.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }

    /**
     * @brief List outside input connections of the GraphView object's inputNodes.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;

    /**
     * @brief List input connections of the specified GraphView object's inputNode.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const;

    /**
     * @brief List output connections of the GraphView object's outputNodes.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;

    /**
     * @brief Specific i-th output connection of the GraphView object.
     * @param nodeName Name of the Node of which to show the output.
     * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
     */
    std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
            std::string nodeName) const;

    /**
     * @brief Compute dimensions of input/output Tensors for each Operator of the
     * GraphView object's Nodes.
     */
    void forwardDims();

    /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
    void setBackend(const std::string &backend);
    /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
    void setDatatype(const DataType &datatype);

///////////////////////////////////////////////////////
//        TOPOLOGY
///////////////////////////////////////////////////////
public:
    /**
     * @brief Get the parents Nodes of inputNodes.
     * @return std::set<NodePtr>
     */
    std::set<NodePtr> getParents() const;
    /**
     * @brief Get parents Nodes of the specified Node.
     * @param nodeName Name of the Node.
     * @return std::vector<NodePtr>
     */
    std::vector<NodePtr> getParents(const std::string nodeName) const;
    std::vector<std::vector<NodePtr>> getOrderedParents() const;

    /**
     * @brief Get the children Nodes of outputNodes.
     * @return std::set<NodePtr>
     */
    std::set<NodePtr> getChildren() const;
    /**
     * @brief Get children Nodes of the specified Node.
     * @param nodeName Name of the Node.
     * @return std::vector<std::vector<NodePtr>>
     */
    std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const;
    std::set<NodePtr> getChildren(
            const NodePtr otherNode) const;  // TODO change it for a vector<vector> ?

    /**
     * @brief Get the Nodes pointed to by the GraphView object.
     * @return std::set<NodePtr>
     */
    inline const std::set<NodePtr>& getNodes() const { return mNodes; }

    /**
     * @brief Get the operator with the corresponding name if it is in the
     * GraphView.
     * @param nodeName Name of the node.
     * @return NodePtr returns a new empty node if the one asked for
     * was not found.
     */
    NodePtr getNode(const std::string& nodeName) const;

    /**
     * @brief Remove a Node from the current GraphView scope without affecting its connections.
     * @param nodePtr Node to remove
     * @param includeLearnableParam Whether learnable parameters should also be removed. Default true.
     */
    void remove(NodePtr nodePtr, bool includeLearnableParam = true);

    // Surrounding nodes management

    void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID);

    /**
     * @brief Include a Node to the current GraphView object.
     * @param other_Nde Node to add.
     * @param includeLearnableParam Include non-data inputs, like weights and biases
     * in the GraphView automatically. Default: true.
     */
    void add(NodePtr otherNode, bool includeLearnableParam = true);
    /**
     * @brief Include a set of Nodes to the current GraphView object.
     * @param otherNodes
     * @param includeLearnableParam
     */
    void add(std::set<NodePtr> otherNodes,
             bool includeLearnableParam = true);

    /**
     * @brief Include every Node inside another GraphView to the current
     * GraphView.
     * @param other_graph GraphView containing the Nodes to include.
     */
    void add(std::shared_ptr<GraphView> otherGraph);

    /**
     * @brief Include a Node in the current GraphView and link it to another
     * already contained Node.
     *
     * @param toOtherNode Pointer to the Node to add.
     * @param fromOutNode Pointer to the already included Node the new Node will
     * be linked to (it will become a parent of the new Node). If the GraphView
     * only has one output Node, then default to this Node.
     * @param fromTensor Ouput Tensor ID of the already included Node. Default to
     * 0.
     * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
     * first available data input for the Node.
     */
    void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr,
                  const IOIndex_t fromTensor = IOIndex_t(0),
                  IOIndex_t toTensor = gk_IODefaultIndex);

    /**
     * @brief Include a Node in the current GraphView and link it to another
     * already contained Node.
     *
     * @param toOtherNode Pointer to the Node to add.
     * @param fromOutNodeName Name of the already included Node the new Node will
     * be linked to (it will become a parent of the new Node). As a name is
     * optional, ensure such Node is in the GraphView or it will send back an
     * error message.
     * @param fromTensor Ouput Tensor ID of the already included Node. Default to
     * 0.
     * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
     * first available data input for the Node.
     */
    inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName,
                         const IOIndex_t fromTensor = IOIndex_t(0),
                         IOIndex_t toTensor = gk_IODefaultIndex) {
        assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() &&
               "No Node with this name found in the GraphView.");
        addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
    }

    /**
     * @brief Include a GraphView content in the current GraphView and link
     * the two sets by linking one Node from each GraphView.
     * @param toOtherView Pointer to the GraphView whose content should be added.
     * @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the
     * connection. If the GraphView including the other one has only one output
     * Node, then it defaults to the first output Tensor of this Node.
     * @param toNode Pair of pointer to Node and Tensor ID for specifying the
     * connection. If the GraphView whose content is included has only one input
     * Node, then it defaults to the first available data input Tensor of this
     * Node.
     */
    void addChild(std::shared_ptr<GraphView> toOtherView,
                  std::pair<NodePtr, IOIndex_t> fromOutNode =
                          std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)),
                  std::pair<NodePtr, IOIndex_t> toNode =
                          std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));

    /**
     * @brief Swap two Node instances if possible.
     * @param node
     * @param otherNode
     * @return true
     * @return false
     */
    bool swap(Node &node, Node &otherNode);

    void link(std::string name1_inID, std::string name2_outID);

    /**
     * @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
     *
     * @param childNode Node that gets a new parent.
     * @param newParentNode Inserted Node.
     * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
     * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
     * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
     */
    void insertParent(NodePtr childNode,
                        NodePtr newParentNode,
                        IOIndex_t childInputTensorIdx,
                        IOIndex_t newParentInputTensorIdx,
                        IOIndex_t newParentOutputTensorIdx);

    /**
     * @brief Replace the current GraphView with the set of given Nodes if possible
     * @param newNodes Set of Nodes.
     * @return true
     * @return false
     */
    bool replaceWith(std::set<NodePtr> newNodes);

    /**
     * @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible.
     * Both sets should include all the necessary Producers.
     * @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing
     * them will not be affected by the replacement. The oldNodes set should have only one input/output
     * Node for automatic connections of newNodes set.
     * @param oldNodes actual set of shared_ptr<Node> to replace.
     * @param newNodes new set of shared_ptr<Node>.
     * @return true
     * @return false
     */
    bool replace(std::set<NodePtr>& oldNodes, std::set<NodePtr>& newNodes);

    void updateInputNodes();
    /**
     * @brief Process from zero the set of output Nodes.
     */
    void updateOutputNodes();

    /**
     * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones.
     * @return std::shared_ptr<GraphView>
     */
    inline std::shared_ptr<GraphView> cloneSharedOperators() const {
        return cloneCallback(&Node::cloneSharedOperators);
    }

    /**
     * @brief Clone the GraphView with shared Producers. All the other Operators are copied.
     * @return std::shared_ptr<GraphView>
     */
    inline std::shared_ptr<GraphView> cloneSharedProducers() const {
        return cloneCallback(&Node::cloneSharedProducers);
    }

    /**
     * @brief Clone the GraphView. Everything is cloned: Nodes and Operators.
     * @return std::shared_ptr<GraphView>
     */
    inline std::shared_ptr<GraphView> clone() const {
        return cloneCallback(&Node::clone);
    }

    /**
     * @brief Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity). When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an identity operation).
     * @param cloneNode Callback function to clone a node
     * @return std::shared_ptr<GraphView>
     */
    std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const;

    /**
     * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
     * @return IOIndex_t
     */
    IOIndex_t getNbFreeDataInputs() const;

private:
///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////

    /**
     * @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object.
     * @return IOIndex_t
     */
    IOIndex_t getNbDataInputs() const;

    /**
     * @brief Update the set of inputNodes with a new Node, checking if it can be
     * added and removing any Node not part of mInputNode anymore.
     * @param nodePtr
     */
    void updateInputNodes(NodePtr node);

    /**
     * @brief Update the set of outputNodes with a new Node, checking if it can be
     * added and removing any Node not part of mOutputNode anymore.
     * @param nodePtr
     */
    void updateOutputNodes(NodePtr node);

    ///////////////////////////////////////////////////////
    //        TOPOLOGY
    ///////////////////////////////////////////////////////

    void _forwardDims(std::set<NodePtr> listNodes);

    void removeInputNode(const std::string nodeName);
    void removeOutputNode(const std::string nodeName);
};
}  // namespace Aidge

#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */