From abbfe683b5a035bb4bc55d4cbe4846822c8f89f3 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 7 Feb 2024 14:28:59 +0000 Subject: [PATCH] [Upd] Scheduler unit_test for 'backward()' --- unit_tests/scheduler/Test_Scheduler.cpp | 27 ++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 025ca8ba..dc13c6e5 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -20,6 +20,7 @@ #include "aidge/scheduler/Scheduler.hpp" #include "aidge/backend/cpu.hpp" +#include "aidge/recipies/GraphViewHelper.hpp" using namespace Aidge; @@ -300,7 +301,7 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { std::vector<std::shared_ptr<Aidge::Tensor>> dataIn = {inputTensor}; REQUIRE_NOTHROW(scheduler.forward(true, false, dataIn)); - + scheduler.saveSchedulingDiagram("schedulingSequential"); std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{ @@ -345,4 +346,28 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { bool equal4 = (*other4 == expectedOutput4); REQUIRE(equal4); } +} + +TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward]") { + std::shared_ptr<GraphView> gv = Sequential({ReLU(), ReLU()}); + + std::shared_ptr<Tensor> inputTensor = + std::make_shared<Tensor>(Array4D<int, 2, 1, 5, 5>{{{{{0, 1, 2, 3, 4}, + {5, 6, 7, 8, 9}, + {10, 11, 12, 13, 14}, + {15, 16, 17, 18, 19}, + {20, 21, 22, 23, 24}}}, + {{{25, 26, 27, 28, 29}, + {30, 31, 32, 33, 34}, + {35, 36, 37, 38, 39}, + {40, 41, 42, 43, 44}, + {45, 46, 47, 48, 49}}}}}); + auto label = inputTensor; + // implem already set to default + auto myProd = Producer(inputTensor, "prod"); + myProd -> addChild(gv); + gv -> compile("cpu", DataType::Float32); + compile_gradient(gv); + SequentialScheduler scheduler(gv); + scheduler.backward(); } \ No newline at end of file -- GitLab