Skip to content
Snippets Groups Projects
GraphView.hpp 11.4 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef __AIDGE_CORE_GRAPH_GRAPHVIEW_H__
#define __AIDGE_CORE_GRAPH_GRAPHVIEW_H__

#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "graph/Node.hpp"
#include "utils/Types.h"

namespace Aidge {
enum class DataType;
class GraphView : public std::enable_shared_from_this<GraphView> {
private:
    /// @brief Name of the graphview
    std::string mName; 

    /// @brief Set of nodes included in the GraphView
    std::set<NodePtr> mNodes; 

    /// @brief Set of nodes included in the graphview with names
    std::map<std::string, NodePtr> mNodeRegistry;
    
    /// @brief Nodes without input link
    std::set<NodePtr> mInputNodes;

    /// @brief Nodes without output link
    std::set<NodePtr> mOutputNodes;

public:
    GraphView(std::string name="")
        : mName(name) 
    {
        // ctor
    }

    GraphView(std::set<NodePtr> nodes, std::string name="")
        : mName(name) 
    {
        add(nodes);
    }

    bool operator==(const GraphView &gv) const 
    {
        return mNodes == gv.mNodes;
    }

    NodePtr operator[](std::string name) 
    {
        assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
        return mNodeRegistry.at(name);
    }

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

    Connector operator()(const std::vector<Connector> ctors);

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////
public:
    /**
     * @brief Name of the node.
     * @return std::string
     */
    std::string name() const;

    /**
     * @brief Set the node name.
     * @warning Undefined behaviour when several Nodes have the same name.
     * @param name New name for the node.
     */
    void setName(const std::string &name);

    /**
     * @brief Save the GraphView as a Mermaid graph in a .md file at the
     * specified location.
     * @param path
     */
    void save(std::string path, bool verbose = false) const;

    inline bool inView(NodePtr nodePtr) const {
        return mNodes.find(nodePtr) != mNodes.end();
    }

///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////
public:
    inline std::set<NodePtr> inputNodes() const noexcept { return mInputNodes; }
    inline std::set<NodePtr> outputNodes() const noexcept { return mOutputNodes; }

    inline bool isInputNode(NodePtr nodePtr) const {
        return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false;
    }
    inline bool isOutputNode(NodePtr nodePtr) const {
        return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false;
    }

    /**
     * @brief List data input Tensors of the graph input nodes.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;

    /**
     * @brief List data input Tensors of the graph input nodes.
     * @param name Name of the Node.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    inline auto dataInputs(std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }

    /**
     * @brief List input Tensors of the graph input nodes.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;

    std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const;

    /**
     * @brief List output Tensors of the node.
     * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
     */
    std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;

    /**
     * @brief Specific i-th output Tensor of the GraphView.
     * @param nodeName Name of the Node of which to show the output.
     * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
     */
    std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
            std::string nodeName) const;

    void forwardDims();

    void setBackend(const std::string &backend);
    void setDatatype(const DataType &datatype);

///////////////////////////////////////////////////////
//        TOPOLOGY
///////////////////////////////////////////////////////
public:
    /**
     * @brief Get the Parents of inputNodes.
     * @return std::vector<NodePtr>
     */
    std::set<NodePtr> getParents() const;
    std::vector<NodePtr> getParents(const std::string nodeName) const;
    std::vector<std::vector<NodePtr>> getOrderedParents() const;

    /**
     * @brief Get the Children of outputNodes.
     * @return std::set<NodePtr>
     */
    std::set<NodePtr> getChildren() const;
    std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const;
    std::set<NodePtr> getChildren(
            const NodePtr otherNode) const;  // TODO change it for a vector<vector> ?

    /**
     * @brief Getter for Operators of the GraphView.
     * @return std::set<NodePtr>
     */
    inline std::set<NodePtr> getNodes() const { return mNodes; }

    /**
     * @brief Get the operator with the corresponding name if it is in the
     * GraphView.
     * @param nodeName name of the node.
     * @return NodePtr return a new empty node if the one asked for
     * was not found.
     */
    NodePtr getNode(const char *nodeName) const;

    /**
     * @brief Remove a Node from the current GraphView scope without affecting its connections
     * @param nodePtr Node to remove
     * @param includeLearnableParam Whether learnable parameters should also be removed. Default true.
     */
    void remove(NodePtr nodePtr, bool includeLearnableParam = true);

    // Surrounding nodes management

    void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID);

    /**
     * @brief Includes a Node to the current GraphView
     * @param other_node Node to add.
     * @param includeLearnableParam Should non-data inputs, like weights and biases
     * be included in the GraphView automatically. Default: true.
     */
    void add(NodePtr otherNode, bool includeLearnableParam = true);
    void add(std::set<NodePtr> otherNodes,
             bool includeLearnableParam = true);

    /**
     * @brief Include every Node inside another GraphView to the current
     * GraphView.
     * @param other_graph GraphView containing the Nodes to include.
     */
    void add(std::shared_ptr<GraphView> otherGraph);

    /**
     * @brief Include a Node in the current GraphView and link it to another
     * already contained Node.
     *
     * @param toOtherNode Pointer to the Node to add.
     * @param fromOutNode Pointer to the already included Node the new Node will
     * be linked to (it will become a parent of the new Node). If the GraphView
     * only has one output Node, then default to this Node.
     * @param fromTensor Ouput Tensor ID of the already included Node. Default to
     * 0.
     * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
     * first available data input for the Node.
     */
    void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr,
                  const IOIndex_t fromTensor = IOIndex_t(0),
                  IOIndex_t toTensor = gk_IODefaultIndex);

    /**
     * @brief Include a Node in the current GraphView and link it to another
     * already contained Node.
     *
     * @param toOtherNode Pointer to the Node to add.
     * @param fromOutNodeName Name of the already included Node the new Node will
     * be linked to (it will become a parent of the new Node). As a name is
     * optional, ensure such Node is in the GraphView or it will send back an
     * error message.
     * @param fromTensor Ouput Tensor ID of the already included Node. Default to
     * 0.
     * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
     * first available data input for the Node.
     */
    inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName,
                         const IOIndex_t fromTensor = IOIndex_t(0),
                         IOIndex_t toTensor = gk_IODefaultIndex) {
        assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() &&
               "No Node with this name found in the GraphView.");
        addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
    }

    /**
     * @brief Include a GraphView content in the current GraphView and link
     * the two sets by linking one Node from each GraphView.
     * @param toOtherView Pointer to the GraphView whose content should be added.
     * @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the
     * connection. If the GraphView including the other one has only one output
     * Node, then it defaults to the first output Tensor of this Node.
     * @param toNode Pair of pointer to Node and Tensor ID for specifying the
     * connection. If the GraphView whose content is included has only one input
     * Node, then it defaults to the first available data input Tensor of this
     * Node.
     */
    void addChild(std::shared_ptr<GraphView> toOtherView,
                  std::pair<NodePtr, IOIndex_t> fromOutNode =
                          std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)),
                  std::pair<NodePtr, IOIndex_t> toNode =
                          std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));

    /**
     * @brief Swap two Node instances if possible.
     * @param node
     * @param otherNode
     * @return true
     * @return false
     */
    bool swap(Node &node, Node &otherNode);

    void link(std::string name1_inID, std::string name2_outID);

    void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes,
                IOIndex_t tensorIdx);

    /**
     * @brief Replace the current GraphView with the set of given Nodes if possible
     * @param newNodes Set of Nodes.
     * @return true 
     * @return false 
     */
    bool replaceWith(std::set<NodePtr> newNodes);
    void updateInputNodes();
    /**
     * @brief Process from zero the set of output Nodes.
     */
    void updateOutputNodes();

private:
///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////

    IONb_t getNbDataInputs() const;

    IONb_t getNbFreeDataInputs() const;


    void updateInputNodes(NodePtr node);

    /**
     * @brief Update the set of output Nodes with a new Node,checking if it can be
     * added and removing any Node not part of mOutputNode anymore.
     * @param nodePtr
     */
    void updateOutputNodes(NodePtr node);

    ///////////////////////////////////////////////////////
    //        TOPOLOGY
    ///////////////////////////////////////////////////////

    void _forwardDims(std::set<NodePtr> listNodes);

    void removeInputNode(const std::string nodeName);
    void removeOutputNode(const std::string nodeName);
};
}  // namespace Aidge

#endif /* __AIDGE_CORE_GRAPH_GRAPHVIEW_H__ */