diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp index 5c3d1d7ef3b3dd8c779cf9cda737f1a2b2f6e01f..fc6b842edef17c80a4ef80667fc814bf85df25a4 100644 --- a/src/data/DataProvider.cpp +++ b/src/data/DataProvider.cpp @@ -42,8 +42,8 @@ Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::si // Compute the number of bacthes depending on mDropLast boolean mNbBatch = (mDropLast) ? - static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) : - static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize)); + (mNbItems / mBatchSize) : + static_cast<std::size_t>(std::ceil(mNbItems / static_cast<float>(mBatchSize))); } std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() const diff --git a/src/graph/OpArgs.cpp b/src/graph/OpArgs.cpp index 124878fc45fe632d4a584e76a0eae6e7acfd53b9..e1a378c3db0d79d7816e9882f790540cdc26cd88 100644 --- a/src/graph/OpArgs.cpp +++ b/src/graph/OpArgs.cpp @@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); for (const OpArgs& elt : inputs) { if(elt.node() != nullptr) { - // >= to allow incomplete graphViews - assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); - /* - * /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. - * Prefer a functional description for detailed inputs - */ - for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) { - node_ptr -> addChild(elt.node()); // already checks that node_ptr->nbOutput() == 1 + // Connect the first output (ordered) of each output node (ordered) + // to the next available input of the input node. + AIDGE_ASSERT(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size(), + "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})", + elt.node()->getNbFreeDataInputs(), elt.node()->name(), elt.node()->type(), gv->outputNodes().size()); + std::set<NodePtr> connectedOutputs; + for (const auto& node_out : gv->getOrderedOutputs()) { + if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) { + node_out.first -> addChild(elt.node(), node_out.second); // already checks that node_out->nbOutput() == 1 + connectedOutputs.insert(node_out.first); + } } gv->add(elt.node()); } else { - for (std::shared_ptr<Node> node_in : elt.view()->inputNodes()) { - // >= to allow incomplete graphViews - assert(static_cast<std::size_t>(node_in->getNbFreeDataInputs()) >= gv->outputNodes().size()); - for (std::shared_ptr<Node> node_out : gv->outputNodes()) { - node_out -> addChild(node_in); // assert one output Tensor per output Node + // For each input node, connect the first output (ordered) of each + // output node (ordered) to the next available input + std::set<NodePtr> connectedInputs; + for (const auto& node_in : elt.view()->getOrderedInputs()) { + if (connectedInputs.find(node_in.first) == connectedInputs.end()) { + AIDGE_ASSERT(static_cast<std::size_t>(node_in.first->getNbFreeDataInputs()) >= gv->outputNodes().size(), + "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})", + node_in.first->getNbFreeDataInputs(), node_in.first->name(), node_in.first->type(), gv->outputNodes().size()); + std::set<NodePtr> connectedOutputs; + for (const auto& node_out : gv->getOrderedOutputs()) { + if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) { + node_out.first -> addChild(node_in.first, node_out.second); // assert one output Tensor per output Node + connectedOutputs.insert(node_out.first); + } + } + connectedInputs.insert(node_in.first); } } gv->add(elt.view()); @@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) { std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) { std::shared_ptr<GraphView> gv = Sequential(inputs); - assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); + AIDGE_ASSERT(gv->outputNodes().size() == 1U, + "Residual(): Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); std::shared_ptr<Node> lastNode = *gv->outputNodes().begin(); - assert(gv->inputNodes().size() == 2U && "Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection"); + AIDGE_ASSERT(gv->inputNodes().size() == 2U, + "Residual(): Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection"); std::shared_ptr<Node> firstNode = nullptr; for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) { if (node_ptr != lastNode) { firstNode = node_ptr; } } - assert(lastNode->getNbFreeDataInputs()>=1); + AIDGE_ASSERT(lastNode->getNbFreeDataInputs()>=1, "Residual(): missing a free data input for the output Node in order to connect the residual branch"); gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex); return gv; }