Skip to content
Snippets Groups Projects
Commit abbfe683 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] Scheduler unit_test for 'backward()'

parent 736bb9da
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!39Scheduler backprop
This commit is part of merge request !39. Comments created here will be created in the context of that merge request.
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/backend/cpu.hpp" #include "aidge/backend/cpu.hpp"
#include "aidge/recipies/GraphViewHelper.hpp"
using namespace Aidge; using namespace Aidge;
...@@ -300,7 +301,7 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { ...@@ -300,7 +301,7 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
std::vector<std::shared_ptr<Aidge::Tensor>> dataIn = {inputTensor}; std::vector<std::shared_ptr<Aidge::Tensor>> dataIn = {inputTensor};
REQUIRE_NOTHROW(scheduler.forward(true, false, dataIn)); REQUIRE_NOTHROW(scheduler.forward(true, false, dataIn));
scheduler.saveSchedulingDiagram("schedulingSequential"); scheduler.saveSchedulingDiagram("schedulingSequential");
std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{ std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{
...@@ -345,4 +346,28 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { ...@@ -345,4 +346,28 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
bool equal4 = (*other4 == expectedOutput4); bool equal4 = (*other4 == expectedOutput4);
REQUIRE(equal4); REQUIRE(equal4);
} }
}
TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward]") {
std::shared_ptr<GraphView> gv = Sequential({ReLU(), ReLU()});
std::shared_ptr<Tensor> inputTensor =
std::make_shared<Tensor>(Array4D<int, 2, 1, 5, 5>{{{{{0, 1, 2, 3, 4},
{5, 6, 7, 8, 9},
{10, 11, 12, 13, 14},
{15, 16, 17, 18, 19},
{20, 21, 22, 23, 24}}},
{{{25, 26, 27, 28, 29},
{30, 31, 32, 33, 34},
{35, 36, 37, 38, 39},
{40, 41, 42, 43, 44},
{45, 46, 47, 48, 49}}}}});
auto label = inputTensor;
// implem already set to default
auto myProd = Producer(inputTensor, "prod");
myProd -> addChild(gv);
gv -> compile("cpu", DataType::Float32);
compile_gradient(gv);
SequentialScheduler scheduler(gv);
scheduler.backward();
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment