Skip to content
Snippets Groups Projects
Commit 866ceb9a authored by Maxence Naud's avatar Maxence Naud
Browse files

Update 'SequentialScheduler' test

parent e9e8c07e
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!39Scheduler backprop
Pipeline #42432 passed
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/backend/cpu.hpp" #include "aidge/backend/cpu.hpp"
#include "aidge/recipies/GraphViewHelper.hpp" #include "aidge/recipes/GraphViewHelper.hpp"
using namespace Aidge; using namespace Aidge;
...@@ -372,7 +372,7 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward ...@@ -372,7 +372,7 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward
compile_gradient(gv); compile_gradient(gv);
SequentialScheduler scheduler(gv); SequentialScheduler scheduler(gv);
scheduler.forward(); scheduler.forward();
auto predictedOutput = gv->getOrderedOutputs()[0]; auto predictedOutput = gv->getOrderedOutputs()[0].first;
std::shared_ptr<Tensor> targetOutput = std::shared_ptr<Tensor> targetOutput =
std::make_shared<Tensor>(Array4D<float, 2, 1, 5, 5>{{{{{0.0f, 1.0f, 1.0f, 2.0f, 2.0f}, std::make_shared<Tensor>(Array4D<float, 2, 1, 5, 5>{{{{{0.0f, 1.0f, 1.0f, 2.0f, 2.0f},
...@@ -386,4 +386,5 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward ...@@ -386,4 +386,5 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward
{6.0f, 6.0f, 6.0f, 7.0f, 7.0f}, {6.0f, 6.0f, 6.0f, 7.0f, 7.0f},
{7.0f, 7.0f, 7.0f, 7.0f, 7.0f}}}}}); {7.0f, 7.0f, 7.0f, 7.0f, 7.0f}}}}});
REQUIRE_NOTHROW(scheduler.backward(predictedOutput - targetOutput)); REQUIRE_NOTHROW(scheduler.backward({targetOutput}));
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment