Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <algorithm> // std::find, std::set_intersection, std::transform
#include <stdexcept> // std::runtime_error
#include <cstddef> // std::size_t
#include <cstdio> // std::fclose, std::fopen
Olivier BICHLER
committed
#include <fmt/format.h>
#include <iterator> // std::back_inserter, std::distance, std::inserter,
// std::next
#include <memory> // std::dynamic_pointer_cast, std::static_pointer_cast
#include <string> // std::to_string
#include <utility> // std::make_pair, std::pair
Olivier BICHLER
committed
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/utils/Directories.hpp"
#include "aidge/utils/ErrorHandling.hpp"
const std::shared_ptr<Aidge::Node> Aidge::GraphView::operator[](const std::string& nodeName) const {
return (mNodeRegistry.find(nodeName) != mNodeRegistry.cend()) ? mNodeRegistry.at(nodeName) : nullptr;
}
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Aidge::Connector Aidge::GraphView::operator()(
const std::vector<Aidge::Connector> ctors) {
// TODO: allow for multiple inputNodes?
assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour");
std::shared_ptr<Node> inNode = *inputNodes().begin();
assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n");
for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) {
assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
assert((ctor.node() != nullptr) &&
"Input Connector must be associated with a node");
ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
{inNode, inID++});
}
return Connector(*(outputNodes().begin()));
}
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
bool Aidge::GraphView::inView(const std::shared_ptr<Aidge::Node>& nodePtr) const {
return mNodes.find(nodePtr) != mNodes.cend();
}
void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProducers) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((path + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Could not create graph view log file: {}", path + ".mmd");
}
fmt::print(fp.get(),
"%%{{init: {{'flowchart': {{ 'curve': 'monotoneY'}}, "
"'fontFamily': 'Verdana' }} }}%%\nflowchart TB\n\n");
const auto namePtrTable = getRankedNodesName("{3}");
Olivier BICHLER
committed
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>"
: "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\"";
std::string nodeCls = "";
if (node_ptr->type() == "Producer") {
nodeCls = ":::producerCls";
}
else if (std::dynamic_pointer_cast<GenericOperator_Op>(node_ptr->getOperator())) {
nodeCls = ":::genericCls";
}
else if (const auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node_ptr->getOperator())) {
nodeCls = ":::metaCls";
if (verbose) {
metaOp->getMicroGraph()->save(path + "_" + node_ptr->type() + "#" + namePtrTable.at(node_ptr), verbose, showProducers);
}
}
if (node_ptr == mRootNode) {
if (nodeCls.empty()) {
nodeCls = ":::rootCls";
}
else {
nodeCls += "_rootCls";
}
if (node_ptr->type() != "Producer" || showProducers) {
// if (node_ptr == mRootNode) {
fmt::print(fp.get(), "{}_{}({}){}\n", node_ptr->type(), namePtrTable.at(node_ptr),
givenName, nodeCls);
// }
// Write every link
for (const std::shared_ptr<Node> &node_ptr : mNodes) {

Maxence Naud
committed
if ((node_ptr -> type() == "Producer") && !showProducers) {
continue;
}
for (const auto& childs : node_ptr->getOrderedChildren()) {
for (const auto& child : childs) {
for (auto parent : child->inputs()) {
if (parent.first == node_ptr && parent.second == outputIdx) {
// Add-on to display the operator's output dimensions
std::string dims = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator());
if (op && !op->getOutput(outputIdx)->dims().empty()) {
Olivier BICHLER
committed
dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims());
if (mNodes.find(child) != mNodes.end()) {
fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr),
outputIdx, dims, inputIdx, child->type(), namePtrTable.at(child));
fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr),
outputIdx, dims, inputIdx, static_cast<void*>(child.get()));
++outputIdx;
}
}
size_t inputIdx = 0;
for (auto input : mInputNodes) {
if (input.first != nullptr) {
fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"→{}\"|{}_{}\n", inputIdx, inputIdx,
input.second, input.first->type(), namePtrTable.at(input.first));
fmt::print(fp.get(), "input{}((in#{})):::inputCls\n", inputIdx, inputIdx);
size_t outputIdx = 0;
for (auto output : mOutputNodes) {
if (output.first != nullptr) {
// Add-on to display the operator's output dimensions
std::string dims = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator());
if (op && op->getOutput(output.second) && !op->getOutput(output.second)->dims().empty()) {
dims += " " + fmt::format("{}", op->getOutput(output.second)->dims());
}
fmt::print(fp.get(), "{}_{}--->|\"{}{}→\"|output{}((out#{})):::outputCls\n",
output.first->type(), namePtrTable.at(output.first), output.second,
dims, outputIdx, outputIdx);
fmt::print(fp.get(), "output{}((out#{})):::outputCls\n", outputIdx, outputIdx);
fmt::print(fp.get(), "classDef inputCls fill:#afa\n");
fmt::print(fp.get(), "classDef outputCls fill:#ffa\n");
Loading
Loading full blame...