From 7dbf6d79448c52f363426778aca00b08ea8b06d8 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 22 Nov 2023 17:00:18 +0000
Subject: [PATCH] fix GraphRegex test and update OperatorTensor output
 functions for Identity operator

---
 include/aidge/operator/OperatorTensor.hpp | 10 +++-------
 unit_tests/graphRegex/Test_GraphRegex.cpp | 22 +++++++++++-----------
 2 files changed, 14 insertions(+), 18 deletions(-)

diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp
index 0ecccb371..5f8317966 100644
--- a/include/aidge/operator/OperatorTensor.hpp
+++ b/include/aidge/operator/OperatorTensor.hpp
@@ -75,18 +75,14 @@ public:
     void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
     void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
     const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
-    inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); }
     inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
         return std::static_pointer_cast<Data>(getInput(inputIdx));
     }
 
     // output management
-    void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final;
-    void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final;
-    const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
-    inline Tensor& output(const IOIndex_t outputIdx) const {
-        return *getOutput(outputIdx);
-    }
+    void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override;
+    void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override;
+    virtual const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
     inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final {
         return std::static_pointer_cast<Data>(getOutput(outputIdx));
     }
diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp
index 4383f3bd7..924aac79e 100644
--- a/unit_tests/graphRegex/Test_GraphRegex.cpp
+++ b/unit_tests/graphRegex/Test_GraphRegex.cpp
@@ -9,7 +9,7 @@
 #include "aidge/operator/FC.hpp"
 #include "aidge/operator/MatMul.hpp"
 #include "aidge/operator/Producer.hpp"
-#include "aidge/utils/Recipies.hpp"
+#include "aidge/recipies/Recipies.hpp"
 
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/GenericOperator.hpp"
@@ -94,10 +94,10 @@ TEST_CASE("GraphRegexUser") {
         std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
 
         std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
-        std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c");
-        std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1");
-        std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2");
-        std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 1, 1, "c3");
+        std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
+        std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
+        std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
+        std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 0, 1, "c3");
 
         g1->add(conv);
         g1->addChild(conv1, "c");
@@ -126,10 +126,10 @@ TEST_CASE("GraphRegexUser") {
     SECTION("Applied Recipes"){
 
       // generate the original GraphView
-        auto matmul0 = MatMul(5, "matmul0");
-        auto add0 = Add<2>("add0");
-        auto matmul1 = MatMul(5, "matmul1");
-        auto add1 = Add<2>("add1");
+        auto matmul0 = MatMul(5, 5, "matmul0");
+        auto add0 = Add(2, "add0");
+        auto matmul1 = MatMul(5, 5, "matmul1");
+        auto add1 = Add(2, "add1");
 
         auto b0 = Producer({5}, "B0");
         auto w0 = Producer({5, 5}, "W0");
@@ -149,8 +149,8 @@ TEST_CASE("GraphRegexUser") {
         matmul1->addChild(add1, 0, 0);
         b1->addChild(add1, 0, 1);
 
-        auto fc = GenericOperator("FC", 1, 1, 1, "c");
-        auto fl = GenericOperator("Flatten", 1, 1, 1, "c");
+        auto fc = GenericOperator("FC", 1, 0, 1, "c");
+        auto fl = GenericOperator("Flatten", 1, 0, 1, "c");
 
 
         auto g = std::make_shared<GraphView>();
-- 
GitLab