Skip to content
Snippets Groups Projects
Commit 5ff8d30b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'memorize' into 'dev'

Support for recurrent networks

See merge request eclipse/aidge/aidge_backend_cpu!37
parents b7d782f6 dd7b1b22
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!37Support for recurrent networks
Pipeline #40596 passed
......@@ -18,14 +18,14 @@
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/data/Tensor.hpp"
namespace Aidge {
TEST_CASE("[core/recipies] FuseBatchNorm", "[recipies][FuseBatchNorm]") {
TEST_CASE("[core/recipes] FuseBatchNorm", "[recipes][FuseBatchNorm]") {
auto myProd = Producer({2, 3, 3, 3}, "dataProvider");
auto myConv = Conv(3, 3, {1, 1}, "conv1");
auto myBN = BatchNorm<2>(32, 1.0e-5F, 0.1F, "batchnorm1");
......@@ -86,14 +86,11 @@ TEST_CASE("[core/recipies] FuseBatchNorm", "[recipies][FuseBatchNorm]") {
myBNOp -> setInput(4, std::make_shared<Tensor>(Array1D<float,3> {{0.4470, 0.3064, 0.7061}}));
auto g1 = Sequential({
myProd,
myConv,
myBN
});
g1 -> setName("fuseBNGraph");
myProd -> addChild(myConv); // set graph input
myProdOp -> setDataType(DataType::Float32);
myProdOp -> setBackend("cpu");
g1 -> compile("cpu", DataType::Float32);
auto s = SequentialScheduler(g1);
......@@ -107,7 +104,7 @@ TEST_CASE("[core/recipies] FuseBatchNorm", "[recipies][FuseBatchNorm]") {
std::shared_ptr<Tensor> res2 = std::make_shared<Tensor>(*(myConvOp -> getOutput(0)));
REQUIRE(g1 -> outputNodes().size() == 1);
REQUIRE(g1 -> inputNodes().size() == 1);
REQUIRE(g1 -> inputNodes().size() == 0);
bool eq = true;
for (std::size_t i = 0; i < res1->size(); ++i) {
eq &= std::abs(res1->get<float>(i) - res2->get<float>(i)) < 1.0e-06;
......
......@@ -16,14 +16,14 @@
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/operator/Concat.hpp"
namespace Aidge {
TEST_CASE("[core/recipies] Tiling(transformation)", "[Tiling][Recipies]") {
TEST_CASE("[core/recipes] Tiling(transformation)", "[Tiling][Recipes]") {
SECTION("Transform a pre-generated GraphView") {
......
......@@ -19,7 +19,7 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/backend/cpu.hpp"
......
......@@ -205,7 +205,46 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
SECTION("Test Residual graph") {
}
SECTION("Test Recurrent graph") {}
SECTION("Test Recurrent graph") {
std::shared_ptr<Tensor> in = std::make_shared<Tensor>(
Array2D<int, 2, 3>{{{1, 2, 3}, {4, 5, 6}}});
std::shared_ptr<Tensor> initTensor = std::make_shared<Tensor>(
Array2D<int, 2, 3>{{{0, 0, 0}, {1, 1, 1}}});
std::shared_ptr<Tensor> biasTensor = std::make_shared<Tensor>(
Array2D<int, 2, 3>{{{2, 0, 0}, {1, 0, 0}}});
auto add1 = Add(2, "add1");
auto mem = Memorize(3, "mem1");
auto add2 = Add(2, "add2");
auto bias = Producer(biasTensor, "bias");
auto init = Producer(initTensor, "init");
auto input = Producer(in, "input");
std::shared_ptr<GraphView> g = Sequential({add1, mem, add2});
init->addChild(mem, 0, 1);
mem->addChild(add1, 1, 1);
bias->addChild(add2, 0, 1);
input->addChild(add1, 0, 0);
// Update GraphView inputs/outputs following previous connections:
g->add({mem, add1, add2, init, bias, input});
g->setBackend("cpu");
g->setDataType(Aidge::DataType::Int32);
g->save("graphRecurrent");
g->forwardDims();
SequentialScheduler scheduler(g);
REQUIRE_NOTHROW(scheduler.forward(true, true));
scheduler.saveSchedulingDiagram("schedulingRecurrent");
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(
Array2D<int, 2, 3>{{{5, 6, 9}, {14, 16, 19}}});
std::shared_ptr<Tensor> result =
std::static_pointer_cast<Tensor>(g->getNode("add2")->getOperator()->getRawOutput(0));
result->print();
expectedOutput->print();
bool equal = (*result == *expectedOutput);
REQUIRE(equal);
}
SECTION("Test ConnectInput graph") {
std::shared_ptr<GraphView> g =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment