Skip to content
Snippets Groups Projects
Commit 48ce14d2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Changed GraphView inputs() and dataInputs() behavior

parent 9b3c6ec2
No related branches found
No related tags found
1 merge request!11Removed padding from conv and pool and added Pad operator
......@@ -124,7 +124,7 @@ public:
}
/**
* @brief List dataInput connections of the GraphView object's inputNodes.
* @brief List outside dataInput connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
......@@ -137,7 +137,7 @@ public:
inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/**
* @brief List input connections of the GraphView object's inputNodes.
* @brief List outside input connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
......
......@@ -128,21 +128,17 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::dataInputs() const {
IOIndex_t nbDataIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbDataIn += inputNode->nbDataInputs();
}
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn);
nbDataIn = 0U;
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->dataInputs();
std::move(inputNodeinputs.begin(), inputNodeinputs.end(),
res.begin() + nbDataIn);
nbDataIn += inputNode->nbDataInputs();
// res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode ->
// inputs()).end());
for (const auto& input : inputNodeinputs) {
if (mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
......@@ -150,21 +146,17 @@ Aidge::GraphView::dataInputs() const {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
Aidge::GraphView::inputs() const {
std::size_t nbIn = 0U;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
nbIn += inputNode->nbInputs();
}
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn);
nbIn = 0U;
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res;
for (const std::shared_ptr<Node>& inputNode : mInputNodes) {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs();
std::move(inputNodeinputs.begin(), inputNodeinputs.end(),
res.begin() + nbIn);
nbIn += inputNode->nbInputs();
// res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode ->
// inputs()).end());
for (const auto& input : inputNodeinputs) {
if (mNodes.find(input.first) == mNodes.end()) {
res.push_back(input);
}
}
}
return res;
}
......
......@@ -161,7 +161,7 @@ TEST_CASE("[core/graph] GraphView(addChild)") {
TEST_CASE("[core/graph] GraphView(inputs)") {
auto g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = Conv(3, 32, {3, 3});
g1->add(conv);
g1->add(conv, false);
REQUIRE(g1->inputs() == conv->inputs());
}
......
......@@ -21,15 +21,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
SECTION("PaddedConv") {
auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {{{1, 1}, {1, 1}}});
// 4 nodes: Pad + Conv + 2x Producer (for weight and bias)
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraph();
REQUIRE(microGraph->getNodes().size() == 4);
REQUIRE(microGraph->inputNodes().size() == 1);
REQUIRE(microGraph->getNodes().size() == 2);
REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias)
REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad");
REQUIRE(microGraph->outputNodes().size() == 1);
REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv");
REQUIRE(op->nbInputs() == 1);
REQUIRE(op->nbInputs() == 3);
REQUIRE(op->nbDataInputs() == 1);
REQUIRE(op->nbOutputs() == 1);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment