From e8820d66990e620a4ce0cad5dd20ac95816feef5 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 10 Nov 2023 09:32:28 +0000 Subject: [PATCH] Also change nbDataInputs() to nbData() in Node.hpp --- include/aidge/graph/Node.hpp | 10 +++++----- src/graph/GraphView.cpp | 10 +++++----- src/graph/Node.cpp | 10 +++++----- unit_tests/graph/Test_GraphView.cpp | 2 +- unit_tests/operator/Test_MetaOperator.cpp | 4 ++-- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index f1d0a39d4..384aa64cd 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -178,9 +178,9 @@ public: */ inline IOIndex_t getFirstFreeDataInput() const { IOIndex_t i = 0; - for (; (i < nbDataInputs()) && (input(i).second != gk_IODefaultIndex); ++i) {} - // assert((i<nbDataInputs()) && "No free data input for Node"); - return (i < nbDataInputs()) ? i : gk_IODefaultIndex; + for (; (i < nbData()) && (input(i).second != gk_IODefaultIndex); ++i) {} + // assert((i<nbData()) && "No free data input for Node"); + return (i < nbData()) ? i : gk_IODefaultIndex; } @@ -214,8 +214,8 @@ public: * @details [data, data, weight, bias] => 2 * @return IOIndex_t */ - inline IOIndex_t nbDataInputs() const noexcept { - return getOperator()->nbDataInputs(); + inline IOIndex_t nbData() const noexcept { + return getOperator()->nbData(); } /** diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 367c9f10d..13b9d63dc 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -28,7 +28,7 @@ Aidge::Connector Aidge::GraphView::operator()( // TODO: allow for multiple inputNodes? assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour"); std::shared_ptr<Node> inNode = *inputNodes().begin(); - assert((ctors.size() == static_cast<std::size_t>(inNode->nbDataInputs())) && "Wrong number of arguments.\n"); + assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n"); for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) { assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); (void)input; // avoid unused warning @@ -110,7 +110,7 @@ Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const { IOIndex_t nbDataInput = 0; // assert(outputNodes().size() == static_cast<std::size_t>(1)); for (const std::shared_ptr<Node> &inNode : inputNodes()) { - nbDataInput += inNode->nbDataInputs(); + nbDataInput += inNode->nbData(); } return nbDataInput; } @@ -315,7 +315,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara mNodeRegistry.insert(std::make_pair(node->name(), node)); // add learnable parameters to the graph if (includeLearnableParam) { - for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { + for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) { std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i)); if (parentNode) { parentNode->addView(shared_from_this()); @@ -475,7 +475,7 @@ void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnab // same for learnable params if (includeLearnableParam) { - for (IOIndex_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) { + for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) { auto inputI = nodePtr->input(i); bool removeNode = true; for (const auto& parentOutput : inputI.first->outputs()) { @@ -757,7 +757,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone for (auto parent : oldToNewNode.first->inputs()) { while (oldToNewNodes[parent.first] == nullptr) { // Find next valid parent in line, going backward in the graph - assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); + assert(parent.first->nbData() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); const auto& parents = parent.first->inputs(); if (!parents.empty() && parents[0].first != nullptr // a valid parent exists diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index e6a53c871..6dca8eaaf 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -34,7 +34,7 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) /////////////////////////////////////////////////////// Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { - assert((ctors.size() == nbDataInputs()) && "Wrong number of arguments.\n"); + assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); (void) input; // avoid unused warning @@ -94,8 +94,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataInputs()); - for (std::size_t i = 0; i < static_cast<std::size_t>(nbDataInputs()); ++i) { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); + for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; @@ -295,7 +295,7 @@ bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const void Aidge::Node::resetConnections(bool includeLearnableParam) { // remove every parents reference to it - IOIndex_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbDataInputs(); + IOIndex_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbData(); for (IOIndex_t i = 0; i < nbRemovedInputs; ++i) { std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); if (parent.first) { @@ -367,7 +367,7 @@ std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::Nod } } } - + return out; } ///////////////////////////////////////////////////////////////////////////////////////////// diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index dbba1a7d6..e7f0d6095 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -210,7 +210,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { conv1->resetConnections(false); REQUIRE(conv->output(0).size() == 0); - for (std::size_t i = 0; i < conv1->nbDataInputs(); ++i) { + for (std::size_t i = 0; i < conv1->nbData(); ++i) { REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); } REQUIRE((conv1->input(1) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod1, 0))); diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index c09042791..79f6979c9 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -31,7 +31,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { REQUIRE(microGraph->outputNodes().size() == 1); REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); REQUIRE(op->nbInputs() == 3); - REQUIRE(op->nbDataInputs() == 1); + REQUIRE(op->nbData() == 1); REQUIRE(op->nbOutputs() == 1); std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(); @@ -45,7 +45,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { // Order not garanteed by the GraphView //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getInput(0) == myInput); REQUIRE(op->getOperator()->getOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getOutput(0)); - + //op->getOperator()->updateConsummerProducer(); // require implementation //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); -- GitLab