Skip to content
Snippets Groups Projects
GraphView.hpp 19.8 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_GRAPH_GRAPHVIEW_H_
#define AIDGE_CORE_GRAPH_GRAPHVIEW_H_
Cyril Moineau's avatar
Cyril Moineau committed

#include <map>
#include <memory>
Cyril Moineau's avatar
Cyril Moineau committed
#include <string>
#include <utility>
#include <vector>

#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
Cyril Moineau's avatar
Cyril Moineau committed

namespace Aidge {
enum class DataType;

/**
 * @brief Groupement of Nodes forming a computational graph on which properties and functions
 * can easily and safely be applied or run.
 */
Cyril Moineau's avatar
Cyril Moineau committed
class GraphView : public std::enable_shared_from_this<GraphView> {
private:
    /// @brief Name of the graphview
Cyril Moineau's avatar
Cyril Moineau committed

    /// @brief GraphView root node
    NodePtr mRootNode;

Cyril Moineau's avatar
Cyril Moineau committed
    /// @brief Set of nodes included in the GraphView
Cyril Moineau's avatar
Cyril Moineau committed

    /// @brief Set of nodes included in the graphview with names
    std::map<std::string, NodePtr> mNodeRegistry;
Olivier BICHLER's avatar
Olivier BICHLER committed
    /// @brief GraphView inputs
    std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes;
Cyril Moineau's avatar
Cyril Moineau committed

Olivier BICHLER's avatar
Olivier BICHLER committed
    /// @brief GraphView outputs
    std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
Cyril Moineau's avatar
Cyril Moineau committed

public:
    GraphView(std::string name="")
Cyril Moineau's avatar
Cyril Moineau committed
    {
        // ctor
    }

    bool operator==(const GraphView &gv) const
Cyril Moineau's avatar
Cyril Moineau committed
    {
        return mNodes == gv.mNodes;
    }

    NodePtr operator[](std::string name)
Cyril Moineau's avatar
Cyril Moineau committed
    {
        assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
        return mNodeRegistry.at(name);
    }

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

    Connector operator()(const std::vector<Connector> ctors);

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////
public:
    /**
     * @brief Name of the node.
     * @return std::string
     */
    std::string name() const;

    /**
     * @brief Set the node name.
     * @warning Undefined behaviour when several Nodes have the same name.
     * @param name New name for the node.
     */
    void setName(const std::string &name);

    /**
     * @brief Save the GraphView as a Mermaid graph in a .md file at the
     * specified location.
     * @param path
     */
    void save(std::string path, bool verbose = false) const;

    inline bool inView(NodePtr nodePtr) const {
        return mNodes.find(nodePtr) != mNodes.end();
    }

    NodePtr getRootNode() {
        return mRootNode;
    }

Cyril Moineau's avatar
Cyril Moineau committed
///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////
public:
    /** @brief Get reference to the set of input Nodes. */
    inline std::set<NodePtr> inputNodes() const noexcept {
Olivier BICHLER's avatar
Olivier BICHLER committed
        std::set<NodePtr> nodes;
        for (auto node : mInputNodes) {
            nodes.insert(node.first);
        }
        return nodes;
    }
    /** @brief Get reference to the set of output Nodes. */
    inline std::set<NodePtr> outputNodes() const noexcept {
Olivier BICHLER's avatar
Olivier BICHLER committed
        std::set<NodePtr> nodes;
        for (auto node : mOutputNodes) {
            nodes.insert(node.first);
        }
        return nodes;
    }
    /** @brief Assess if the given Node is an input Node of the GraphView object. */
Cyril Moineau's avatar
Cyril Moineau committed
    inline bool isInputNode(NodePtr nodePtr) const {
Olivier BICHLER's avatar
Olivier BICHLER committed
        const auto nodes = inputNodes();
        return (nodes.find(nodePtr) != nodes.end()) ? true : false;
Cyril Moineau's avatar
Cyril Moineau committed
    }
    /** @brief Assess if the given Node is an output Node of the GraphView object. */
Cyril Moineau's avatar
Cyril Moineau committed
    inline bool isOutputNode(NodePtr nodePtr) const {
Olivier BICHLER's avatar
Olivier BICHLER committed
        const auto nodes = outputNodes();
        return (nodes.find(nodePtr) != nodes.end()) ? true : false;
Olivier BICHLER's avatar
Olivier BICHLER committed
    void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
    void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);

    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; };
    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; };

Cyril Moineau's avatar
Cyril Moineau committed
    /**
     * @brief List outside data input connections of the GraphView.
     * Data inputs exclude inputs expecting parameters (weights or bias).
     * The vector size is garanteed to match the number of outside data inputs of the GraphView. If there is
     * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
Cyril Moineau's avatar
Cyril Moineau committed
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;

    /**
     * @brief List all dataInput connections (within and outside) of the specified GraphView node named "name".
     * Data inputs exclude inputs expecting parameters (weights or bias).
Cyril Moineau's avatar
Cyril Moineau committed
     * @param name Name of the Node.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
Cyril Moineau's avatar
Cyril Moineau committed

    /**
     * @brief List outside input connections of the GraphView. The vector
     * size is garanteed to match the number of outside inputs of the GraphView. If there is
     * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
Cyril Moineau's avatar
Cyril Moineau committed
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
    std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;

     * @brief List all input connections (within and outside) of the specified GraphView node named "name".
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
     */
Cyril Moineau's avatar
Cyril Moineau committed
    std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const;

    /**
     * @brief List outside output connections of the GraphView. The vector
     * size is garanteed to match the number of outputs of the GraphView. If there is
     * no connection to a given output, the corresponding sub-vector will be empty.
     * @return std::vector<std::pair<NodePtr, IOIndex_t>>
Cyril Moineau's avatar
Cyril Moineau committed
     */
    std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;

    /**
     * @brief List all output connections (within and outside) of the specified GraphView node named "name".
Cyril Moineau's avatar
Cyril Moineau committed
     * @param nodeName Name of the Node of which to show the output.
     * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
     */
    std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
            std::string nodeName) const;

    /**
     * @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent.
     * If not, apply the required transformations.
     * @details Sets the GraphView ready for computation in four steps:
     * 1 - Assert input Tensors' datatype is compatible with each Operator's datatype.
     * If not, a conversion Operator is inserted.
     * 2 - Assert input Tensors' backend is compatible with each Operator's backend.
     * If not, add a Transmitter Operator.
     * 3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is
     * compatible with the selected kernel.
     * If not, add a Transpose Operator.
     * 4 - Propagate Tensor dimensions through the consecutive Operators.
     */
    void compile(const std::string& backend, const Aidge::DataType datatype);

    /**
     * @brief Compute dimensions of input/output Tensors for each Operator of the
     * GraphView object's Nodes.
     */
Cyril Moineau's avatar
Cyril Moineau committed
    void forwardDims();

    /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
Cyril Moineau's avatar
Cyril Moineau committed
    void setBackend(const std::string &backend);
    /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
    void setDataType(const DataType &datatype);
Cyril Moineau's avatar
Cyril Moineau committed

///////////////////////////////////////////////////////
//        TOPOLOGY
///////////////////////////////////////////////////////
public:
    /**
     * @brief Get the parents Nodes of inputNodes.
     * @return std::set<NodePtr>
Cyril Moineau's avatar
Cyril Moineau committed
     */
    std::set<NodePtr> getParents() const;
    /**
     * @brief Get parents Nodes of the specified Node.
     * @param nodeName Name of the Node.
Cyril Moineau's avatar
Cyril Moineau committed
    std::vector<NodePtr> getParents(const std::string nodeName) const;
    std::vector<std::vector<NodePtr>> getOrderedParents() const;

    /**
     * @brief Get the children Nodes of outputNodes.
Cyril Moineau's avatar
Cyril Moineau committed
     * @return std::set<NodePtr>
     */
    std::set<NodePtr> getChildren() const;
    /**
     * @brief Get children Nodes of the specified Node.
     * @param nodeName Name of the Node.
     * @return std::vector<std::vector<NodePtr>>
     */
Cyril Moineau's avatar
Cyril Moineau committed
    std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const;
    std::set<NodePtr> getChildren(
            const NodePtr otherNode) const;  // TODO change it for a vector<vector> ?

    /**
     * @brief Get the Nodes pointed to by the GraphView object.
Cyril Moineau's avatar
Cyril Moineau committed
     */
    inline const std::set<NodePtr>& getNodes() const { return mNodes; }
Cyril Moineau's avatar
Cyril Moineau committed

    /**
     * @brief Get the operator with the corresponding name if it is in the
     * GraphView.
     * @param nodeName Name of the node.
     * @return NodePtr returns a new empty node if the one asked for
Cyril Moineau's avatar
Cyril Moineau committed
     * was not found.
     */
    NodePtr getNode(const std::string& nodeName) const;
Cyril Moineau's avatar
Cyril Moineau committed

    /**
     * @brief Remove a Node from the current GraphView scope without affecting its connections.
Cyril Moineau's avatar
Cyril Moineau committed
     * @param nodePtr Node to remove
     * @param includeLearnableParam Whether learnable parameters should also be removed. Default true.
     */
    void remove(NodePtr nodePtr, bool includeLearnableParam = true);

    // Surrounding nodes management

    void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID);

    /**
     * @brief Include a Node to the current GraphView object.
     * @param other_Nde Node to add.
     * @param includeLearnableParam Include non-data inputs, like weights and biases
     * in the GraphView automatically. Default: true.
Cyril Moineau's avatar
Cyril Moineau committed
     */
    void add(NodePtr otherNode, bool includeLearnableParam = true);
    /**
     * @brief Include a set of Nodes to the current GraphView object.
     * @param otherNodes
     * @param includeLearnableParam
     * @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
    bool add(std::set<NodePtr> otherNodes,
Cyril Moineau's avatar
Cyril Moineau committed
             bool includeLearnableParam = true);

    /**
     * @brief Include a set of Nodes to the current GraphView object.
     * The first element of the otherNodes pair is the start node and
     * the second is the remaining nodes to add.
     * @param otherNodes
     * @param includeLearnableParam
     * @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
    bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
             bool includeLearnableParam = true);

Cyril Moineau's avatar
Cyril Moineau committed
    /**
     * @brief Include every Node inside another GraphView to the current
     * GraphView.
     * @param other_graph GraphView containing the Nodes to include.
     * @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
Cyril Moineau's avatar
Cyril Moineau committed
     */
    bool add(std::shared_ptr<GraphView> otherGraph);
Cyril Moineau's avatar
Cyril Moineau committed

    /**
     * @brief Include a Node in the current GraphView and link it to another
     * already contained Node.
     *
     * @param toOtherNode Pointer to the Node to add.
     * @param fromOutNode Pointer to the already included Node the new Node will
     * be linked to (it will become a parent of the new Node). If the GraphView
     * only has one output Node, then default to this Node.
     * @param fromTensor Ouput Tensor ID of the already included Node. Default to
     * 0.
     * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
     * first available data input for the Node.
     */
    void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr,
                  const IOIndex_t fromTensor = IOIndex_t(0),
                  IOIndex_t toTensor = gk_IODefaultIndex);

    /**
     * @brief Include a Node in the current GraphView and link it to another
     * already contained Node.
     *
     * @param toOtherNode Pointer to the Node to add.
     * @param fromOutNodeName Name of the already included Node the new Node will
     * be linked to (it will become a parent of the new Node). As a name is
     * optional, ensure such Node is in the GraphView or it will send back an
     * error message.
     * @param fromTensor Ouput Tensor ID of the already included Node. Default to
     * 0.
     * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
     * first available data input for the Node.
     */
    inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName,
                         const IOIndex_t fromTensor = IOIndex_t(0),
                         IOIndex_t toTensor = gk_IODefaultIndex) {
        assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() &&
               "No Node with this name found in the GraphView.");
        addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
    }

    /**
     * @brief Include a GraphView content in the current GraphView and link
     * the two sets by linking one Node from each GraphView.
     * @param toOtherView Pointer to the GraphView whose content should be added.
     * @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the
     * connection. If the GraphView including the other one has only one output
     * Node, then it defaults to the first output Tensor of this Node.
     * @param toNode Pair of pointer to Node and Tensor ID for specifying the
     * connection. If the GraphView whose content is included has only one input
     * Node, then it defaults to the first available data input Tensor of this
     * Node.
     */
    void addChild(std::shared_ptr<GraphView> toOtherView,
                  std::pair<NodePtr, IOIndex_t> fromOutNode =
                          std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)),
                  std::pair<NodePtr, IOIndex_t> toNode =
                          std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));

    /**
     * @brief Swap two Node instances if possible.
     * @param node
     * @param otherNode
     * @return true
     * @return false
     */
    bool swap(Node &node, Node &otherNode);

    void link(std::string name1_inID, std::string name2_outID);

    /**
     * @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
     * @param childNode Node that gets a new parent.
     * @param newParentNode Inserted Node.
     * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
     * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
     * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
    void insertParent(NodePtr childNode,
                        NodePtr newParentNode,
                        IOIndex_t childInputTensorIdx,
                        IOIndex_t newParentInputTensorIdx,
                        IOIndex_t newParentOutputTensorIdx);
Cyril Moineau's avatar
Cyril Moineau committed

     * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible.
     * Both sets should include all the necessary Producers.
     * @details There are 3 cases of replacement:
     * Case 1: same number of input/output connections for oldNodes and newNodes sets.
     *     - input/output connections are replacated according to their IDs.
     * Case 2: different number of input/output connections for oldNodes and newNodes sets.
     *     - only a single parent/child node for the newNodes set, every input/output is
     *       connected to it.
     *     - several parents/children nodes for newNodes set => impossible to know, return false
     * Case 3: newNodes set is empty
     *     - same number of input/output connections in oldNodes, parents and children are linked according
     *       to these connections IDs
     *     - different number of input/output connections in oldNodes => return false
     * @param oldNodes
     * @param newNodes
     * @return true replacement has been performed
     * @return false no replacement has been performed
    static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes);
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones.
     * @return std::shared_ptr<GraphView>
     */
    inline std::shared_ptr<GraphView> cloneSharedOperators() const {
Olivier BICHLER's avatar
Olivier BICHLER committed
        return cloneCallback(&Node::cloneSharedOperators);
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @brief Clone the GraphView with shared Producers. All the other Operators are copied.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @return std::shared_ptr<GraphView>
     */
    inline std::shared_ptr<GraphView> cloneSharedProducers() const {
Olivier BICHLER's avatar
Olivier BICHLER committed
        return cloneCallback(&Node::cloneSharedProducers);
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @brief Clone the GraphView. Everything is cloned: Nodes and Operators.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @return std::shared_ptr<GraphView>
     */
    inline std::shared_ptr<GraphView> clone() const {
        return cloneCallback(&Node::clone);
    }

    /**
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @brief Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity). When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an identity operation).
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @param cloneNode Callback function to clone a node
     * @return std::shared_ptr<GraphView>
Olivier BICHLER's avatar
Olivier BICHLER committed
    std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const;
    /**
     * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
     * Data inputs exclude inputs expecting parameters (weights or bias).
     * @return IOIndex_t
     */
    IOIndex_t getNbFreeDataInputs() const;

protected:
    /**
     * @brief Update inputs/outputs of the GraphView, with no particular order.
     * This function DOES NOT preserve inputs/outputs order and should NOT BE USED.
     * It is here only to leave time to adapt the replace() function.
     */
    void updateInputsOutputsNodes_DEPRECATED();
Cyril Moineau's avatar
Cyril Moineau committed
private:
///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////

     * @brief Get the number of dataInput that are outside the GraphView.
     * Data inputs exclude inputs expecting parameters (weights or bias).
     * This number matches the size of the vector returned by GraphView::dataInputs().
     */
    IOIndex_t getNbDataInputs() const;
Cyril Moineau's avatar
Cyril Moineau committed

Olivier BICHLER's avatar
Olivier BICHLER committed
    /**
     * @brief Automatically update GraphView inputs/outputs with a new Node, checking if
     * it this Node becomes an input/output for the graph and if previous inputs are still
     * inputs/outputs after adding this node.
Olivier BICHLER's avatar
Olivier BICHLER committed
    void updateInputsOutputsNew(NodePtr newNode);
Cyril Moineau's avatar
Cyril Moineau committed

    /**
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @brief Automatically update GraphView inputs/outputs with a Node removed, checking if
     * it this Node was an input/output for the graph and if this node childs become new inputs/outputs
     * for the graph.
Cyril Moineau's avatar
Cyril Moineau committed
     * @param nodePtr
     */
Olivier BICHLER's avatar
Olivier BICHLER committed
    void updateInputsOutputsDelete(NodePtr deletedNode);
Cyril Moineau's avatar
Cyril Moineau committed

    ///////////////////////////////////////////////////////
    //        TOPOLOGY
    ///////////////////////////////////////////////////////

    void _forwardDims(std::set<NodePtr> listNodes);
};
}  // namespace Aidge

#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */