Skip to content
Snippets Groups Projects
Commit 0d33ca05 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'seq_with_nodes_order' into 'dev'

Use the ordered nodes API in Sequential() in order to preserve the user declaration nodes order

See merge request eclipse/aidge/aidge_core!128
parents be0ce7f3 7f85a75c
No related branches found
No related tags found
No related merge requests found
...@@ -42,8 +42,8 @@ Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::si ...@@ -42,8 +42,8 @@ Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::si
// Compute the number of bacthes depending on mDropLast boolean // Compute the number of bacthes depending on mDropLast boolean
mNbBatch = (mDropLast) ? mNbBatch = (mDropLast) ?
static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) : (mNbItems / mBatchSize) :
static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize)); static_cast<std::size_t>(std::ceil(mNbItems / static_cast<float>(mBatchSize)));
} }
std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const
......
...@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) ...@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs)
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for (const OpArgs& elt : inputs) { for (const OpArgs& elt : inputs) {
if(elt.node() != nullptr) { if(elt.node() != nullptr) {
// >= to allow incomplete graphViews // Connect the first output (ordered) of each output node (ordered)
assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); // to the next available input of the input node.
/* AIDGE_ASSERT(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size(),
* /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
* Prefer a functional description for detailed inputs elt.node()->getNbFreeDataInputs(), elt.node()->name(), elt.node()->type(), gv->outputNodes().size());
*/ std::set<NodePtr> connectedOutputs;
for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) { for (const auto& node_out : gv->getOrderedOutputs()) {
node_ptr -> addChild(elt.node()); // already checks that node_ptr->nbOutput() == 1 if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(elt.node(), node_out.second); // already checks that node_out->nbOutput() == 1
connectedOutputs.insert(node_out.first);
}
} }
gv->add(elt.node()); gv->add(elt.node());
} }
else { else {
for (std::shared_ptr<Node> node_in : elt.view()->inputNodes()) { // For each input node, connect the first output (ordered) of each
// >= to allow incomplete graphViews // output node (ordered) to the next available input
assert(static_cast<std::size_t>(node_in->getNbFreeDataInputs()) >= gv->outputNodes().size()); std::set<NodePtr> connectedInputs;
for (std::shared_ptr<Node> node_out : gv->outputNodes()) { for (const auto& node_in : elt.view()->getOrderedInputs()) {
node_out -> addChild(node_in); // assert one output Tensor per output Node if (connectedInputs.find(node_in.first) == connectedInputs.end()) {
AIDGE_ASSERT(static_cast<std::size_t>(node_in.first->getNbFreeDataInputs()) >= gv->outputNodes().size(),
"Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
node_in.first->getNbFreeDataInputs(), node_in.first->name(), node_in.first->type(), gv->outputNodes().size());
std::set<NodePtr> connectedOutputs;
for (const auto& node_out : gv->getOrderedOutputs()) {
if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(node_in.first, node_out.second); // assert one output Tensor per output Node
connectedOutputs.insert(node_out.first);
}
}
connectedInputs.insert(node_in.first);
} }
} }
gv->add(elt.view()); gv->add(elt.view());
...@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) { ...@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) { std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = Sequential(inputs); std::shared_ptr<GraphView> gv = Sequential(inputs);
assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); AIDGE_ASSERT(gv->outputNodes().size() == 1U,
"Residual(): Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> lastNode = *gv->outputNodes().begin(); std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
assert(gv->inputNodes().size() == 2U && "Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection"); AIDGE_ASSERT(gv->inputNodes().size() == 2U,
"Residual(): Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> firstNode = nullptr; std::shared_ptr<Node> firstNode = nullptr;
for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) { for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) {
if (node_ptr != lastNode) { if (node_ptr != lastNode) {
firstNode = node_ptr; firstNode = node_ptr;
} }
} }
assert(lastNode->getNbFreeDataInputs()>=1); AIDGE_ASSERT(lastNode->getNbFreeDataInputs()>=1, "Residual(): missing a free data input for the output Node in order to connect the residual branch");
gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex); gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
return gv; return gv;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment