Skip to content
Snippets Groups Projects
Commit 5b2d9b4b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Use the ordered nodes API in Sequential() in order to preserve the user declaration nodes order

parent be0ce7f3
No related branches found
No related tags found
No related merge requests found
...@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) ...@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs)
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for (const OpArgs& elt : inputs) { for (const OpArgs& elt : inputs) {
if(elt.node() != nullptr) { if(elt.node() != nullptr) {
// >= to allow incomplete graphViews // Connect the first output (ordered) of each output node (ordered)
assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); // to the next available input of the input node.
/* AIDGE_ASSERT(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size(),
* /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
* Prefer a functional description for detailed inputs elt.node()->getNbFreeDataInputs(), elt.node()->name(), elt.node()->type(), gv->outputNodes().size());
*/ std::set<NodePtr> connectedOutputs;
for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) { for (const auto& node_out : gv->getOrderedOutputs()) {
node_ptr -> addChild(elt.node()); // already checks that node_ptr->nbOutput() == 1 if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(elt.node(), node_out.second); // already checks that node_out->nbOutput() == 1
connectedOutputs.insert(node_out.first);
}
} }
gv->add(elt.node()); gv->add(elt.node());
} }
else { else {
for (std::shared_ptr<Node> node_in : elt.view()->inputNodes()) { // For each input node, connect the first output (ordered) of each
// >= to allow incomplete graphViews // output node (ordered) to the next available input
assert(static_cast<std::size_t>(node_in->getNbFreeDataInputs()) >= gv->outputNodes().size()); std::set<NodePtr> connectedInputs;
for (std::shared_ptr<Node> node_out : gv->outputNodes()) { for (const auto& node_in : elt.view()->getOrderedInputs()) {
node_out -> addChild(node_in); // assert one output Tensor per output Node if (connectedInputs.find(node_in.first) == connectedInputs.end()) {
AIDGE_ASSERT(static_cast<std::size_t>(node_in.first->getNbFreeDataInputs()) >= gv->outputNodes().size(),
"Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
node_in.first->getNbFreeDataInputs(), node_in.first->name(), node_in.first->type(), gv->outputNodes().size());
std::set<NodePtr> connectedOutputs;
for (const auto& node_out : gv->getOrderedOutputs()) {
if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(node_in.first, node_out.second); // assert one output Tensor per output Node
connectedOutputs.insert(node_out.first);
}
}
connectedInputs.insert(node_in.first);
} }
} }
gv->add(elt.view()); gv->add(elt.view());
...@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) { ...@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) { std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = Sequential(inputs); std::shared_ptr<GraphView> gv = Sequential(inputs);
assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); AIDGE_ASSERT(gv->outputNodes().size() == 1U,
"Residual(): Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> lastNode = *gv->outputNodes().begin(); std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
assert(gv->inputNodes().size() == 2U && "Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection"); AIDGE_ASSERT(gv->inputNodes().size() == 2U,
"Residual(): Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> firstNode = nullptr; std::shared_ptr<Node> firstNode = nullptr;
for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) { for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) {
if (node_ptr != lastNode) { if (node_ptr != lastNode) {
firstNode = node_ptr; firstNode = node_ptr;
} }
} }
assert(lastNode->getNbFreeDataInputs()>=1); AIDGE_ASSERT(lastNode->getNbFreeDataInputs()>=1, "Residual(): missing a free data input for the output Node in order to connect the residual branch");
gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex); gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
return gv; return gv;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment