Skip to content
Snippets Groups Projects
Commit eefd0f21 authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

Minor optimizations and add default values to 'GraphView::compile()' member function

parent 02ffa260
No related branches found
No related tags found
No related merge requests found
...@@ -214,7 +214,7 @@ public: ...@@ -214,7 +214,7 @@ public:
* If not, add a Transpose Operator. * If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators. * 4 - Propagate Tensor dimensions through the consecutive Operators.
*/ */
void compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device = 0); void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0);
/** /**
* @brief Compute dimensions of input/output Tensors for each Operator of the * @brief Compute dimensions of input/output Tensors for each Operator of the
...@@ -283,7 +283,7 @@ public: ...@@ -283,7 +283,7 @@ public:
* added to the list, and so on. * added to the list, and so on.
* - Any remaining nodes have no path to the root node and are added in * - Any remaining nodes have no path to the root node and are added in
* arbitrary order. In this case, the ranking is not garanteed to be unique. * arbitrary order. In this case, the ranking is not garanteed to be unique.
* *
* If the ranking cannot be garanteed to be unique, the second item indicates * If the ranking cannot be garanteed to be unique, the second item indicates
* the rank from which unicity cannot be garanteed. * the rank from which unicity cannot be garanteed.
* @return std::pair<std::vector<NodePtr>, size_t> Pair with the list of ranked * @return std::pair<std::vector<NodePtr>, size_t> Pair with the list of ranked
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <functional> #include <functional>
#include <map> #include <map>
#include <cassert> #include <vector>
namespace Aidge { namespace Aidge {
#ifdef PYBIND #ifdef PYBIND
...@@ -72,19 +72,18 @@ struct Registrar { ...@@ -72,19 +72,18 @@ struct Registrar {
} }
static bool exists(const registrar_key& key) { static bool exists(const registrar_key& key) {
const auto it = C::registry().find(key); return (C::registry().find(key) != C::registry().cend());
return (it != C::registry().end());
} }
static auto create(const registrar_key& key){ static auto create(const registrar_key& key){
const auto it = C::registry().find(key); const auto it = C::registry().find(key);
AIDGE_ASSERT(it != C::registry().end(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key); AIDGE_ASSERT(it != C::registry().cend(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key);
return (*it).second; return (*it).second;
} }
static std::vector<registrar_key> getKeys(){ static std::vector<registrar_key> getKeys(){
std::vector<registrar_key> keys; std::vector<registrar_key> keys;
for(auto keyValue : C::registry()) for(const auto& keyValue : C::registry())
keys.push_back(keyValue.first); keys.push_back(keyValue.first);
return keys; return keys;
} }
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using namespace Aidge; using namespace Aidge;
TEST_CASE("[core/operators] MetaOperator", "[Operator]") { TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
SECTION("PaddedConv") { SECTION("PaddedConv") {
auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {1, 1, 1, 1}); auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {1, 1, 1, 1});
...@@ -108,14 +108,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { ...@@ -108,14 +108,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
// Weights X // Weights X
myLSTM->input(1).first->getOperator()->setOutput(0, myInitW); myLSTM->input(1).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(2).first->getOperator()->setOutput(0, myInitW); op->setInput(2, myInitW);
myLSTM->input(3).first->getOperator()->setOutput(0, myInitW); op->setInput(3, myInitW);
myLSTM->input(4).first->getOperator()->setOutput(0, myInitW); op->setInput(4, myInitW);
// Weights H // Weights H
myLSTM->input(5).first->getOperator()->setOutput(0, myInitR); op->setInput(5, myInitR);
myLSTM->input(6).first->getOperator()->setOutput(0, myInitR); op->setInput(6, myInitR);
myLSTM->input(7).first->getOperator()->setOutput(0, myInitR); op->setInput(7, myInitR);
myLSTM->input(8).first->getOperator()->setOutput(0, myInitR); op->setInput(8, myInitR);
auto g = getConnectedGraphView(myLSTM); auto g = getConnectedGraphView(myLSTM);
g->save("lstm_before_expand", true, true); g->save("lstm_before_expand", true, true);
......
...@@ -55,7 +55,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { ...@@ -55,7 +55,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
} }
g1->save("schedule"); g1->save("schedule");
g1->forwardDims(); g1->compile();
auto scheduler = SequentialScheduler(g1); auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling(true); scheduler.generateScheduling(true);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment