diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index f11136adaaa3d23fa9d3dc5749dd5d6771cbc42c..31afeb43a786e81d22d4098b42cc1a5d1b167b98 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -124,7 +124,7 @@ public: } /** - * @brief List dataInput connections of the GraphView object's inputNodes. + * @brief List outside dataInput connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; @@ -137,7 +137,7 @@ public: inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** - * @brief List input connections of the GraphView object's inputNodes. + * @brief List outside input connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 7cb4e1dcf33b71bec87ea883aceb8c8a3c49a5ba..dbb64177676e414b131aa00af898b8542024bad9 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -128,21 +128,17 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { - IOIndex_t nbDataIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbDataIn += inputNode->nbDataInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn); - nbDataIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbDataIn); - nbDataIn += inputNode->nbDataInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } @@ -150,21 +146,17 @@ Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { - std::size_t nbIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbIn += inputNode->nbInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn); - nbIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbIn); - nbIn += inputNode->nbInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index dc693193c6606c99b1628d23ad253015f8f8dbe6..a37c9441723f019f3cae858578984a7c13b5929d 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -161,7 +161,7 @@ TEST_CASE("[core/graph] GraphView(addChild)") { TEST_CASE("[core/graph] GraphView(inputs)") { auto g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); - g1->add(conv); + g1->add(conv, false); REQUIRE(g1->inputs() == conv->inputs()); } diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 24aaf6cf489b762198fbe576e1bb901cc1cdf7f9..ef2223aa1fc7377bdeafb16b21dad1d0314a7a72 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -21,15 +21,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { SECTION("PaddedConv") { auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {{{1, 1}, {1, 1}}}); - // 4 nodes: Pad + Conv + 2x Producer (for weight and bias) auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraph(); - REQUIRE(microGraph->getNodes().size() == 4); - REQUIRE(microGraph->inputNodes().size() == 1); + REQUIRE(microGraph->getNodes().size() == 2); + REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias) REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad"); REQUIRE(microGraph->outputNodes().size() == 1); REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); - REQUIRE(op->nbInputs() == 1); + REQUIRE(op->nbInputs() == 3); REQUIRE(op->nbDataInputs() == 1); REQUIRE(op->nbOutputs() == 1);