Skip to content
Snippets Groups Projects
OpArgs.cpp 4.37 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) {
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
    for (const OpArgs& elt : inputs) {
        if(elt.node() != nullptr) {
            // Connect the first output (ordered) of each output node (ordered) 
            // to the next available input of the input node.
            AIDGE_ASSERT(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size(),
                "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
                elt.node()->getNbFreeDataInputs(), elt.node()->name(), elt.node()->type(), gv->outputNodes().size());
            std::set<NodePtr> connectedOutputs;
            for (const auto& node_out : gv->getOrderedOutputs()) {
                if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
                    node_out.first -> addChild(elt.node(), node_out.second); // already checks that node_out->nbOutput() == 1
                    connectedOutputs.insert(node_out.first);
                }
Cyril Moineau's avatar
Cyril Moineau committed
            }
            gv->add(elt.node());
        }
        else {
            // For each input node, connect the first output (ordered) of each 
            // output node (ordered) to the next available input
            std::set<NodePtr> connectedInputs;
            for (const auto& node_in : elt.view()->getOrderedInputs()) {
                if (connectedInputs.find(node_in.first) == connectedInputs.end()) {
                    AIDGE_ASSERT(static_cast<std::size_t>(node_in.first->getNbFreeDataInputs()) >= gv->outputNodes().size(),
                        "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
                        node_in.first->getNbFreeDataInputs(), node_in.first->name(), node_in.first->type(), gv->outputNodes().size());
                    std::set<NodePtr> connectedOutputs;
                    for (const auto& node_out : gv->getOrderedOutputs()) {
                        if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
                            node_out.first -> addChild(node_in.first, node_out.second); // assert one output Tensor per output Node
                            connectedOutputs.insert(node_out.first);
                        }
                    }
                    connectedInputs.insert(node_in.first);
Cyril Moineau's avatar
Cyril Moineau committed
                }
            }
            gv->add(elt.view());
        }
    }
    return gv;
}


std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
    for(const OpArgs& elt : inputs) {
        if (elt.node()!=nullptr)
            gv->add(elt.node());
        else
            gv->add(elt.view());
    }
    return gv;
}


std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<GraphView> gv = Sequential(inputs);
    AIDGE_ASSERT(gv->outputNodes().size() == 1U,
        "Residual(): Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
    AIDGE_ASSERT(gv->inputNodes().size() == 2U,
        "Residual(): Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<Node> firstNode = nullptr;
    for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) {
Cyril Moineau's avatar
Cyril Moineau committed
        if (node_ptr != lastNode) {
            firstNode = node_ptr;
        }
    }
    AIDGE_ASSERT(lastNode->getNbFreeDataInputs()>=1, "Residual(): missing a free data input for the output Node in order to connect the residual branch");
Cyril Moineau's avatar
Cyril Moineau committed
    gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
    return gv;