Skip to content
Snippets Groups Projects
GraphView.cpp 54.7 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/graph/GraphView.hpp"
Cyril Moineau's avatar
Cyril Moineau committed

#include <algorithm>     // std::find, std::set_intersection, std::transform
Cyril Moineau's avatar
Cyril Moineau committed
#include <cassert>
#include <stdexcept>     // std::runtime_error
#include <cstddef>       // std::size_t
#include <cstdio>        // std::fclose, std::fopen
#include <iterator>      // std::back_inserter, std::distance, std::inserter,
                         // std::next
Maxence Naud's avatar
Maxence Naud committed
#include <map>
#include <memory>        // std::dynamic_pointer_cast, std::static_pointer_cast
Maxence Naud's avatar
Maxence Naud committed
#include <set>
#include <string>        // std::to_string
#include <utility>       // std::make_pair, std::pair
Maxence Naud's avatar
Maxence Naud committed
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MetaOperator.hpp"
Maxence Naud's avatar
Maxence Naud committed
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Directories.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
Cyril Moineau's avatar
Cyril Moineau committed


const std::shared_ptr<Aidge::Node> Aidge::GraphView::operator[](const std::string& nodeName) const {
    return (mNodeRegistry.find(nodeName) != mNodeRegistry.cend()) ? mNodeRegistry.at(nodeName) : nullptr;
}
Cyril Moineau's avatar
Cyril Moineau committed

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

Aidge::Connector Aidge::GraphView::operator()(
    const std::vector<Aidge::Connector> ctors) {
  // TODO: allow for multiple inputNodes?
  assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
  std::shared_ptr<Node> inNode = *inputNodes().begin();
  assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n");
  for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
Cyril Moineau's avatar
Cyril Moineau committed
    assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
    (void)input; // avoid unused warning
Cyril Moineau's avatar
Cyril Moineau committed
  }

  IOIndex_t inID = 0;
  for (const Connector &ctor : ctors) {
Cyril Moineau's avatar
Cyril Moineau committed
    assert((ctor.node() != nullptr) &&
           "Input Connector must be associated with a node");
    ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
                          {inNode, inID++});
  }
  return Connector(*(outputNodes().begin()));
}

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////

bool Aidge::GraphView::inView(const std::shared_ptr<Aidge::Node>& nodePtr) const {
    return mNodes.find(nodePtr) != mNodes.cend();
}
Olivier BICHLER's avatar
Olivier BICHLER committed
void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProducers) const {
Olivier BICHLER's avatar
Olivier BICHLER committed
    auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((path + ".mmd").c_str(), "w"), &std::fclose);

    if (!fp) {
        AIDGE_THROW_OR_ABORT(std::runtime_error,
            "Could not create graph view log file: {}", path + ".mmd");
    }

    fmt::print(fp.get(),
                "%%{{init: {{'flowchart': {{ 'curve': 'monotoneY'}}, "
                "'fontFamily': 'Verdana' }} }}%%\nflowchart TB\n\n");
Cyril Moineau's avatar
Cyril Moineau committed

    // Start by creating every node
    const auto namePtrTable = getRankedNodesName("{3}");
Cyril Moineau's avatar
Cyril Moineau committed

    for (const std::shared_ptr<Node> &node_ptr : mNodes) {
Olivier BICHLER's avatar
Olivier BICHLER committed
        std::string givenName =
Cyril Moineau's avatar
Cyril Moineau committed
            (node_ptr->name().empty())
                ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>"
                : "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\"";

        std::string nodeCls = "";
        if (node_ptr->type() == "Producer") {
          nodeCls = ":::producerCls";
        }
        else if (std::dynamic_pointer_cast<GenericOperator_Op>(node_ptr->getOperator())) {
          nodeCls = ":::genericCls";
        }
        else if (const auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node_ptr->getOperator())) {
          nodeCls = ":::metaCls";

          if (verbose) {
            metaOp->getMicroGraph()->save(path + "_" + node_ptr->type() + "#" + namePtrTable.at(node_ptr), verbose, showProducers);
          }
        }

        if (node_ptr == mRootNode) {
Olivier BICHLER's avatar
Olivier BICHLER committed
          if (nodeCls.empty()) {
            nodeCls = ":::rootCls";
          }
          else {
            nodeCls += "_rootCls";
          }

        if (node_ptr == mRootNode || node_ptr->type() != "Producer" || showProducers) {
Olivier BICHLER's avatar
Olivier BICHLER committed
          fmt::print(fp.get(), "{}_{}({}){}\n", node_ptr->type(), namePtrTable.at(node_ptr),
                      givenName, nodeCls);
Cyril Moineau's avatar
Cyril Moineau committed
    }
Cyril Moineau's avatar
Cyril Moineau committed
    // Write every link
    for (const std::shared_ptr<Node> &node_ptr : mNodes) {
      if ((node_ptr -> type() == "Producer") && !showProducers) {
        continue;
      }
      IOIndex_t outputIdx = 0;
Maxence Naud's avatar
Maxence Naud committed
      for (const auto& childs : node_ptr->getOrderedChildren()) {
        for (const auto& child : childs) {
          if (child != nullptr) {
            IOIndex_t inputIdx = 0;
            for (auto parent : child->inputs()) {
              if (parent.first == node_ptr && parent.second == outputIdx) {
                // Add-on to display the operator's output dimensions
                std::string dims = "";
                const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator());
                if (op && !op->getOutput(outputIdx)->dims().empty()) {
                  dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims());
                if (mNodes.find(child) != mNodes.end()) {
Olivier BICHLER's avatar
Olivier BICHLER committed
                  fmt::print(fp.get(), "{}_{}-->|\"{}{}&rarr;{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr),
                              outputIdx, dims, inputIdx, child->type(), namePtrTable.at(child));
                }
                else if (verbose) {
Olivier BICHLER's avatar
Olivier BICHLER committed
                  fmt::print(fp.get(), "{}_{}-->|\"{}{}&rarr;{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr),
                              outputIdx, dims, inputIdx, static_cast<void*>(child.get()));
Cyril Moineau's avatar
Cyril Moineau committed
        }
        ++outputIdx;
      }
    }

    size_t inputIdx = 0;
    for (auto input : mInputNodes) {
      if (input.first != nullptr) {
        fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}\"|{}_{}\n", inputIdx, inputIdx,
Olivier BICHLER's avatar
Olivier BICHLER committed
                    input.second, input.first->type(), namePtrTable.at(input.first));
Olivier BICHLER's avatar
Olivier BICHLER committed
        fmt::print(fp.get(), "input{}((in#{})):::inputCls\n", inputIdx, inputIdx);
      ++inputIdx;
Cyril Moineau's avatar
Cyril Moineau committed
    }

    size_t outputIdx = 0;
    for (auto output : mOutputNodes) {
      if (output.first != nullptr) {
        // Add-on to display the operator's output dimensions
        std::string dims = "";
        const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator());
        if (op && op->getOutput(output.second) && !op->getOutput(output.second)->dims().empty()) {
          dims += " " + fmt::format("{}", op->getOutput(output.second)->dims());
        }
Olivier BICHLER's avatar
Olivier BICHLER committed
        fmt::print(fp.get(), "{}_{}--->|\"{}{}&rarr;\"|output{}((out#{})):::outputCls\n",
                    output.first->type(), namePtrTable.at(output.first), output.second,
                    dims, outputIdx, outputIdx);
Olivier BICHLER's avatar
Olivier BICHLER committed
        fmt::print(fp.get(), "output{}((out#{})):::outputCls\n", outputIdx, outputIdx);
Olivier BICHLER's avatar
Olivier BICHLER committed
    fmt::print(fp.get(), "classDef inputCls fill:#afa\n");
    fmt::print(fp.get(), "classDef outputCls fill:#ffa\n");
    fmt::print(fp.get(), "classDef externalCls fill:#ccc\n");
    fmt::print(fp.get(), "classDef producerCls fill:#ccf\n");
    fmt::print(fp.get(), "classDef genericCls fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5\n");
    fmt::print(fp.get(), "classDef metaCls stroke-width:5px\n");
    fmt::print(fp.get(), "classDef rootCls stroke:#f00\n");
    fmt::print(fp.get(), "classDef producerCls_rootCls stroke:#f00,fill:#ccf\n");
    fmt::print(fp.get(), "classDef genericCls_rootCls stroke:#f00,fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5\n");
    fmt::print(fp.get(), "classDef metaCls_rootCls stroke:#f00,stroke-width:5px\n");
    fmt::print(fp.get(), "\n");
void Aidge::GraphView::logOutputs(const std::string& dirName) const {
  if (!Aidge::createDirectories(dirName)){
    AIDGE_THROW_OR_ABORT(std::runtime_error, "Failed to create directory: {}.", dirName);
  }
  for (std::shared_ptr<Node> nodePtr : getNodes()) {

    const std::string& nodePath = dirName + "/" + Aidge::filePath(nodePtr->name()) +"/";
    if (!Aidge::createDirectories(nodePath)){
      AIDGE_THROW_OR_ABORT(std::runtime_error, "Failed to create directory: {}.", nodePath);
    }

    for (IOIndex_t outIdx = 0; outIdx < nodePtr->nbOutputs(); ++outIdx) {
      const std::string& inputPath = nodePath +"output_" + std::to_string(outIdx) + ".log";
      auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen(inputPath.c_str(), "w"), &std::fclose);
      if (!fp) {
        AIDGE_THROW_OR_ABORT(std::runtime_error,
            "Could not create graph view log file: {}", inputPath);
      }
      fmt::print(fp.get(), "{}\n", nodePtr->getOperator()->getRawOutput(outIdx)->toString().c_str());
    }
  }
}

void Aidge::GraphView::setRootNode(NodePtr node) {
  AIDGE_ASSERT(mNodes.find(node) != mNodes.end(), "Root node is not in the GraphView!");
  mRootNode = node;
}

Cyril Moineau's avatar
Cyril Moineau committed
///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////

std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const {
    std::set<std::shared_ptr<Aidge::Node>> nodes;
    for (const auto& node : mInputNodes) {
        nodes.insert(node.first);
    }
    return nodes;
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::outputNodes() const {
    std::set<std::shared_ptr<Aidge::Node>> nodes;
    for (const auto& node : mOutputNodes) {
        nodes.insert(node.first);
    }
    return nodes;
}

bool Aidge::GraphView::isInputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const {
    const auto nodes = inputNodes();
    return (nodes.find(nodePtr) != nodes.cend());
}

bool Aidge::GraphView::isOutputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const {
    const auto nodes = outputNodes();
    return (nodes.find(nodePtr) != nodes.cend());
}


Olivier BICHLER's avatar
Olivier BICHLER committed
void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
Olivier BICHLER's avatar
Olivier BICHLER committed
  std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes);
  for (auto input : inputs) {
    // Allow to specify dummy inputs (nullptr), but this will only be reflected
    // in mInputNodes. All other functions (nbInputs(), inputs()) will not take
    // it into account.
    if (input.first != nullptr) {
      auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input);
      AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input");
      ignoredInputs.erase(it);
      ++nbInputs;
    }
  AIDGE_ASSERT(nbInputs <= mInputNodes.size(), "too many specified number of inputs");

Olivier BICHLER's avatar
Olivier BICHLER committed
  mInputNodes = inputs;
  mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end());
}

void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) {
Olivier BICHLER's avatar
Olivier BICHLER committed
  std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes);
  for (auto output : outputs) {
    // Allow to specify dummy outputs (nullptr), but this will only be reflected
    // in mOutputNodes. All other functions (nbOutputs(), outputs()) will not take
    // it into account.
    if (output.first != nullptr) {
      auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output);
      AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output");
      ignoredOutputs.erase(it);
      ++nbOutputs;
    }
  AIDGE_ASSERT(nbOutputs <= mOutputNodes.size(), "too many specified number of outputs");

Olivier BICHLER's avatar
Olivier BICHLER committed
  mOutputNodes = outputs;
  mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end());
}

Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
  IOIndex_t nbDataInput = 0;
Cyril Moineau's avatar
Cyril Moineau committed
  for (const std::shared_ptr<Node> &inNode : inputNodes()) {
    // We cannot simply add inNode->nbDataInputs(), as input nodes may already
    // have some inputs connected within the GraphView, which would therefore not
    // constitue inputs (from outside) for the GraphView!
    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
        inNode->dataInputs();

    for (const auto& input : inputNodeinputs) {
      if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
        ++nbDataInput;
      }
    }
Cyril Moineau's avatar
Cyril Moineau committed
  }
  return nbDataInput;
}

Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
  IOIndex_t nbIn = 0;
  // Free inputs within the GraphView are logically also free inputs from outside
  // the GraphView.
Olivier BICHLER's avatar
Olivier BICHLER committed
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
Cyril Moineau's avatar
Cyril Moineau committed
    nbIn += inputNode->getNbFreeDataInputs();
  }
  return nbIn;
}


std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
  std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;

Olivier BICHLER's avatar
Olivier BICHLER committed
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
Cyril Moineau's avatar
Cyril Moineau committed
        inputNode->dataInputs();

    for (const auto& input : inputNodeinputs) {
      if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
Cyril Moineau's avatar
Cyril Moineau committed
  }
  return res;
}


std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
  std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;

Olivier BICHLER's avatar
Olivier BICHLER committed
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
Cyril Moineau's avatar
Cyril Moineau committed
        inputNode->inputs();

    for (const auto& input : inputNodeinputs) {
      if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
Cyril Moineau's avatar
Cyril Moineau committed
  }
  return res;
}


std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Olivier BICHLER's avatar
Olivier BICHLER committed
Aidge::GraphView::inputs(const std::string& name) const {
Cyril Moineau's avatar
Cyril Moineau committed
  return mNodeRegistry.at(name)->inputs();
}

void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device) {
    setBackend(backend, device);
    // Data type
    // TODO: manage Datatype attribute in OperatorImpl
Maxence Naud's avatar
Maxence Naud committed
    setDataType(datatype);
    // Data Format
    // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
    // Forward dimensions
    forwardDims();
}

void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) {
Cyril Moineau's avatar
Cyril Moineau committed
    // setInputs
    // Link every tensor to the right pointer
    // following parent - children informations
      AIDGE_ASSERT(dims.size() == mInputNodes.size(), "GraphView forwardDims error - Inconsistent number of given dimensions ({}) and graph inputs ({})", dims.size(), mInputNodes.size());
      for (std::size_t i = 0; i < dims.size(); ++i){
        auto tensor = std::make_shared<Tensor>(dims[i]);
        mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor);
      }
    }

    // Ensure every node in the graph is correctly connected
Cyril Moineau's avatar
Cyril Moineau committed
    for (std::shared_ptr<Node> nodePtr : getNodes()) {
        for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
Cyril Moineau's avatar
Cyril Moineau committed
            // assess if the input was not already set and is a Tensor then link it to parent output
            std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
            if (inputI.first) {
Maxence Naud's avatar
Maxence Naud committed
                if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
                    if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
                        // assert provided Data is of "Tensor" type
                        nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
                    }
                    else {
                        AIDGE_ASSERT(false, "Non-tensor entries not handled yet, for node {} (of type {}).", nodePtr->name(), nodePtr->type());
Maxence Naud's avatar
Maxence Naud committed
                    }
                }
            } else {
                AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i)
                    && !std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty(),
                  "Missing input#{} for node {} ({})", i, nodePtr->name(), nodePtr->type());
Cyril Moineau's avatar
Cyril Moineau committed
            }
Cyril Moineau's avatar
Cyril Moineau committed
        }
    }
    // Compute dimensions of every node
    std::set<std::shared_ptr<Node>> listNodes = getNodes();
    do {
        std::set<std::shared_ptr<Node>> nextList;
        for (std::shared_ptr<Node> nodePtr : listNodes) {
Maxence Naud's avatar
Maxence Naud committed
            if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
              const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
              // Recompute everytime, even if it was already computed in a
              // previous call of forwardDims(), as the graph may have changed!
              op->computeOutputDims();
              if (!op->outputDimsForwarded()) {
                  nextList.insert(nodePtr);
              }
Maxence Naud's avatar
Maxence Naud committed
            }
        }
        // Internal check to make sure we won't enter in an infinite loop!
        if (nextList == listNodes) {
            // We are stuck!
            std::vector<std::string> nodesName;
            std::transform(nextList.begin(), nextList.end(),
                std::back_inserter(nodesName),
                [](auto val){ return val->name() + " (" + val->type() + ")"; });
            AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName);
        }
        listNodes.swap(nextList);
Cyril Moineau's avatar
Cyril Moineau committed
    }
    while (!listNodes.empty());
void Aidge::GraphView::setBackend(const std::string &backend, const DeviceIdx_t device) const {
    for (const auto& node : getNodes()) {
        node->getOperator()->setBackend(backend, device);
Maxence Naud's avatar
Maxence Naud committed
    }
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) const {
    for (const auto& node : getNodes()) {
Maxence Naud's avatar
Maxence Naud committed
        node->getOperator()->setDataType(datatype);
    }
Cyril Moineau's avatar
Cyril Moineau committed
}

std::vector<
    std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
  std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
Olivier BICHLER's avatar
Olivier BICHLER committed
  for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
    const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
        outputNodeOutputs = outputNode->outputs();

    for (const auto& outputPos : outputNodeOutputs) {
      // Keep only the nodes connected at this output position that are outside the GraphView
      std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos;
      for (const auto& output : outputPos) {
Olivier BICHLER's avatar
Olivier BICHLER committed
        if (output.first == nullptr || mNodes.find(output.first) == mNodes.end()) {
Olivier BICHLER's avatar
Olivier BICHLER committed
      if (outputPos.empty() || !outsideOutputPos.empty()) {
        outsideOutputs.push_back(outsideOutputPos);
      }
Cyril Moineau's avatar
Cyril Moineau committed
  }
Cyril Moineau's avatar
Cyril Moineau committed
}

std::vector<
    std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Olivier BICHLER's avatar
Olivier BICHLER committed
Aidge::GraphView::outputs(const std::string& nodeName) const {
Cyril Moineau's avatar
Cyril Moineau committed
  return mNodeRegistry.at(nodeName)->outputs();
}

void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
                               Aidge::IOIndex_t /*newNodeOutID*/) {
  AIDGE_THROW_OR_ABORT(std::runtime_error, "Not implemented yet.");
Cyril Moineau's avatar
Cyril Moineau committed
}

void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
Olivier BICHLER's avatar
Olivier BICHLER committed
  AIDGE_ASSERT(node != nullptr, "Trying to add non-existant node!");

  // first node to be added to the graph is the root node by default
  if (mRootNode == nullptr) {
    mRootNode = node;
  }

Cyril Moineau's avatar
Cyril Moineau committed
  // add to the GraphView nodes
  node->addView(shared_from_this());
  mNodes.insert(node);
  if (!(node->name()).empty())
    mNodeRegistry.insert(std::make_pair(node->name(), node));
Olivier BICHLER's avatar
Olivier BICHLER committed

  // check if the node is an input/output node
  updateInputsOutputsNew(node);

Cyril Moineau's avatar
Cyril Moineau committed
  // add learnable parameters to the graph
  if (includeLearnableParam) {
    for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) {
      std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
Cyril Moineau's avatar
Cyril Moineau committed
      if (parentNode) {
          parentNode->addView(shared_from_this());
          mNodes.insert(parentNode);
          if (!(parentNode->name()).empty())
            mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
Olivier BICHLER's avatar
Olivier BICHLER committed
          // check if the parentNode is an input/output node
          updateInputsOutputsNew(parentNode);
std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes() const {
  std::set<NodePtr> nodesToRank(mNodes);
  nodesToRank.erase(mRootNode);
  std::vector<NodePtr> rankedNodes;
  rankedNodes.push_back(mRootNode);

  for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
    NodePtr curNode = rankedNodes[curNodeIdx];

    for (auto childs : curNode->getOrderedChildren()) {
      for (auto child : childs) {
Olivier BICHLER's avatar
Olivier BICHLER committed
        if (child != nullptr && nodesToRank.find(child) != nodesToRank.end()) {
          rankedNodes.push_back(child);
          nodesToRank.erase(child);
        }
      }
    }

    for (auto parent : curNode->getParents()) {
Olivier BICHLER's avatar
Olivier BICHLER committed
      if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
        rankedNodes.push_back(parent);
        nodesToRank.erase(parent);
      }
    }
  }

  const size_t orderUnicityLimit = rankedNodes.size();
  if (!nodesToRank.empty()) {
    rankedNodes.insert(rankedNodes.end(), nodesToRank.begin(), nodesToRank.end());
  }

  return std::make_pair(rankedNodes, orderUnicityLimit);
}

std::map<Aidge::NodePtr, std::string> Aidge::GraphView::getRankedNodesName(const std::string& format, bool markNonUnicity) const {
  const auto rankedNodes = getRankedNodes();
  std::map<NodePtr, std::string> rankedNodesName;
  size_t rank = 0;
  std::map<std::string, size_t> typeRank;
  for (const auto& rankedNode : rankedNodes.first) {
    std::map<std::string, size_t>::iterator it;
    std::tie(it, std::ignore) = typeRank.insert(std::make_pair(rankedNode->type(), 0));

    const auto name = (markNonUnicity && rank < rankedNodes.second)
      ? fmt::format(format, rankedNode->name(), rankedNode->type(), rank, it->second)
      : fmt::format(format, rankedNode->name(), rankedNode->type(), fmt::format("?{}", rank), fmt::format("?{}", it->second));
    rankedNodesName.insert(std::make_pair(rankedNode, name));
    ++it->second;
    ++rank;
  }
  return rankedNodesName;
}

bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
  if (otherNodes.empty()) {
    return true;
  }

  bool orderUnicity = true;

Olivier BICHLER's avatar
Olivier BICHLER committed
  // List only the nodes that are not already present in current graph
  std::set<NodePtr> nodesToAdd;
  std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
Olivier BICHLER's avatar
Olivier BICHLER committed

  // List the nodes to rank, initially all the nodes in the GraphView
  std::set<NodePtr> nodesToRank(mNodes);
  nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end());
  std::vector<NodePtr> rankedNodesToAdd;

  if (mRootNode == nullptr) {
    std::set<NodePtr> noParentNodes;

    // If no root node is defined, check nodes without parents
    for (auto node : nodesToRank) {
      bool noParent = true;
      for (auto parent : node->getParents()) {
        if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
          noParent = false;
Olivier BICHLER's avatar
Olivier BICHLER committed
          break;
        }
      }

      if (noParent) {
        noParentNodes.insert(node);
    // Take the first one found (this is an arbitrary choice)
    mRootNode = *noParentNodes.begin();

    if (noParentNodes.size() > 1) {
      // If there is more than one, order unicity cannot be garanteed!
      orderUnicity = false;
    }

    rankedNodesToAdd.push_back(mRootNode);
  }

  nodesToRank.erase(mRootNode);
  std::vector<NodePtr> rankedNodes;
  rankedNodes.push_back(mRootNode);

  for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
    NodePtr curNode = rankedNodes[curNodeIdx];

    for (auto childs : curNode->getOrderedChildren()) {
      for (auto child : childs) {
Olivier BICHLER's avatar
Olivier BICHLER committed
        if (child != nullptr && nodesToRank.find(child) != nodesToRank.end()) {
          rankedNodes.push_back(child);
          nodesToRank.erase(child);

          if (nodesToAdd.find(child) != nodesToAdd.end()) {
            rankedNodesToAdd.push_back(child);
            nodesToAdd.erase(child);
    for (auto parent : curNode->getParents()) {
Olivier BICHLER's avatar
Olivier BICHLER committed
      if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
        rankedNodes.push_back(parent);
        nodesToRank.erase(parent);

        if (nodesToAdd.find(parent) != nodesToAdd.end()) {
          rankedNodesToAdd.push_back(parent);
          nodesToAdd.erase(parent);
        }
      }
Olivier BICHLER's avatar
Olivier BICHLER committed
    }
Olivier BICHLER's avatar
Olivier BICHLER committed

  if (!nodesToAdd.empty()) {
    // There are remaining nodes without path to the root node
    orderUnicity = false;

    while (!nodesToAdd.empty()) {
      const auto it = nodesToAdd.begin();
      rankedNodesToAdd.push_back(*it);
      nodesToAdd.erase(it);
Olivier BICHLER's avatar
Olivier BICHLER committed
    }
  }

  for (auto node_ptr : rankedNodesToAdd) {
    add(node_ptr, includeLearnableParam);
  }

  return orderUnicity;
bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
  if (nodes.first != nullptr) {
    mRootNode = nodes.first;
    add(nodes.first, includeLearnableParam);
  }
  return add(nodes.second, includeLearnableParam);
bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
    // set the rootNode to the other graphView rootNode if no rootNode yet
Maxence Naud's avatar
Maxence Naud committed
    mRootNode = mRootNode ? mRootNode : graph->rootNode();
    return add(graph->getNodes(), false);
Cyril Moineau's avatar
Cyril Moineau committed
}

void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
                               std::shared_ptr<Node> fromOutNode,
                               const Aidge::IOIndex_t fromTensor,
                               Aidge::IOIndex_t toTensor) {
  if (fromOutNode)
    assert(inView(fromOutNode) && "Output Node not found in the GraphView.");
  else {
    assert((outputNodes().size() == 1U) &&
           "Must specify an outputNode or have only one.");
    fromOutNode = *(outputNodes().begin());
  }
  fromOutNode->addChild(toOtherNode, fromTensor, toTensor);
  add(toOtherNode);
}

void Aidge::GraphView::addChild(
    std::shared_ptr<GraphView> toOtherView,
    std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode,
    std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) {
  // assert output node is valid
  if (!fromOutNode.first) {
    assert(outputNodes().size() == 1U &&
           "If no output node is provided, the graph should have only one to "
           "make the choice explicit.");
    fromOutNode.first = *(outputNodes().begin());
  } else
    assert(inView(fromOutNode.first));
  // assert input node is valid
  if (!toNode.first) {
    assert(toOtherView->inputNodes().size() == 1U &&
           "If no intput node is provided, the other graph should have only "
           "one to make the choice explicit.");
    toNode.first = *(toOtherView->inputNodes().begin());
  } else {
    assert(toOtherView->inView(toNode.first));
  }
  // Tensor assertions are performed in the Node adChild method
  fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second);
  // once linking performed, add other graph to current graph
  add(toOtherView);
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
  // TODO: choose if we return a set or a vector
  std::set<std::shared_ptr<Node>> parents;
Olivier BICHLER's avatar
Olivier BICHLER committed
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
Cyril Moineau's avatar
Cyril Moineau committed
    parents.insert(inputNode->getParents().begin(),
                   inputNode->getParents().end());
  }
  return parents;
}

std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const {
  std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName);
  AIDGE_ASSERT(it != mNodeRegistry.end(), "No node named {} in graph {}.", nodeName, name());
Cyril Moineau's avatar
Cyril Moineau committed
  return (it->second)->getParents();
}

std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const {
  std::vector<std::vector<std::shared_ptr<Node>>> parents;
Olivier BICHLER's avatar
Olivier BICHLER committed
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
Cyril Moineau's avatar
Cyril Moineau committed
    parents.push_back(inputNode->getParents());
  }
  return parents;
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
  std::set<std::shared_ptr<Node>> children;
Olivier BICHLER's avatar
Olivier BICHLER committed
  for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
Cyril Moineau's avatar
Cyril Moineau committed
    children.insert((outputNode->getChildren()).begin(),
                    (outputNode->getChildren()).end());
  }
  return children;
}

std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getChildren(const std::string nodeName) const {
  std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
      mNodeRegistry.find(nodeName);
  AIDGE_ASSERT(it != mNodeRegistry.end(), "No node named {} in graph {}.", nodeName, name());
Cyril Moineau's avatar
Cyril Moineau committed
  return (it->second)->getOrderedChildren();
}

std::set<std::shared_ptr<Aidge::Node>>
Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const {
  std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode);
  AIDGE_ASSERT(it != mNodes.end(), "The node {} (of type {}) is not in graph {}.",
    (otherNode) ? otherNode->name() : "#nullptr", (otherNode) ? otherNode->type() : "", name());
Cyril Moineau's avatar
Cyril Moineau committed
  return (*it)->getChildren();
}


std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
Cyril Moineau's avatar
Cyril Moineau committed
  std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
      mNodeRegistry.find(nodeName);
Cyril Moineau's avatar
Cyril Moineau committed
    return it->second;
  } else {
    Log::warn("No Node named {} in the current GraphView {}.", nodeName, name());
Cyril Moineau's avatar
Cyril Moineau committed
  }
}


void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
Olivier BICHLER's avatar
Olivier BICHLER committed
  // remove learnable params
Cyril Moineau's avatar
Cyril Moineau committed
  if (includeLearnableParam) {
    for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) {
Cyril Moineau's avatar
Cyril Moineau committed
      auto inputI = nodePtr->input(i);
      if (inputI.first != nullptr) {
        bool removeNode = true;
        for (const auto& parentOutput : inputI.first->outputs()) {
          for (const auto& childOfParentOutput : parentOutput) {
            // only remove the learnable parameter if not related to any other Node in the GraphView
            if (childOfParentOutput.first != nodePtr) {
              removeNode = false;
              break;
            }
Cyril Moineau's avatar
Cyril Moineau committed
          }
        }
        if (removeNode) {
          // assert Learnable Parameter in the GraphView scope
          if (mNodes.find(inputI.first) != mNodes.end()) {
            mNodes.erase(inputI.first);
            inputI.first->removeView(shared_from_this());
          }
          if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }
Olivier BICHLER's avatar
Olivier BICHLER committed

          // check if the node was an input/output node
          updateInputsOutputsDelete(inputI.first);
        }
Olivier BICHLER's avatar
Olivier BICHLER committed

  if (mNodes.find(nodePtr) != mNodes.end()) {
    mNodes.erase(nodePtr);
    nodePtr->removeView(shared_from_this());

    // check if the nodePtr was an input/output node
    updateInputsOutputsDelete(nodePtr);
Olivier BICHLER's avatar
Olivier BICHLER committed
  }
  if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }
Cyril Moineau's avatar
Cyril Moineau committed
}


bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) {
Olivier BICHLER's avatar
Olivier BICHLER committed
  fmt::print("Swap() not implementated yet. Return false.\n");
Cyril Moineau's avatar
Cyril Moineau committed
  return false;
}

Olivier BICHLER's avatar
Olivier BICHLER committed
void Aidge::GraphView::link(const std::string& /*name1_inID*/,
                           const std::string& /*name2_outID*/) {
Olivier BICHLER's avatar
Olivier BICHLER committed
  fmt::print("Not implemented yet.\n");
void Aidge::GraphView::insertParent(NodePtr childNode,
                  NodePtr newParentNode,
                  IOIndex_t childInputTensorIdx,
                  IOIndex_t newParentInputTensorIdx,
                  IOIndex_t newParentOutputTensorIdx){
  NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
  const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
  // Remove child from current parent & current Parent from child
  currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);

  // Add child
  currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
  newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);

  add(newParentNode);
bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {
    // (1) create GraphViews from both sets of Nodes
    auto oldG = std::make_shared<GraphView>("oldG");
    oldG->add(oldNodes, false);
    auto newG = std::make_shared<GraphView>("newG");
    return GraphView::replace(oldG, newG);
}

bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const std::shared_ptr<GraphView>& newGraph) {
    // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
    // How to distinguish it from data input?
    // TODO: Parameter Tensors could be identified with their dimensions
    // TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
    // It also avoids specifying each producer since they are automatically included
    const std::set<NodePtr>&  oldNodes = oldGraph->getNodes();
    const std::set<NodePtr>&  newNodes = newGraph->getNodes();

    const std::vector<std::pair<NodePtr, IOIndex_t>> oldOIn =
                                                     oldGraph->getOrderedInputs();
    const std::vector<std::pair<NodePtr, IOIndex_t>> oldOOut =
                                                     oldGraph->getOrderedOutputs();
    const std::vector<std::pair<NodePtr, IOIndex_t>> newOIn =
                                                     newGraph->getOrderedInputs();
    const std::vector<std::pair<NodePtr, IOIndex_t>> newOOut =
                                                     newGraph->getOrderedOutputs();

    auto inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOIn.size());
    auto outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOOut.size());

    // keep in memory every node related to the node to replace :
    // Parent
    for (std::size_t i = 0; i < oldOIn.size(); ++i) {
        std::pair<NodePtr, IOIndex_t> inputParent = 
                  oldOIn[i].first -> input(oldOIn[i].second);
        inputParents[i]= inputParent;
        // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
    // Children
    for (std::size_t i = 0; i < oldOOut.size();) {
        std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = 
              oldOOut[i].first -> output(oldOOut[i].second);
        if (outputChild.empty()) {
            outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
            ++i;
            for (const auto& child : outputChild) {
                if (oldNodes.find(child.first) == oldNodes.cend()) {
                    outputChildren[i] = child;
                    ++i;
    // only keep common views to each node for the new set
    // set of common GraphView for oldNodes' Nodes
    std::set<std::shared_ptr<GraphView>> commonGraphViews =  (*oldNodes.begin())->views();
    for (const auto& nodePtr : oldNodes) {
        const std::set<std::shared_ptr<GraphView>> nodeView = nodePtr->views();
        std::set<std::shared_ptr<GraphView>> intersection;
        std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
                            nodeView.begin(), nodeView.end(),
                            std::inserter(intersection, intersection.begin()));
        commonGraphViews = intersection;
    commonGraphViews.erase(oldGraph);
    commonGraphViews.erase(newGraph);
    if ((newNodes.size() > 0) && (oldOIn.size() != newOIn.size()) && (oldOOut.size() != newOOut.size())) {
        for (const auto& nodePtr : oldNodes) {
            nodePtr->removeView(oldGraph);
        }
        for (const auto& nodePtr : newNodes) {
            nodePtr->removeView(newGraph);
    if ((oldOIn.size() == newOIn.size()) &&
        (oldOOut.size() == newOOut.size())) {
        for (std::size_t i = 0; i < oldOIn.size(); ++i) {
            if (inputParents[i].first) {
                inputParents[i].first -> addChild(newOIn[i].first, inputParents[i].second, newOIn[i].second);
        for (std::size_t o = 0; o < oldOOut.size(); ++o) {
            if (outputChildren[o].first) {
                newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second);
            }
        }
    }
    else {
        // get the number of Parents for oldG->inputNodes()
        // get the number of Children for oldg->outputNodes()
        if (newNodes.size() == 0) {
            // Case 3
            if (oldOIn.size() == oldOOut.size()) {
                for (std::size_t i = 0; i < oldOIn.size(); ++i) {
                    if (inputParents[i].first) {
                      inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
            else if ((oldOIn.size() == 1) && (inputParents[0].first)) {
                for (std::size_t i = 0; i < oldOIn.size(); ++i) {
                    inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second);
                }
            }
        }
        else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes
            ((oldOIn.size() == 1) || (newOIn.size() == 1)) && // (oldOIn.size() == newOI.size()) already handled in Case 1
            ((oldOOut.size() == newOOut.size()))
            if ((oldOIn.size() == 1) && (inputParents[0].first)) {