Skip to content
Snippets Groups Projects
Commit 4362a3e8 authored by Maxence Naud's avatar Maxence Naud
Browse files

Minor optimizations and add default values to 'GraphView::compile()' member function

parent f1d8130c
No related branches found
No related tags found
3 merge requests!105version 0.2.0,!88Basic supervised learning,!82Resolve "Optimizer to update gradients"
Pipeline #41867 failed
......@@ -214,7 +214,7 @@ public:
* If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators.
*/
void compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device = 0);
void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0);
/**
* @brief Compute dimensions of input/output Tensors for each Operator of the
......@@ -283,7 +283,7 @@ public:
* added to the list, and so on.
* - Any remaining nodes have no path to the root node and are added in
* arbitrary order. In this case, the ranking is not garanteed to be unique.
*
*
* If the ranking cannot be garanteed to be unique, the second item indicates
* the rank from which unicity cannot be garanteed.
* @return std::pair<std::vector<NodePtr>, size_t> Pair with the list of ranked
......
......@@ -23,7 +23,7 @@
#include <functional>
#include <map>
#include <cassert>
#include <vector>
namespace Aidge {
#ifdef PYBIND
......@@ -72,19 +72,18 @@ struct Registrar {
}
static bool exists(const registrar_key& key) {
const auto it = C::registry().find(key);
return (it != C::registry().end());
return (C::registry().find(key) != C::registry().cend());
}
static auto create(const registrar_key& key){
const auto it = C::registry().find(key);
AIDGE_ASSERT(it != C::registry().end(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key);
AIDGE_ASSERT(it != C::registry().cend(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key);
return (*it).second;
}
static std::vector<registrar_key> getKeys(){
std::vector<registrar_key> keys;
for(auto keyValue : C::registry())
for(const auto& keyValue : C::registry())
keys.push_back(keyValue.first);
return keys;
}
......
......@@ -21,7 +21,7 @@
using namespace Aidge;
TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
SECTION("PaddedConv") {
auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {1, 1, 1, 1});
......@@ -108,14 +108,14 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
// Weights X
myLSTM->input(1).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(2).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(3).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(4).first->getOperator()->setOutput(0, myInitW);
op->setInput(2, myInitW);
op->setInput(3, myInitW);
op->setInput(4, myInitW);
// Weights H
myLSTM->input(5).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(6).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(7).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(8).first->getOperator()->setOutput(0, myInitR);
op->setInput(5, myInitR);
op->setInput(6, myInitR);
op->setInput(7, myInitR);
op->setInput(8, myInitR);
auto g = getConnectedGraphView(myLSTM);
g->save("lstm_before_expand", true, true);
......
......@@ -55,7 +55,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
}
g1->save("schedule");
g1->forwardDims();
g1->compile();
auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling(true);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment