Skip to content
Snippets Groups Projects
Commit 0d33ca05 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'seq_with_nodes_order' into 'dev'

Use the ordered nodes API in Sequential() in order to preserve the user declaration nodes order

See merge request !128
parents be0ce7f3 7f85a75c
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!128Use the ordered nodes API in Sequential() in order to preserve the user declaration nodes order
Pipeline #47069 passed
...@@ -42,8 +42,8 @@ Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::si ...@@ -42,8 +42,8 @@ Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::si
// Compute the number of bacthes depending on mDropLast boolean // Compute the number of bacthes depending on mDropLast boolean
mNbBatch = (mDropLast) ? mNbBatch = (mDropLast) ?
static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) : (mNbItems / mBatchSize) :
static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize)); static_cast<std::size_t>(std::ceil(mNbItems / static_cast<float>(mBatchSize)));
} }
std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const
......
...@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) ...@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs)
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for (const OpArgs& elt : inputs) { for (const OpArgs& elt : inputs) {
if(elt.node() != nullptr) { if(elt.node() != nullptr) {
// >= to allow incomplete graphViews // Connect the first output (ordered) of each output node (ordered)
assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); // to the next available input of the input node.
/* AIDGE_ASSERT(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size(),
* /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
* Prefer a functional description for detailed inputs elt.node()->getNbFreeDataInputs(), elt.node()->name(), elt.node()->type(), gv->outputNodes().size());
*/ std::set<NodePtr> connectedOutputs;
for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) { for (const auto& node_out : gv->getOrderedOutputs()) {
node_ptr -> addChild(elt.node()); // already checks that node_ptr->nbOutput() == 1 if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(elt.node(), node_out.second); // already checks that node_out->nbOutput() == 1
connectedOutputs.insert(node_out.first);
}
} }
gv->add(elt.node()); gv->add(elt.node());
} }
else { else {
for (std::shared_ptr<Node> node_in : elt.view()->inputNodes()) { // For each input node, connect the first output (ordered) of each
// >= to allow incomplete graphViews // output node (ordered) to the next available input
assert(static_cast<std::size_t>(node_in->getNbFreeDataInputs()) >= gv->outputNodes().size()); std::set<NodePtr> connectedInputs;
for (std::shared_ptr<Node> node_out : gv->outputNodes()) { for (const auto& node_in : elt.view()->getOrderedInputs()) {
node_out -> addChild(node_in); // assert one output Tensor per output Node if (connectedInputs.find(node_in.first) == connectedInputs.end()) {
AIDGE_ASSERT(static_cast<std::size_t>(node_in.first->getNbFreeDataInputs()) >= gv->outputNodes().size(),
"Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
node_in.first->getNbFreeDataInputs(), node_in.first->name(), node_in.first->type(), gv->outputNodes().size());
std::set<NodePtr> connectedOutputs;
for (const auto& node_out : gv->getOrderedOutputs()) {
if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
node_out.first -> addChild(node_in.first, node_out.second); // assert one output Tensor per output Node
connectedOutputs.insert(node_out.first);
}
}
connectedInputs.insert(node_in.first);
} }
} }
gv->add(elt.view()); gv->add(elt.view());
...@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) { ...@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) { std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = Sequential(inputs); std::shared_ptr<GraphView> gv = Sequential(inputs);
assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); AIDGE_ASSERT(gv->outputNodes().size() == 1U,
"Residual(): Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> lastNode = *gv->outputNodes().begin(); std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
assert(gv->inputNodes().size() == 2U && "Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection"); AIDGE_ASSERT(gv->inputNodes().size() == 2U,
"Residual(): Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> firstNode = nullptr; std::shared_ptr<Node> firstNode = nullptr;
for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) { for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) {
if (node_ptr != lastNode) { if (node_ptr != lastNode) {
firstNode = node_ptr; firstNode = node_ptr;
} }
} }
assert(lastNode->getNbFreeDataInputs()>=1); AIDGE_ASSERT(lastNode->getNbFreeDataInputs()>=1, "Residual(): missing a free data input for the output Node in order to connect the residual branch");
gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex); gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
return gv; return gv;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment