Skip to content
Snippets Groups Projects
Commit 48ce14d2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Changed GraphView inputs() and dataInputs() behavior

parent 9b3c6ec2
No related branches found
No related tags found
No related merge requests found
...@@ -124,7 +124,7 @@ public: ...@@ -124,7 +124,7 @@ public:
} }
/** /**
* @brief List dataInput connections of the GraphView object's inputNodes. * @brief List outside dataInput connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>> * @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/ */
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
...@@ -137,7 +137,7 @@ public: ...@@ -137,7 +137,7 @@ public:
inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/** /**
* @brief List input connections of the GraphView object's inputNodes. * @brief List outside input connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>> * @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/ */
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
......
...@@ -128,21 +128,17 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { ...@@ -128,21 +128,17 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const { Aidge::GraphView::dataInputs() const {
IOIndex_t nbDataIn = 0U; std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbDataIn += inputNode->nbDataInputs();
}
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn);
nbDataIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) { for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->dataInputs(); inputNode->dataInputs();
std::move(inputNodeinputs.begin(), inputNodeinputs.end(),
res.begin() + nbDataIn); for (const auto& input : inputNodeinputs) {
nbDataIn += inputNode->nbDataInputs(); if (mNodes.find(input.first) == mNodes.end()) {
// res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> res.push_back(input);
// inputs()).end()); }
}
} }
return res; return res;
} }
...@@ -150,21 +146,17 @@ Aidge::GraphView::dataInputs() const { ...@@ -150,21 +146,17 @@ Aidge::GraphView::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const { Aidge::GraphView::inputs() const {
std::size_t nbIn = 0U; std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbIn += inputNode->nbInputs();
}
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn);
nbIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) { for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs(); inputNode->inputs();
std::move(inputNodeinputs.begin(), inputNodeinputs.end(),
res.begin() + nbIn); for (const auto& input : inputNodeinputs) {
nbIn += inputNode->nbInputs(); if (mNodes.find(input.first) == mNodes.end()) {
// res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> res.push_back(input);
// inputs()).end()); }
}
} }
return res; return res;
} }
......
...@@ -161,7 +161,7 @@ TEST_CASE("[core/graph] GraphView(addChild)") { ...@@ -161,7 +161,7 @@ TEST_CASE("[core/graph] GraphView(addChild)") {
TEST_CASE("[core/graph] GraphView(inputs)") { TEST_CASE("[core/graph] GraphView(inputs)") {
auto g1 = std::make_shared<GraphView>("TestGraph"); auto g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); std::shared_ptr<Node> conv = Conv(3, 32, {3, 3});
g1->add(conv); g1->add(conv, false);
REQUIRE(g1->inputs() == conv->inputs()); REQUIRE(g1->inputs() == conv->inputs());
} }
......
...@@ -21,15 +21,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { ...@@ -21,15 +21,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
SECTION("PaddedConv") { SECTION("PaddedConv") {
auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {{{1, 1}, {1, 1}}}); auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {{{1, 1}, {1, 1}}});
// 4 nodes: Pad + Conv + 2x Producer (for weight and bias)
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraph(); auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraph();
REQUIRE(microGraph->getNodes().size() == 4); REQUIRE(microGraph->getNodes().size() == 2);
REQUIRE(microGraph->inputNodes().size() == 1); REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias)
REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad"); REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad");
REQUIRE(microGraph->outputNodes().size() == 1); REQUIRE(microGraph->outputNodes().size() == 1);
REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv");
REQUIRE(op->nbInputs() == 1); REQUIRE(op->nbInputs() == 3);
REQUIRE(op->nbDataInputs() == 1); REQUIRE(op->nbDataInputs() == 1);
REQUIRE(op->nbOutputs() == 1); REQUIRE(op->nbOutputs() == 1);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment