Skip to content
Snippets Groups Projects
GraphView.cpp 54.9 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/graph/GraphView.hpp"
Cyril Moineau's avatar
Cyril Moineau committed

#include <algorithm>     // std::find, std::set_intersection, std::transform
Cyril Moineau's avatar
Cyril Moineau committed
#include <cassert>
#include <stdexcept>     // std::runtime_error
#include <cstddef>       // std::size_t
#include <cstdio>        // std::fclose, std::fopen
#include <iterator>      // std::back_inserter, std::distance, std::inserter,
                         // std::next
Maxence Naud's avatar
Maxence Naud committed
#include <map>
#include <memory>        // std::dynamic_pointer_cast, std::static_pointer_cast
Maxence Naud's avatar
Maxence Naud committed
#include <set>
#include <string>        // std::to_string
#include <utility>       // std::make_pair, std::pair
Maxence Naud's avatar
Maxence Naud committed
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MetaOperator.hpp"
Maxence Naud's avatar
Maxence Naud committed
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Directories.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
Cyril Moineau's avatar
Cyril Moineau committed


const std::shared_ptr<Aidge::Node> Aidge::GraphView::operator[](const std::string& nodeName) const {
    return (mNodeRegistry.find(nodeName) != mNodeRegistry.cend()) ? mNodeRegistry.at(nodeName) : nullptr;
}
Cyril Moineau's avatar
Cyril Moineau committed

///////////////////////////////////////////////////////
//        FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////

Aidge::Connector Aidge::GraphView::operator()(
    const std::vector<Aidge::Connector> ctors) {
  // TODO: allow for multiple inputNodes?
  assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
  std::shared_ptr<Node> inNode = *inputNodes().begin();
  assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n");
  for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
Cyril Moineau's avatar
Cyril Moineau committed
    assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
    (void)input; // avoid unused warning
Cyril Moineau's avatar
Cyril Moineau committed
  }

  IOIndex_t inID = 0;
  for (const Connector &ctor : ctors) {
Cyril Moineau's avatar
Cyril Moineau committed
    assert((ctor.node() != nullptr) &&
           "Input Connector must be associated with a node");
    ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
                          {inNode, inID++});
  }
  return Connector(*(outputNodes().begin()));
}

///////////////////////////////////////////////////////
//        INNER
///////////////////////////////////////////////////////

bool Aidge::GraphView::inView(const std::shared_ptr<Aidge::Node>& nodePtr) const {
    return mNodes.find(nodePtr) != mNodes.cend();
}
Olivier BICHLER's avatar
Olivier BICHLER committed
void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProducers) const {
Olivier BICHLER's avatar
Olivier BICHLER committed
    auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((path + ".mmd").c_str(), "w"), &std::fclose);

    if (!fp) {
        AIDGE_THROW_OR_ABORT(std::runtime_error,
            "Could not create graph view log file: {}", path + ".mmd");
    }

    fmt::print(fp.get(),
                "%%{{init: {{'flowchart': {{ 'curve': 'monotoneY'}}, "
                "'fontFamily': 'Verdana' }} }}%%\nflowchart TB\n\n");
Cyril Moineau's avatar
Cyril Moineau committed

    // Start by creating every node
    const auto namePtrTable = getRankedNodesName("{3}");
Cyril Moineau's avatar
Cyril Moineau committed

    for (const std::shared_ptr<Node> &node_ptr : mNodes) {
Olivier BICHLER's avatar
Olivier BICHLER committed
        std::string givenName =
Cyril Moineau's avatar
Cyril Moineau committed
            (node_ptr->name().empty())
                ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>"
                : "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\"";

        std::string nodeCls = "";
        if (node_ptr->type() == "Producer") {
          nodeCls = ":::producerCls";
        }
        else if (std::dynamic_pointer_cast<GenericOperator_Op>(node_ptr->getOperator())) {
          nodeCls = ":::genericCls";
        }
        else if (const auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node_ptr->getOperator())) {
          nodeCls = ":::metaCls";

          if (verbose) {
            metaOp->getMicroGraph()->save(path + "_" + node_ptr->type() + "#" + namePtrTable.at(node_ptr), verbose, showProducers);
          }
        }

        if (node_ptr == mRootNode) {
Olivier BICHLER's avatar
Olivier BICHLER committed
          if (nodeCls.empty()) {
            nodeCls = ":::rootCls";
          }
          else {
            nodeCls += "_rootCls";
          }
Maxence Naud's avatar
Maxence Naud committed
        if (node_ptr->type() != "Producer" || showProducers) {
            // if (node_ptr == mRootNode) {
            fmt::print(fp.get(), "{}_{}({}){}\n", node_ptr->type(), namePtrTable.at(node_ptr),
                        givenName, nodeCls);
            // }
Cyril Moineau's avatar
Cyril Moineau committed
    }
Cyril Moineau's avatar
Cyril Moineau committed
    // Write every link
    for (const std::shared_ptr<Node> &node_ptr : mNodes) {
      if ((node_ptr -> type() == "Producer") && !showProducers) {
        continue;
      }
      IOIndex_t outputIdx = 0;
Maxence Naud's avatar
Maxence Naud committed
      for (const auto& childs : node_ptr->getOrderedChildren()) {
        for (const auto& child : childs) {
          if (child != nullptr) {
            IOIndex_t inputIdx = 0;
            for (auto parent : child->inputs()) {
              if (parent.first == node_ptr && parent.second == outputIdx) {
                // Add-on to display the operator's output dimensions
                std::string dims = "";
                const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator());
                if (op && !op->getOutput(outputIdx)->dims().empty()) {
                  dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims());
                if (mNodes.find(child) != mNodes.end()) {
Olivier BICHLER's avatar
Olivier BICHLER committed
                  fmt::print(fp.get(), "{}_{}-->|\"{}{}&rarr;{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr),
                              outputIdx, dims, inputIdx, child->type(), namePtrTable.at(child));
                }
                else if (verbose) {
Olivier BICHLER's avatar
Olivier BICHLER committed
                  fmt::print(fp.get(), "{}_{}-->|\"{}{}&rarr;{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr),
                              outputIdx, dims, inputIdx, static_cast<void*>(child.get()));
Cyril Moineau's avatar
Cyril Moineau committed
        }
        ++outputIdx;
      }
    }

    size_t inputIdx = 0;
    for (auto input : mInputNodes) {
      if (input.first != nullptr) {
        fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"&rarr;{}\"|{}_{}\n", inputIdx, inputIdx,
Olivier BICHLER's avatar
Olivier BICHLER committed
                    input.second, input.first->type(), namePtrTable.at(input.first));
Olivier BICHLER's avatar
Olivier BICHLER committed
        fmt::print(fp.get(), "input{}((in#{})):::inputCls\n", inputIdx, inputIdx);
      ++inputIdx;
Cyril Moineau's avatar
Cyril Moineau committed
    }

    size_t outputIdx = 0;
    for (auto output : mOutputNodes) {
      if (output.first != nullptr) {
        // Add-on to display the operator's output dimensions
        std::string dims = "";
        const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator());
        if (op && op->getOutput(output.second) && !op->getOutput(output.second)->dims().empty()) {
          dims += " " + fmt::format("{}", op->getOutput(output.second)->dims());
        }
Olivier BICHLER's avatar
Olivier BICHLER committed
        fmt::print(fp.get(), "{}_{}--->|\"{}{}&rarr;\"|output{}((out#{})):::outputCls\n",
                    output.first->type(), namePtrTable.at(output.first), output.second,
                    dims, outputIdx, outputIdx);
Olivier BICHLER's avatar
Olivier BICHLER committed
        fmt::print(fp.get(), "output{}((out#{})):::outputCls\n", outputIdx, outputIdx);
Olivier BICHLER's avatar
Olivier BICHLER committed
    fmt::print(fp.get(), "classDef inputCls fill:#afa\n");
    fmt::print(fp.get(), "classDef outputCls fill:#ffa\n");
Loading
Loading full blame...