From 4362a3e80706fae9e9d674d64e8c95337a8287c6 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 21 Mar 2024 14:15:07 +0000
Subject: [PATCH] Minor optimizations and add default values to
 'GraphView::compile()' member function

---
 include/aidge/graph/GraphView.hpp         |  4 ++--
 include/aidge/utils/Registrar.hpp         |  9 ++++-----
 unit_tests/operator/Test_MetaOperator.cpp | 16 ++++++++--------
 unit_tests/scheduler/Test_Scheduler.cpp   |  2 +-
 4 files changed, 15 insertions(+), 16 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 3311797d8..fcf5250b2 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -214,7 +214,7 @@ public:
      * If not, add a Transpose Operator.
      * 4 - Propagate Tensor dimensions through the consecutive Operators.
      */
-    void compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device = 0);
+    void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0);
 
     /**
      * @brief Compute dimensions of input/output Tensors for each Operator of the
@@ -283,7 +283,7 @@ public:
      *   added to the list, and so on.
      * - Any remaining nodes have no path to the root node and are added in
      *   arbitrary order. In this case, the ranking is not garanteed to be unique.
-     * 
+     *
      * If the ranking cannot be garanteed to be unique, the second item indicates
      * the rank from which unicity cannot be garanteed.
      * @return std::pair<std::vector<NodePtr>, size_t> Pair with the list of ranked
diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp
index a5bd260ec..b7abfb979 100644
--- a/include/aidge/utils/Registrar.hpp
+++ b/include/aidge/utils/Registrar.hpp
@@ -23,7 +23,7 @@
 
 #include <functional>
 #include <map>
-#include <cassert>
+#include <vector>
 
 namespace Aidge {
 #ifdef PYBIND
@@ -72,19 +72,18 @@ struct Registrar {
     }
 
     static bool exists(const registrar_key& key) {
-        const auto it = C::registry().find(key);
-        return (it != C::registry().end());
+        return (C::registry().find(key) != C::registry().cend());
     }
 
     static auto create(const registrar_key& key){
         const auto it = C::registry().find(key);
-        AIDGE_ASSERT(it != C::registry().end(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key);
+        AIDGE_ASSERT(it != C::registry().cend(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key);
 
         return (*it).second;
     }
     static std::vector<registrar_key> getKeys(){
         std::vector<registrar_key> keys;
-        for(auto keyValue : C::registry())
+        for(const auto& keyValue : C::registry())
             keys.push_back(keyValue.first);
         return keys;
     }
diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index 3ff2a3c6c..331b9380a 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -21,7 +21,7 @@
 
 using namespace Aidge;
 
-TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
+TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
     SECTION("PaddedConv") {
         auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {1, 1, 1, 1});
 
@@ -108,14 +108,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
 
         // Weights X
         myLSTM->input(1).first->getOperator()->setOutput(0, myInitW);
-        myLSTM->input(2).first->getOperator()->setOutput(0, myInitW);
-        myLSTM->input(3).first->getOperator()->setOutput(0, myInitW);
-        myLSTM->input(4).first->getOperator()->setOutput(0, myInitW);
+        op->setInput(2, myInitW);
+        op->setInput(3, myInitW);
+        op->setInput(4, myInitW);
         // Weights H
-        myLSTM->input(5).first->getOperator()->setOutput(0, myInitR);
-        myLSTM->input(6).first->getOperator()->setOutput(0, myInitR);
-        myLSTM->input(7).first->getOperator()->setOutput(0, myInitR);
-        myLSTM->input(8).first->getOperator()->setOutput(0, myInitR);
+        op->setInput(5, myInitR);
+        op->setInput(6, myInitR);
+        op->setInput(7, myInitR);
+        op->setInput(8, myInitR);
 
         auto g = getConnectedGraphView(myLSTM);
         g->save("lstm_before_expand", true, true);
diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp
index 7e28f1fad..3ef70bcfb 100644
--- a/unit_tests/scheduler/Test_Scheduler.cpp
+++ b/unit_tests/scheduler/Test_Scheduler.cpp
@@ -55,7 +55,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
             }
 
             g1->save("schedule");
-            g1->forwardDims();
+            g1->compile();
 
             auto scheduler = SequentialScheduler(g1);
             scheduler.generateScheduling(true);
-- 
GitLab