From 4362a3e80706fae9e9d674d64e8c95337a8287c6 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 21 Mar 2024 14:15:07 +0000 Subject: [PATCH] Minor optimizations and add default values to 'GraphView::compile()' member function --- include/aidge/graph/GraphView.hpp | 4 ++-- include/aidge/utils/Registrar.hpp | 9 ++++----- unit_tests/operator/Test_MetaOperator.cpp | 16 ++++++++-------- unit_tests/scheduler/Test_Scheduler.cpp | 2 +- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 3311797d8..fcf5250b2 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -214,7 +214,7 @@ public: * If not, add a Transpose Operator. * 4 - Propagate Tensor dimensions through the consecutive Operators. */ - void compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device = 0); + void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0); /** * @brief Compute dimensions of input/output Tensors for each Operator of the @@ -283,7 +283,7 @@ public: * added to the list, and so on. * - Any remaining nodes have no path to the root node and are added in * arbitrary order. In this case, the ranking is not garanteed to be unique. - * + * * If the ranking cannot be garanteed to be unique, the second item indicates * the rank from which unicity cannot be garanteed. * @return std::pair<std::vector<NodePtr>, size_t> Pair with the list of ranked diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index a5bd260ec..b7abfb979 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -23,7 +23,7 @@ #include <functional> #include <map> -#include <cassert> +#include <vector> namespace Aidge { #ifdef PYBIND @@ -72,19 +72,18 @@ struct Registrar { } static bool exists(const registrar_key& key) { - const auto it = C::registry().find(key); - return (it != C::registry().end()); + return (C::registry().find(key) != C::registry().cend()); } static auto create(const registrar_key& key){ const auto it = C::registry().find(key); - AIDGE_ASSERT(it != C::registry().end(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key); + AIDGE_ASSERT(it != C::registry().cend(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key); return (*it).second; } static std::vector<registrar_key> getKeys(){ std::vector<registrar_key> keys; - for(auto keyValue : C::registry()) + for(const auto& keyValue : C::registry()) keys.push_back(keyValue.first); return keys; } diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 3ff2a3c6c..331b9380a 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -21,7 +21,7 @@ using namespace Aidge; -TEST_CASE("[core/operators] MetaOperator", "[Operator]") { +TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") { SECTION("PaddedConv") { auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {1, 1, 1, 1}); @@ -108,14 +108,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { // Weights X myLSTM->input(1).first->getOperator()->setOutput(0, myInitW); - myLSTM->input(2).first->getOperator()->setOutput(0, myInitW); - myLSTM->input(3).first->getOperator()->setOutput(0, myInitW); - myLSTM->input(4).first->getOperator()->setOutput(0, myInitW); + op->setInput(2, myInitW); + op->setInput(3, myInitW); + op->setInput(4, myInitW); // Weights H - myLSTM->input(5).first->getOperator()->setOutput(0, myInitR); - myLSTM->input(6).first->getOperator()->setOutput(0, myInitR); - myLSTM->input(7).first->getOperator()->setOutput(0, myInitR); - myLSTM->input(8).first->getOperator()->setOutput(0, myInitR); + op->setInput(5, myInitR); + op->setInput(6, myInitR); + op->setInput(7, myInitR); + op->setInput(8, myInitR); auto g = getConnectedGraphView(myLSTM); g->save("lstm_before_expand", true, true); diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 7e28f1fad..3ef70bcfb 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -55,7 +55,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { } g1->save("schedule"); - g1->forwardDims(); + g1->compile(); auto scheduler = SequentialScheduler(g1); scheduler.generateScheduling(true); -- GitLab