Skip to content
Snippets Groups Projects
Commit 59964335 authored by Maxence Naud's avatar Maxence Naud
Browse files

Upd unit-test for SequentialScheduler::backward() member function

parent ac8554fd
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!39Scheduler backprop
...@@ -349,25 +349,41 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { ...@@ -349,25 +349,41 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
} }
TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward]") { TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward]") {
// create GraphView
std::shared_ptr<GraphView> gv = Sequential({ReLU("relu0"), Sqrt("srqt0"), ReLU("relu1")}); std::shared_ptr<GraphView> gv = Sequential({ReLU("relu0"), Sqrt("srqt0"), ReLU("relu1")});
std::shared_ptr<Tensor> inputTensor = std::shared_ptr<Tensor> inputTensor =
std::make_shared<Tensor>(Array4D<int, 2, 1, 5, 5>{{{{{0, 1, 2, 3, 4}, std::make_shared<Tensor>(Array4D<float, 2, 1, 5, 5>{{{{{0.0f, 1.0f, 2.0f, 3.0f, 4.0f},
{5, 6, 7, 8, 9}, {5.0f, 6.0f, 7.0f, 8.0f, 9.0f},
{10, 11, 12, 13, 14}, {10.0f, 11.0f, 12.0f, 13.0f, 14.0f},
{15, 16, 17, 18, 19}, {15.0f, 16.0f, 17.0f, 18.0f, 19.0f},
{20, 21, 22, 23, 24}}}, {20.0f, 21.0f, 22.0f, 23.0f, 24.0f}}},
{{{25, 26, 27, 28, 29}, {{{25.0f, 26.0f, 27.0f, 28.0f, 29.0f},
{30, 31, 32, 33, 34}, {30.0f, 31.0f, 32.0f, 33.0f, 34.0f},
{35, 36, 37, 38, 39}, {35.0f, 36.0f, 37.0f, 38.0f, 39.0f},
{40, 41, 42, 43, 44}, {40.0f, 41.0f, 42.0f, 43.0f, 44.0f},
{45, 46, 47, 48, 49}}}}}); {45.0f, 46.0f, 47.0f, 48.0f, 49.0f}}}}});
auto label = inputTensor; auto label = inputTensor;
// implem already set to default // implem already set to default
auto myProd = Producer(inputTensor, "prod"); auto myProd = Producer(inputTensor, "prod");
myProd -> addChild(gv); myProd -> addChild(gv);
gv -> compile("cpu", DataType::Int32); gv -> compile("cpu", DataType::Float32);
compile_gradient(gv); compile_gradient(gv);
SequentialScheduler scheduler(gv); SequentialScheduler scheduler(gv);
scheduler.backward(); scheduler.forward();
} auto predictedOutput = gv->getOrderedOutputs()[0];
\ No newline at end of file
std::shared_ptr<Tensor> targetOutput =
std::make_shared<Tensor>(Array4D<float, 2, 1, 5, 5>{{{{{0.0f, 1.0f, 1.0f, 2.0f, 2.0f},
{2.0f, 2.0f, 3.0f, 3.0f, 3.0f},
{3.0f, 3.0f, 3.0f, 4.0f, 4.0f},
{4.0f, 4.0f, 4.0f, 4.0f, 4.0f},
{4.0f, 5.0f, 5.0f, 5.0f, 5.0f}}},
{{{5.0f, 5.0f, 5.0f, 5.0f, 5.0f},
{5.0f, 6.0f, 6.0f, 6.0f, 6.0f},
{6.0f, 6.0f, 6.0f, 6.0f, 6.0f},
{6.0f, 6.0f, 6.0f, 7.0f, 7.0f},
{7.0f, 7.0f, 7.0f, 7.0f, 7.0f}}}}});
REQUIRE_NOTHROW(scheduler.backward(predictedOutput - targetOutput));
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment