Skip to content
Snippets Groups Projects
Commit e6ce1e97 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add][Test] Unit-test for LRScheduler

parent 4967b322
No related branches found
No related tags found
No related merge requests found
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <cstddef> // std::size_t
// #include <memory>
#include <random> // std::random_device, std::mt19937, std::uniform_int_distribution
#include <vector>
// #include "aidge/data/Tensor.hpp"
#include "aidge/optimizer/LR/LRScheduler.hpp"
#include "aidge/optimizer/LR/LRSchedulerList.hpp"
namespace Aidge {
TEST_CASE("[core/optimizer/LR] LRSchduler(computeOutputDims)", "[LRScheduler]") {
constexpr std::uint16_t NBTRIALS = 10;
// Create a random number generator
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::size_t> nbStepsDist(1, 1000);
std::uniform_real_distribution<float> initValDist(1.0e-7f, 1.0f);
SECTION("ConstantLR") {
for (std::size_t trial = 0; trial < NBTRIALS; ++trial) {
const std::size_t nbSteps = nbStepsDist(gen);
// truth
const float truth = initValDist(gen);
// create learning rate scheduler
LRScheduler myLR = ConstantLR(truth);
// prediction
std::vector<float> profile = myLR.lr_profiling(nbSteps);
// learning rate computation
std::size_t step = 0;
for (; (step < nbSteps) && (truth == profile[step]) && (truth == myLR.learning_rate()); ++step) {
myLR.update();
}
REQUIRE(step == nbSteps);
}
}
SECTION("StepLR") {
std::uniform_int_distribution<std::size_t> stepSizeDist(1, 100);
std::uniform_real_distribution<float> gammaDist(0.0001f, 1.0f);
for (std::size_t trial = 0; trial < NBTRIALS; ++trial) {
const float initialLR = initValDist(gen);
const std::size_t nbSteps = nbStepsDist(gen);
const float gamma = gammaDist(gen);
const std::size_t stepSize = stepSizeDist(gen);
LRScheduler myLR = StepLR(initialLR, stepSize, gamma);
// truth
std::vector<float> truth(nbSteps);
truth[0] = initialLR;
for (std::size_t i = 1; i < nbSteps; ++i) {
truth[i] = (i % stepSize == 0) ? truth[i - 1] * gamma : truth[i - 1];
}
// profiling
std::vector<float> profile = myLR.lr_profiling(nbSteps);
// learning rate computation
std::size_t step = 0;
for (; (step < nbSteps) && (truth[step] == profile[step]) && (truth[step] == myLR.learning_rate()); ++step) {
myLR.update();
}
REQUIRE(step == nbSteps);
}
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment