From c6c329215a6e3df670169cadac8cf20c5e6bee9b Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 7 Feb 2024 14:28:59 +0000 Subject: [PATCH] [Upd] Scheduler unit_test for 'backward()' --- unit_tests/scheduler/Test_Scheduler.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 8ea8e726..1dd294a3 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -20,6 +20,7 @@ #include "aidge/scheduler/Scheduler.hpp" #include "aidge/backend/cpu.hpp" +#include "aidge/recipies/GraphViewHelper.hpp" using namespace Aidge; @@ -206,4 +207,28 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { } SECTION("Test Recurrent graph") {} +} + +TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward]") { + std::shared_ptr<GraphView> gv = Sequential({ReLU(), ReLU()}); + + std::shared_ptr<Tensor> inputTensor = + std::make_shared<Tensor>(Array4D<int, 2, 1, 5, 5>{{{{{0, 1, 2, 3, 4}, + {5, 6, 7, 8, 9}, + {10, 11, 12, 13, 14}, + {15, 16, 17, 18, 19}, + {20, 21, 22, 23, 24}}}, + {{{25, 26, 27, 28, 29}, + {30, 31, 32, 33, 34}, + {35, 36, 37, 38, 39}, + {40, 41, 42, 43, 44}, + {45, 46, 47, 48, 49}}}}}); + auto label = inputTensor; + // implem already set to default + auto myProd = Producer(inputTensor, "prod"); + myProd -> addChild(gv); + gv -> compile("cpu", DataType::Float32); + compile_gradient(gv); + SequentialScheduler scheduler(gv); + scheduler.backward(); } \ No newline at end of file -- GitLab