Skip to content
Snippets Groups Projects
Commit d298f23d authored by Olivier Antoni's avatar Olivier Antoni
Browse files

Fix Adam optimizer unit test

parent c02cc1af
No related branches found
No related tags found
2 merge requests!10version 0.1.2,!9Add Adam optimizer
Pipeline #49065 failed
...@@ -70,7 +70,7 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") { ...@@ -70,7 +70,7 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") {
val_tensors[i] = std::make_unique<float[]>(size_tensors[i]); val_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
val_grad_tensors[i] = std::make_unique<float[]>(size_tensors[i]); val_grad_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
val_momentum1_tensors[i] = std::make_unique<float[]>(size_tensors[i]); val_momentum1_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
val_momentum2_tensors[i] = std::make_unique<float[]>(size_tensors[i]); val_momentum2_tensors[i] = std::make_unique<float[]>(size_tensors[i]);
for (std::size_t j = 0; j < size_tensors[i]; ++j) { for (std::size_t j = 0; j < size_tensors[i]; ++j) {
val_tensors[i][j] = valueDist(gen); val_tensors[i][j] = valueDist(gen);
val_grad_tensors[i][j] = valueDist(gen); val_grad_tensors[i][j] = valueDist(gen);
...@@ -92,7 +92,7 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") { ...@@ -92,7 +92,7 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") {
momentum_tensors[i] = std::make_shared<Tensor>(dims); momentum_tensors[i] = std::make_shared<Tensor>(dims);
momentum_tensors[i]->setBackend("cpu"); momentum_tensors[i]->setBackend("cpu");
momentum_tensors[i]->getImpl()->setRawPtr(val_momentum1_tensors[i].get(), size_tensors[i]); momentum_tensors[i]->getImpl()->setRawPtr(val_momentum1_tensors[i].get(), size_tensors[i]);
momentum_tensors[i]->getImpl()->setRawPtr(val_momentum2_tensors[i].get(), size_tensors[i]); momentum_tensors[i]->getImpl()->setRawPtr(val_momentum2_tensors[i].get(), size_tensors[i]);
REQUIRE((tensors[i]->hasImpl() && REQUIRE((tensors[i]->hasImpl() &&
optim_tensors[i]->hasImpl() && optim_tensors[i]->hasImpl() &&
...@@ -102,7 +102,7 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") { ...@@ -102,7 +102,7 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") {
// generate parameters // generate parameters
float lr = paramDist(gen); float lr = paramDist(gen);
float beta1 = paramDist(gen); float beta1 = paramDist(gen);
float beta2 = paramDist(gen); float beta2 = paramDist(gen);
float epsilon = paramDist(gen); float epsilon = paramDist(gen);
// set Optimizer // set Optimizer
...@@ -121,13 +121,14 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") { ...@@ -121,13 +121,14 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") {
// truth // truth
for (std::size_t step = 0; step < 10; ++step) { for (std::size_t step = 0; step < 10; ++step) {
float lr2 = lr * std::sqrt(1.0f - std::pow(beta1, step + 1)) / (1.0f - std::pow(beta1, step + 1));
float epsilon2 = epsilon * std::sqrt(1.0f - std::pow(beta2, step + 1));
for (std::size_t t = 0; t < nb_tensors; ++t) { for (std::size_t t = 0; t < nb_tensors; ++t) {
for (std::size_t i = 0; i < size_tensors[t]; ++i) { for (std::size_t i = 0; i < size_tensors[t]; ++i) {
val_momentum1_tensors[t][i] = beta1 * val_momentum1_tensors[t][i] + (1.0f - beta1) * val_grad_tensors[t][i]; val_momentum1_tensors[t][i] = beta1 * val_momentum1_tensors[t][i] + (1.0f - beta1) * val_grad_tensors[t][i];
val_momentum2_tensors[t][i] = beta2 * val_momentum2_tensors[t][i] + (1.0f - beta2) * val_grad_tensors[t][i] * val_grad_tensors[t][i]; val_momentum2_tensors[t][i] = beta2 * val_momentum2_tensors[t][i] + (1.0f - beta2) * val_grad_tensors[t][i] * val_grad_tensors[t][i];
val_tensors[t][i] = val_tensors[t][i] val_tensors[t][i] = val_tensors[t][i]
- lr * val_momentum1_tensors[t][i] / (1.0f - std::pow(beta1, step + 1)) - lr2 * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) + epsilon2);
/ (std::sqrt(val_momentum2_tensors[t][i] / (1.0f - std::pow(beta2, step + 1))) + epsilon);
} }
} }
// optimizer // optimizer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment