Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2165 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
GraphView.cpp 43.97 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>

#include "aidge/utils/Types.h"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

Aidge::Connector Aidge::GraphView::operator()(
    const std::vector<Aidge::Connector> ctors) {
  // TODO: allow for multiple inputNodes?
  assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
  std::shared_ptr<Node> inNode = *inputNodes().begin();
  assert((ctors.size() == static_cast<std::size_t>(inNode->nbDataInputs())) && "Wrong number of arguments.\n");
  for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
    assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
    (void)input; // avoid unused warning
  }

  IOIndex_t inID = 0;
  for (const Connector &ctor : ctors) {
    assert((ctor.node() != nullptr) &&
           "Input Connector must be associated with a node");
    ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
                          {inNode, inID++});
  }
  return Connector(*(outputNodes().begin()));
}

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////

std::string Aidge::GraphView::name() const { return mName; }

void Aidge::GraphView::setName(const std::string &name) { mName = name; }


void Aidge::GraphView::save(std::string path, bool verbose) const {
    FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
    std::fprintf(fp,
                "%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, "
                "'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");

    std::map<const std::string, std::size_t> typeCounter;
    std::map<std::shared_ptr<Node>, std::string> namePtrTable;

    // Start by creating every node
    for (const std::shared_ptr<Node> &node_ptr : mNodes) {
        const std::string currentType = node_ptr->type();
        if (typeCounter.find(currentType) == typeCounter.end())
        typeCounter[currentType] = 0;
        ++typeCounter[currentType];

        const std::string givenName =
            (node_ptr->name().empty())
                ? currentType + std::to_string(typeCounter[currentType])
                : node_ptr->name();
        namePtrTable[node_ptr] =
            (currentType + "_" + std::to_string(typeCounter[currentType]));
        std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
                    givenName.c_str());
    }
    // Write every link
    std::size_t emptyInputCounter = 0;
    for (const std::shared_ptr<Node> &node_ptr : mNodes) {
        for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) {
        if ((pa_ptr == nullptr) || !inView(pa_ptr)) {
            std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter,
                        emptyInputCounter, namePtrTable[node_ptr].c_str());
            ++emptyInputCounter;
        } else {
            std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(),
                        namePtrTable[node_ptr].c_str());
        }
        }
    }
    if (verbose) {
        for (const auto &c : typeCounter) {
        std::printf("%s - %zu\n", c.first.c_str(), c.second);
        }
    }

    std::fprintf(fp, "\n");
    std::fclose(fp);
}

///////////////////////////////////////////////////////
//        TENSOR MANAGEMENT
///////////////////////////////////////////////////////

void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
  AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs");

  std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes);
  for (auto input : inputs) {
    auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input);
    AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input");
    ignoredInputs.erase(it);
  }

  mInputNodes = inputs;
  mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end());
}

void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) {
  AIDGE_ASSERT(outputs.size() <= mOutputNodes.size(), "too many specified number of outputs");

  std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes);
  for (auto output : outputs) {
    auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output);
    AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output");
    ignoredOutputs.erase(it);
  }

  mOutputNodes = outputs;
  mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end());
}

Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const {
  IOIndex_t nbDataInput = 0;
  for (const std::shared_ptr<Node> &inNode : inputNodes()) {
    // We cannot simply add inNode->nbDataInputs(), as input nodes may already
    // have some inputs connected within the GraphView, which would therefore not
    // constitue inputs (from outside) for the GraphView!
    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
        inNode->dataInputs();

    for (const auto& input : inputNodeinputs) {
      if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
        ++nbDataInput;
      }
    }
  }
  return nbDataInput;
}

Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
  IOIndex_t nbIn = 0;
  // Free inputs within the GraphView are logically also free inputs from outside
  // the GraphView.
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
    nbIn += inputNode->getNbFreeDataInputs();
  }
  return nbIn;
}


std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
  std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;

  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
        inputNode->dataInputs();

    for (const auto& input : inputNodeinputs) {
      if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
        res.push_back(input);
      }
    }
  }
  return res;
}


std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
  std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;

  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
    const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
        inputNode->inputs();

    for (const auto& input : inputNodeinputs) {
      if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) {
        res.push_back(input);
      }
    }
  }
  return res;
}


std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs(std::string name) const {
  return mNodeRegistry.at(name)->inputs();
}

void Aidge::GraphView::forwardDims() {
    // setInputs
    // Link every tensor to the right pointer
    // following parent - children informations
    for (std::shared_ptr<Node> nodePtr : getNodes()) {
        for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
            // assess if the input was not already set and is a Tensor then link it to parent output
            std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
            if (inputI.first) {
              if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
                  if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) {
                    // assert provided Data is of "Tensor" type
                    nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
                  }
                  else {
                    assert(false && "Non-tensor entries not handled yet.\n");
                  }
              }
            } else
            {
              assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
            }

        }
    }
    // Compute dimensions of every node
    _forwardDims(inputNodes());
}

void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
  // TODO: support multi-inputs/outputs
  std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
  for (std::shared_ptr<Node> nodePtr : listNodes) {
    if (!nodePtr->getOperator()->outputDimsForwarded()) {
      nodePtr->getOperator()->computeOutputDims();
    }
    if (!nodePtr->getOperator()->outputDimsForwarded()) {
      nextList.insert(nodePtr);
    } else {
      std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
      nextList.insert(children.begin(), children.end());
    }
  }
  if (nextList.empty()) {
    for (std::shared_ptr<Node> nodePtr : getNodes()) {
      if (!nodePtr->getOperator()->outputDimsForwarded()) {
        nextList.insert(nodePtr);
      }
    }
  }
  if (!nextList.empty()) {
    _forwardDims(nextList);
  }
}

void Aidge::GraphView::setBackend(const std::string &backend) {
  for (auto node : getNodes()) {
    node->getOperator()->setBackend(backend);
  }
}

void Aidge::GraphView::setDatatype(const DataType &datatype) {
  for (auto node : getNodes()) {
    node->getOperator()->setDatatype(datatype);
  }
}
/*
void Aidge::GraphView::updateOutputNodes() {
  mOutputNodes.clear();
  for (const std::shared_ptr<Node>& go_it : mNodes) {
    if (go_it->nbOutputs() !=
        go_it->nbValidOutputs()) { // an output linked to nothing
      mOutputNodes.insert(go_it);
      continue;
    }
    for (const std::shared_ptr<Node>& ch_ptr : go_it->getChildren()) {
      if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
        mOutputNodes.insert(go_it);
        break;
      }
    }
  }
}

void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) {
  if (node->nbOutputs() !=
      node->nbValidOutputs()) { // an output linked to nothing
    mOutputNodes.insert(node);
  } else { // don't enter if was already added to outputNodes
    for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) {
      if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
        mOutputNodes.insert(node);
        break;
      }
    }
  }
  // update other outputNodes
  for (const std::shared_ptr<Node> &pa_ptr :
       node->getParents()) { // check if any parent is in OutputNodes too
    if ((pa_ptr != nullptr) &&
        (mOutputNodes.find(pa_ptr) !=
         mOutputNodes.end())) { // it's a match! Must check if the outputNode
                                // found is still an outputNode
      bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs());
      for (const std::shared_ptr<Node>& ch_ptr : pa_ptr->getChildren()) {
        if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
          remove = false;
          break;
        }
      }
      if (remove) {
        mOutputNodes.erase(pa_ptr);
      }
    }
  }
}
*/
std::vector<
    std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs() const {
  std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
      outsideOutputs;
  for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
    const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>>
        outputNodeOutputs = outputNode->outputs();

    for (const auto& outputPos : outputNodeOutputs) {
      // Keep only the nodes connected at this output position that are outside the GraphView
      std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos;
      for (const auto& output : outputPos) {
        if (mNodes.find(output.first) == mNodes.end()) {
          outsideOutputPos.push_back(output);
        }
      }

      outsideOutputs.push_back(outsideOutputPos);
    }
  }
  return outsideOutputs;
}

std::vector<
    std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::GraphView::outputs(std::string nodeName) const {
  return mNodeRegistry.at(nodeName)->outputs();
}

void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
                               Aidge::IOIndex_t /*newNodeOutID*/) {
  printf("Not implemented yet.\n");
}

void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
  // add to the GraphView nodes
  node->addView(shared_from_this());
  mNodes.insert(node);
  if (!(node->name()).empty())
    mNodeRegistry.insert(std::make_pair(node->name(), node));

  // check if the node is an input/output node
  updateInputsOutputsNew(node);

  // add learnable parameters to the graph
  if (includeLearnableParam) {
    for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) {
      std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
      if (parentNode) {
          parentNode->addView(shared_from_this());
          mNodes.insert(parentNode);
          if (!(parentNode->name()).empty())
            mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode));
          // check if the parentNode is an input/output node
          updateInputsOutputsNew(parentNode);
      }
    }
  }
}

void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
  // List only the nodes that are not already present in current graph
  std::set<NodePtr> nodesToAdd;
  std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));

  do {
    std::set<NodePtr> nextNodesToAdd;

    // Find nodes that are direct parent of current GraphView and add them first
    // such that the obtained GraphView inputs list will be the same, regardless 
    // of the evaluation order of those nodes
    // (i.e. one of their child is in current GraphView)
    for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
      for (auto child : (*it)->getChildren()) {
        if (mNodes.find(child) != mNodes.end()) {
          nextNodesToAdd.insert(*it);
          it = nodesToAdd.erase(it);
          break;
        }
      }
      if (it == nodesToAdd.end()) {
        break;
      }
    }

    // If there is no more parent, find nodes that are direct children of current GraphView,
    // such that the obtained GraphView outputs list will be the same, regardless 
    // of the evaluation order of those nodes
    // (i.e. one of their parent is in current GraphView)
    // TODO: this might be done simultaneously with direct parents, by removing
    // the empty() condition, but there might be edge cases that may change
    // the resulting inputs/outputs order depending on evaluation order (???)
    if (nextNodesToAdd.empty()) {
      for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
        for (auto parent : (*it)->getParents()) {
          if (mNodes.find(parent) != mNodes.end()) {
            nextNodesToAdd.insert(*it);
            it = nodesToAdd.erase(it);
            break;
          }
        }
        if (it == nodesToAdd.end()) {
          break;
        }
      }
    }

    // If no node if found, there might be remaining nodes that form an independant sub-graph
    // In this case, additionnal inputs/outputs will be added at the end of
    // the GraphView inputs/outputs list, in no particular order.
    // TODO: we might try to preserve the initial inputs/ouputs relative order of those nodes
    // if they actually comes from a GraphView, but I think that would be a far-fetched expectation
    // from the users...
    if (nextNodesToAdd.empty()) {
      nodesToAdd.swap(nextNodesToAdd);
    }

    // Add selected nodes in the current GraphView, in no particular order
    for (auto node_ptr : nextNodesToAdd) {
      add(node_ptr, includeLearnableParam);
    }
  }
  while (!nodesToAdd.empty());
}

void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
  add(graph->getNodes(), false);
}

void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
                               std::shared_ptr<Node> fromOutNode,
                               const Aidge::IOIndex_t fromTensor,
                               Aidge::IOIndex_t toTensor) {
  if (fromOutNode)
    assert(inView(fromOutNode) && "Output Node not found in the GraphView.");
  else {
    assert((outputNodes().size() == 1U) &&
           "Must specify an outputNode or have only one.");
    fromOutNode = *(outputNodes().begin());
  }
  fromOutNode->addChild(toOtherNode, fromTensor, toTensor);
  add(toOtherNode);
}

void Aidge::GraphView::addChild(
    std::shared_ptr<GraphView> toOtherView,
    std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode,
    std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) {
  // assert output node is valid
  if (!fromOutNode.first) {
    assert(outputNodes().size() == 1U &&
           "If no output node is provided, the graph should have only one to "
           "make the choice explicit.");
    fromOutNode.first = *(outputNodes().begin());
  } else
    assert(inView(fromOutNode.first));
  // assert input node is valid
  if (!toNode.first) {
    assert(toOtherView->inputNodes().size() == 1U &&
           "If no intput node is provided, the other graph should have only "
           "one to make the choice explicit.");
    toNode.first = *(toOtherView->inputNodes().begin());
  } else {
    assert(toOtherView->inView(toNode.first));
  }
  // Tensor assertions are performed in the Node adChild method
  fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second);
  // once linking performed, add other graph to current graph
  add(toOtherView);
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const {
  // TODO: choose if we return a set or a vector
  std::set<std::shared_ptr<Node>> parents;
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
    parents.insert(inputNode->getParents().begin(),
                   inputNode->getParents().end());
  }
  return parents;
}

std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const {
  std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName);
  if (it == mNodeRegistry.end()) {
    printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str());
    exit(-1);
  }
  return (it->second)->getParents();
}

std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getOrderedParents() const {
  std::vector<std::vector<std::shared_ptr<Node>>> parents;
  for (const std::shared_ptr<Node>& inputNode : inputNodes()) {
    parents.push_back(inputNode->getParents());
  }
  return parents;
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const {
  std::set<std::shared_ptr<Node>> children;
  for (const std::shared_ptr<Node>& outputNode : outputNodes()) {
    children.insert((outputNode->getChildren()).begin(),
                    (outputNode->getChildren()).end());
  }
  return children;
}

std::vector<std::vector<std::shared_ptr<Aidge::Node>>>
Aidge::GraphView::getChildren(const std::string nodeName) const {
  std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
      mNodeRegistry.find(nodeName);
  if (it == mNodeRegistry.end()) {
    printf("No such node a %s in %s graph.\n", nodeName.c_str(),
           name().c_str());
    exit(-1);
  }
  return (it->second)->getOrderedChildren();
}

std::set<std::shared_ptr<Aidge::Node>>
Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const {
  std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode);
  if (it == mNodes.end()) {
    printf("No such node in graph.\n");
    exit(-1);
  }
  return (*it)->getChildren();
}


std::shared_ptr<Aidge::Node>
Aidge::GraphView::getNode(const std::string& nodeName) const {
  std::map<std::string, std::shared_ptr<Node>>::const_iterator it =
      mNodeRegistry.find(nodeName);
  if (it != mNodeRegistry.end()) {
    return it->second;
  } else {
    printf("No Node named %s in the current GraphView.\n", nodeName.c_str());
    exit(-1);
  }
}


void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) {
  // remove learnable params
  if (includeLearnableParam) {
    for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) {
      auto inputI = nodePtr->input(i);
      bool removeNode = true;
      for (const auto& parentOutput : inputI.first->outputs()) {
        for (const auto& childOfParentOutput : parentOutput) {
          // only remove the learnable parameter if not related to any other Node in the GraphView
          if (childOfParentOutput.first != nodePtr) {
            removeNode = false;
            break;
          }
        }
      }
      if (removeNode) {
        // assert Learnable Parameter in the GraphView scope
        if (mNodes.find(inputI.first) != mNodes.end()) {
          mNodes.erase(inputI.first);
          inputI.first->removeView(shared_from_this());
        }
        if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); }

        // check if the node was an input/output node
        updateInputsOutputsDelete(inputI.first);
      }
    }
  }

  if (mNodes.find(nodePtr) != mNodes.end()) {
    mNodes.erase(nodePtr);
    nodePtr->removeView(shared_from_this());
  }
  if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); }

  // check if the nodePtr was an input/output node
  updateInputsOutputsDelete(nodePtr);
}


bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) {
  printf("Swap() not implementated yet. Return false.\n");
  return false;
}

void Aidge::GraphView::link(std::string /*name1_inID*/,
                           std::string /*name2_outID*/) {
  printf("Not implemented yet.\n");
}

void Aidge::GraphView::insertParent(NodePtr childNode,
                  NodePtr newParentNode,
                  IOIndex_t childInputTensorIdx,
                  IOIndex_t newParentInputTensorIdx,
                  IOIndex_t newParentOutputTensorIdx){
  NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
  const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
  // Remove child from current parent & current Parent from child
  currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);

  // Add child
  currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
  newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);

  add(newParentNode);
}


bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) {

    // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
    // How to distinguish it from data input?
    // TODO: Parameter Tensors could be identified with their dimensions
    // TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
    // It also avoids specifying each producer since they are automatically included

    auto oldG = std::make_shared<GraphView>("oldG");
    oldG->add(oldNodes, false);
    auto newG = std::make_shared<GraphView>("newG");
    newG->add(newNodes, false);

    if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) {
        return false;
    }
    if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) ||
                                (newG->outputNodes().size() != 1))) {
        return false;
    }

    // there is at least one inputNode in the old/new GraphView
    std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin());
    std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin());

    // find Node to link to new input Node
    //compute number of input for firstPreviousInputNode not in oldNodes set
    std::size_t nbExternalInputs = 0;
    std::shared_ptr<Node> externalInput = nullptr;
    IOIndex_t externalInputId = gk_IODefaultIndex;
    for (const auto& input : firstPreviousInputNode->inputs()) {
        if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG
            nbExternalInputs++;
            externalInput = input.first;
            externalInputId = input.second;
        }
    }
    if (nbExternalInputs > 1) {
        AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set");
    }

    if (oldG->inputNodes().size() > 1){
        // one or no input has been identified. Checking every input points to the same source
        for (const auto& previousInputNode : oldG->inputNodes()) {
            for (const auto& input : previousInputNode->inputs()) {
                if (oldNodes.find(input.first) == oldNodes.end()) {
                    if ( (externalInput != input.first) || (externalInputId != input.second) ) {
                        return false; // an inputNode points to an external Node different from the registered one
                    }
                }
            }
        }
    }

    if (firstPreviousOutputNode->nbOutputs() != 1) {
        return false;
    }

    // find Node to replicate output connections
    std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin());

    auto copyOutputs = firstPreviousOutputNode->outputs();
    // manage Views for newNodes
    // only keep common views to each node for the new set
    std::set<std::shared_ptr<GraphView>> commonGraphViews =  (*oldNodes.begin())->views();
    for (const auto& nodePtr : oldNodes) {
      const auto nodeView = nodePtr->views();
      std::set<std::shared_ptr<GraphView>> intersection;
      std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
                          nodeView.begin(), nodeView.end(),
                          std::inserter(intersection, intersection.begin()));
      commonGraphViews = intersection;
    }
    commonGraphViews.erase(oldG);
    commonGraphViews.erase(newG);

    // clean Nodes to replace
    // Do not include common Nodes to avoid cleaning Producers linked to newNodes
    std::set<std::shared_ptr<Node>> nodesToClean;
    std::set_difference(oldNodes.begin(), oldNodes.end(),
                          newNodes.begin(), newNodes.end(),
                          std::inserter(nodesToClean, nodesToClean.begin()));
    for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }

    // copy output connections
    if (newOutputNode) {
        for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
            auto outputPairs = copyOutputs[o];
            for (const auto& onePair : outputPairs) {
                newOutputNode->addChild(onePair.first, o, onePair.second);
            }
        }
    }

    // copy input connections
    if (!newNodes.empty() && externalInput) {
        for (const auto& newInputNode : newG->inputNodes()) {
            IOIndex_t inputId = 0;
            for (const auto& input : newInputNode->inputs()) {
                if (newNodes.find(input.first) == newNodes.end()) {
                    externalInput->addChild(newInputNode, externalInputId, inputId);
                }
                inputId++;
            }
        }
    }

    // insert new Nodes in the right GraphViews
    for (const auto& graphPtr : commonGraphViews) {
        graphPtr->add(newNodes, false);
        if (newNodes.empty()) {
            // TODO: FIXME: this function should not be called anymore!
            graphPtr->updateInputsOutputsNodes();
        }
    }

    for (const auto& node : oldNodes) {
      node->removeView(oldG);
    }
    for (const auto& node : newNodes) {
      node->removeView(newG);
    }
    return true;
}

/*
void Aidge::GraphView::updateInputNodes() {
  std::set<std::pair<NodePtr, IOIndex_t>> inputNodes;
  for (const std::shared_ptr<Node>& go_ptr : mNodes) {
    size_t inputIdx = 0;
    for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
      if ((pa_ptr == nullptr) ||
          (mNodes.find(pa_ptr) ==
           mNodes.end())) { // Parent doesn't exist || Parent not in the graph
        inputNodes.insert(std::make_pair(go_ptr, inputIdx));
      }
      ++inputIdx;
    }
  }

  // Remove inputs that are not input anymore (deleted node or input connected internally)
  for (auto it = mInputNodes.begin(); it != mInputNodes.end(); ++it) {
    if (inputNodes.find(*it) == inputNodes.end()) {
      it = mInputNodes.erase(it);
    }
  }

  // Add remaining new inputs
  for (auto inputNode : inputNodes) {
    mInputNodes.push_back(inputNode);
  }
}
*/

void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
  // Can be called several times with the same node, e.g. when addChild() is
  // called on a node already part of the GraphView. In this case, inputs/outputs
  // need to be updated!
  std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();

  // Remove inputs that are not input anymore because connected to newNode
  for (auto orderedChilds : newNode->getOrderedChildren()) {
    for (auto ch_ptr : orderedChilds) {
      // Check that newNode child is in current GraphView
      if (mNodes.find(ch_ptr) != mNodes.end()) {
        IOIndex_t inputIdx = 0;
        for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
          // If newNode is connected to it
          if (pa_ptr == newNode) {
            const auto val = std::make_pair(ch_ptr, inputIdx);
            const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);

            // Check that it was not already the case (if node UPDATE)
            if (iter != mInputNodes.end()) {
              // The first old (removed) input becomes the insertion point for newNode GraphView inputs
              if (std::distance(newInputsInsertionPoint, iter) <= 0) {
                newInputsInsertionPoint = mInputNodes.erase(iter);
              }
              else {
                mInputNodes.erase(iter);
              }
            }
          }
          ++inputIdx;
        }
      }
    }
  }

  // Check if node inputs are inputs for the GraphView and add them to the input list if so
  // Inputs addition order follows node inputs order
  // Inputs are inserted at the position of the first input removed
  IOIndex_t inputIdx = 0U;
  for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
    if ((pa_ptr == nullptr) ||
        (mNodes.find(pa_ptr) ==
        mNodes.end())) { // Parent doesn't exist || Parent not in the graph
      const auto val = std::make_pair(newNode, inputIdx);
      if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
        newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
      }
    }
    ++inputIdx;
  }

  // (if node UPDATE)
  // newNode may already exists in the graph and may have been updated
  // Check and remove inputs that are not inputs anymore
  inputIdx = 0U;
  for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) {
    if ((pa_ptr != nullptr) &&
        (mNodes.find(pa_ptr) !=
        mNodes.end())) {
      const auto val = std::make_pair(newNode, inputIdx);
      auto it = std::find(mInputNodes.begin(), mInputNodes.end(), val);
      if (it != mInputNodes.end()) {
        mInputNodes.erase(it);
      }
    }
    ++inputIdx;
  }

  std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();

  // Remove outputs that are not output anymore because connected to newNode
  for (const std::shared_ptr<Node>& parent : newNode->getParents()) {
    // Check that newNode parent is in current GraphView
    if (mNodes.find(parent) != mNodes.end()) {
      for (auto orderedChilds : parent->getOrderedChildren()) {
        IOIndex_t outputIdx = 0;
        for (auto ch_ptr : orderedChilds) {
          // If newNode is connected to it
          if (ch_ptr == newNode) {
            const auto val = std::make_pair(parent, outputIdx);
            const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);

            if (iter != mOutputNodes.end()) {
              // The first old (removed) output becomes the insertion point for newNode GraphView outputs
              if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
                newOutputsInsertionPoint = mOutputNodes.erase(iter);
              }
              else {
                mOutputNodes.erase(iter);
              }
            }
          }
        }
        ++outputIdx;
      }
    }
  }

  // Check if node outputs are outputs for the GraphView and add them to the output list if so
  IOIndex_t outputIdx = 0;
  for (auto orderedChilds : newNode->getOrderedChildren()) {
    bool noInsideConnection = true;
    for (auto ch_ptr : orderedChilds) {
      if (mNodes.find(ch_ptr) != mNodes.end()) {
        noInsideConnection = false;
        break;
      }
    }

    if (noInsideConnection) {
      const auto val = std::make_pair(newNode, outputIdx);
      // Output may be already be present (see addChild() with a node already in GraphView)
      if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
        newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
      }
    }
    ++outputIdx;
  }
}
void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) {
  std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newInputsInsertionPoint = mInputNodes.end();

  // Check if node inputs were inputs for the GraphView and remove them from the list if so
  for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) {
    const auto val = std::make_pair(deletedNode, inputIdx);
    const auto iter = std::find(mInputNodes.begin(), mInputNodes.end(), val);

    if (iter != mInputNodes.end()) {
      // The first old (removed) input becomes the insertion point for newNode GraphView inputs
      if (std::distance(newInputsInsertionPoint, iter) <= 0) {
        newInputsInsertionPoint = mInputNodes.erase(iter);
      }
      else {
        mInputNodes.erase(iter);
      }
    }
  }

  // Add child node inputs that become GraphView input following the removal of the node
  // Inputs addition order follows deletedNode outputs order
  for (auto orderedChilds : deletedNode->getOrderedChildren()) {
    for (auto ch_ptr : orderedChilds) {
      // Check that deletedNode child is in current GraphView
      if (mNodes.find(ch_ptr) != mNodes.end()) {
        IOIndex_t inputIdx = 0;
        for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
          // If newNode was connected to it
          if (pa_ptr == deletedNode) {
            const auto val = std::make_pair(ch_ptr, inputIdx);
            AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end());
            newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
          }
          ++inputIdx;
        }
      }
    }
  }
  
  std::vector<std::pair<NodePtr, IOIndex_t>>::iterator newOutputsInsertionPoint = mOutputNodes.end();

  // Check if node outputs were outputs for the GraphView and remove them from the list if so
  for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) {
    const auto val = std::make_pair(deletedNode, outputIdx);
    const auto iter = std::find(mOutputNodes.begin(), mOutputNodes.end(), val);

    if (iter != mOutputNodes.end()) {
      // The first old (removed) output becomes the insertion point for newNode GraphView outputs
      if (std::distance(newOutputsInsertionPoint, iter) <= 0) {
        newOutputsInsertionPoint = mOutputNodes.erase(iter);
      }
      else {
        mOutputNodes.erase(iter);
      }
    }
  }

  // Add parent node outputs that become GraphView output following the removal of the node
  // Outputs addition order follows deletedNode inputs order
  for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) {
    if (parent == nullptr) {
      continue;
    }

    IOIndex_t outputIdx = 0;
    for (auto orderedChilds : parent->getOrderedChildren()) {
      bool noInsideConnection = true;
      for (auto ch_ptr : orderedChilds) {
        if (mNodes.find(ch_ptr) != mNodes.end()) {
          noInsideConnection = false;
          break;
        }
      }

      if (noInsideConnection) {
        const auto val = std::make_pair(parent, outputIdx);
        AIDGE_INTERNAL_ASSERT(std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end());
        newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
      }
      ++outputIdx;
    }
  }
}

void Aidge::GraphView::updateInputsOutputsNodes() {
  mInputNodes.clear();
  for (const std::shared_ptr<Node>& go_ptr : mNodes) {
    IOIndex_t inputIdx = 0;
    for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) {
      if ((pa_ptr == nullptr) ||
          (mNodes.find(pa_ptr) ==
           mNodes.end())) { // Parent doesn't exist || Parent not in the graph
        mInputNodes.push_back(std::make_pair(go_ptr, inputIdx));
      }

      ++inputIdx;
    }
  }

  mOutputNodes.clear();
  for (const std::shared_ptr<Node>& go_ptr : mNodes) {
    IOIndex_t outputIdx = 0;
    for (auto orderedChilds : go_ptr->getOrderedChildren()) {
      for (auto ch_ptr : orderedChilds) {
        if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph
          mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx));
        }
      }
      if (orderedChilds.empty()) {
          // an output linked to nothing
          mOutputNodes.push_back(std::make_pair(go_ptr, outputIdx));
      }
      ++outputIdx;
    }
  }
}

/*
void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) {
  // add node_ptr to inputNode if it can
  std::size_t filledWithKnownInputs = 0U;
  bool wasAdded = mInputNodes.find(node) != mInputNodes.end();
  for (const std::shared_ptr<Node>& pa_ptr : node->getParents()) {
    if ((pa_ptr == nullptr) ||
        (mNodes.find(pa_ptr) ==
         mNodes.end())) { // Parent doesn't exist || Parent not in the graph
      mInputNodes.insert(node);
      wasAdded = true;
      break;
    }
    ++filledWithKnownInputs;
  }
  if (filledWithKnownInputs == node->nbInputs() && wasAdded) {
    mInputNodes.erase(node);
  }
  // update other inputNodes
  for (const std::shared_ptr<Node>& ch_ptr :
       node->getChildren()) { // check if any child is in InputNodes too
    if (mInputNodes.find(ch_ptr) !=
        mInputNodes.end()) { // it's a match! Must check if the inputNode found
                             // is still an inputNode
                             // change here
      bool remove = true;
      for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) {
        if (pa_ptr == nullptr ||
            mNodes.find(pa_ptr) ==
                mNodes
                    .end()) { // Parent doesn't exist || Parent not in the graph
          remove = false;
          break;
        }
      }
      if (remove) {
        mInputNodes.erase(ch_ptr);
      }
    }
  }
}
*/
/*
void Aidge::GraphView::removeInputNode(const std::string nodeName) {
  std::map<std::string, std::shared_ptr<Node>>::iterator it =
      mNodeRegistry.find(nodeName);
  if (it != mNodeRegistry.end()) {
    const std::shared_ptr<Node> val = (*it).second;
    if (mInputNodes.find(val) != mInputNodes.end()) {
      mInputNodes.erase(val);
    }
  }
}

void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
  std::map<std::string, std::shared_ptr<Node>>::iterator it =
      mNodeRegistry.find(nodeName);
  if (it != mNodeRegistry.end()) {
    const std::shared_ptr<Node> val = (*it).second;
    if (mOutputNodes.find(val) != mOutputNodes.end()) {
      mOutputNodes.erase(val);
    }
  }
}
*/
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const {
  std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);

  // Map for old node -> new node correspondance
  std::map<NodePtr, NodePtr> oldToNewNodes;

  for (const std::shared_ptr<Node> &node_ptr : mNodes) {
    oldToNewNodes[node_ptr] = cloneNode(node_ptr);
  }

  // For each node, convert old node -> new node connections
  for (auto &oldToNewNode : oldToNewNodes) {
    if (oldToNewNode.second == nullptr)
      continue;  // deleted node

    // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
    size_t parentId = 0;
    for (auto parent : oldToNewNode.first->inputs()) {
      while (oldToNewNodes[parent.first] == nullptr) {
        // Find next valid parent in line, going backward in the graph
        AIDGE_ASSERT(parent.first->getChildren().size() == 1, "deleted nodes in GraphView::clone() cannot have multiple children");
        AIDGE_ASSERT(parent.first->nbDataInputs() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents");
        const auto& parents = parent.first->dataInputs();

        if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
          && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
        {
          parent = parents[0];
        }
        else {
          break;
        }
      }

      if (oldToNewNodes[parent.first]) {
        AIDGE_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs(),
          "next valid parent after deleted nodes in GraphView::clone() has wrong number of outputs");
        oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
      }

      ++parentId;
    }
  }

  // Once connected, add each new nodes to new GraphView
  // This has to be done in a second step to ensure that new GraphView inputs/outputs
  // are properly set (otherwise, some node's inputs/outputs may be wrongly registered as
  // GraphView inputs/outputs because not yet connected to other nodes)
  for (auto &oldToNewNode : oldToNewNodes) {
    if (oldToNewNode.second == nullptr)
      continue;  // deleted node

    newGraph->add(oldToNewNode.second, false);
  }

  // Update cloned graph inputs/outputs order to match initial graph order
  auto newInputNodes = mInputNodes;
  for (auto it = newInputNodes.begin(); it != newInputNodes.end(); ++it) {
    // If input node was removed, find next valid input
    while (oldToNewNodes[it->first] == nullptr) {
      // Removed node should have only one connected output, otherwise cloning is invalid
      AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() == 1);
      auto child = *it->first->getChildren().begin();

      bool found = false;
      std::size_t inputIdx = 0;
      for (auto parent : child->getParents()) {
        if (parent == it->first) {
          it->first = child;
          it->second = inputIdx;
          found = true;
          break;
        }
        ++inputIdx;
      }

      if (!found) {
        break;
      }
    }

    if (oldToNewNodes[it->first] == nullptr) {
      it = newInputNodes.erase(it);
    }
    else {
      it->first = oldToNewNodes[it->first];
    }
  }
  newGraph->setOrderedInputs(newInputNodes);

  auto newOutputNodes = mOutputNodes;
  for (auto it = newOutputNodes.begin(); it != newOutputNodes.end(); ++it) {
    // If output node was removed, find previous valid output
    while (oldToNewNodes[it->first] == nullptr) {
      // Removed node should have only one connected data input, otherwise cloning is invalid
      AIDGE_INTERNAL_ASSERT(it->first->nbDataInputs() <= 1);
      auto parents = it->first->dataInputs();
      if (!parents.empty()) {
        *it = parents[0];
      }
      else {
        break;
      }
    }

    if (oldToNewNodes[it->first] == nullptr) {
      it = newOutputNodes.erase(it);
    }
    else {
      it->first = oldToNewNodes[it->first];
    }
  }
  newGraph->setOrderedOutputs(newOutputNodes);

  return newGraph;
}