From 7dbf6d79448c52f363426778aca00b08ea8b06d8 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 22 Nov 2023 17:00:18 +0000 Subject: [PATCH] fix GraphRegex test and update OperatorTensor output functions for Identity operator --- include/aidge/operator/OperatorTensor.hpp | 10 +++------- unit_tests/graphRegex/Test_GraphRegex.cpp | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 0ecccb371..5f8317966 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -75,18 +75,14 @@ public: void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; - inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); } inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { return std::static_pointer_cast<Data>(getInput(inputIdx)); } // output management - void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final; - void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final; - const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const; - inline Tensor& output(const IOIndex_t outputIdx) const { - return *getOutput(outputIdx); - } + void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override; + void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override; + virtual const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const; inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final { return std::static_pointer_cast<Data>(getOutput(outputIdx)); } diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index 4383f3bd7..924aac79e 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -9,7 +9,7 @@ #include "aidge/operator/FC.hpp" #include "aidge/operator/MatMul.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Recipies.hpp" +#include "aidge/recipies/Recipies.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" @@ -94,10 +94,10 @@ TEST_CASE("GraphRegexUser") { std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); - std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); - std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); - std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 0, 1, "c3"); g1->add(conv); g1->addChild(conv1, "c"); @@ -126,10 +126,10 @@ TEST_CASE("GraphRegexUser") { SECTION("Applied Recipes"){ // generate the original GraphView - auto matmul0 = MatMul(5, "matmul0"); - auto add0 = Add<2>("add0"); - auto matmul1 = MatMul(5, "matmul1"); - auto add1 = Add<2>("add1"); + auto matmul0 = MatMul(5, 5, "matmul0"); + auto add0 = Add(2, "add0"); + auto matmul1 = MatMul(5, 5, "matmul1"); + auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); auto w0 = Producer({5, 5}, "W0"); @@ -149,8 +149,8 @@ TEST_CASE("GraphRegexUser") { matmul1->addChild(add1, 0, 0); b1->addChild(add1, 0, 1); - auto fc = GenericOperator("FC", 1, 1, 1, "c"); - auto fl = GenericOperator("Flatten", 1, 1, 1, "c"); + auto fc = GenericOperator("FC", 1, 0, 1, "c"); + auto fl = GenericOperator("Flatten", 1, 0, 1, "c"); auto g = std::make_shared<GraphView>(); -- GitLab