Skip to content
Snippets Groups Projects
Commit 5b2d9b4b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Use the ordered nodes API in Sequential() in order to preserve the user declaration nodes order

parent be0ce7f3
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!128Use the ordered nodes API in Sequential() in order to preserve the user declaration nodes order
Pipeline #47009 canceled
This commit is part of merge request !128. Comments created here will be created in the context of that merge request.
......@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs)
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for (const OpArgs& elt : inputs) {
if(elt.node() != nullptr) {
// >= to allow incomplete graphViews
assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size());
/*
* /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted.
* Prefer a functional description for detailed inputs
*/
for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) {
node_ptr -> addChild(elt.node()); // already checks that node_ptr->nbOutput() == 1
// Connect the first output (ordered) of each output node (ordered)
// to the next available input of the input node.
AIDGE_ASSERT(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size(),
"Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
elt.node()->getNbFreeDataInputs(), elt.node()->name(), elt.node()->type(), gv->outputNodes().size());
std::set<NodePtr> connectedOutputs;
for (const auto& node_out : gv->getOrderedOutputs()) {
if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(elt.node(), node_out.second); // already checks that node_out->nbOutput() == 1
connectedOutputs.insert(node_out.first);
}
}
gv->add(elt.node());
}
else {
for (std::shared_ptr<Node> node_in : elt.view()->inputNodes()) {
// >= to allow incomplete graphViews
assert(static_cast<std::size_t>(node_in->getNbFreeDataInputs()) >= gv->outputNodes().size());
for (std::shared_ptr<Node> node_out : gv->outputNodes()) {
node_out -> addChild(node_in); // assert one output Tensor per output Node
// For each input node, connect the first output (ordered) of each
// output node (ordered) to the next available input
std::set<NodePtr> connectedInputs;
for (const auto& node_in : elt.view()->getOrderedInputs()) {
if (connectedInputs.find(node_in.first) == connectedInputs.end()) {
AIDGE_ASSERT(static_cast<std::size_t>(node_in.first->getNbFreeDataInputs()) >= gv->outputNodes().size(),
"Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
node_in.first->getNbFreeDataInputs(), node_in.first->name(), node_in.first->type(), gv->outputNodes().size());
std::set<NodePtr> connectedOutputs;
for (const auto& node_out : gv->getOrderedOutputs()) {
if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(node_in.first, node_out.second); // assert one output Tensor per output Node
connectedOutputs.insert(node_out.first);
}
}
connectedInputs.insert(node_in.first);
}
}
gv->add(elt.view());
......@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = Sequential(inputs);
assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
AIDGE_ASSERT(gv->outputNodes().size() == 1U,
"Residual(): Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
assert(gv->inputNodes().size() == 2U && "Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
AIDGE_ASSERT(gv->inputNodes().size() == 2U,
"Residual(): Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> firstNode = nullptr;
for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) {
if (node_ptr != lastNode) {
firstNode = node_ptr;
}
}
assert(lastNode->getNbFreeDataInputs()>=1);
AIDGE_ASSERT(lastNode->getNbFreeDataInputs()>=1, "Residual(): missing a free data input for the output Node in order to connect the residual branch");
gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
return gv;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment