Skip to content
Snippets Groups Projects
Commit 4c67e7ed authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'lrscheduler' into 'learning'

Resolve "[Add][Feature][Learning] Learning Rate Scheduler"

See merge request !87
parents 03ce8a78 4b2b7c9a
No related branches found
No related tags found
3 merge requests!105version 0.2.0,!88Basic supervised learning,!87Resolve "[Add][Feature][Learning] Learning Rate Scheduler"
Pipeline #40021 passed
...@@ -62,7 +62,12 @@ ...@@ -62,7 +62,12 @@
#include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sqrt.hpp"
#include "aidge/operator/Sub.hpp" #include "aidge/operator/Sub.hpp"
#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Transpose.hpp"
#include "aidge/optimizer/LR/LRSchedulerList.hpp"
#include "aidge/optimizer/LR/LRScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/stimuli/Stimulus.hpp" #include "aidge/stimuli/Stimulus.hpp"
#include "aidge/recipies/Recipies.hpp" #include "aidge/recipies/Recipies.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#define AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#include <cstddef> // std::size_t
#include <functional> // std::function
#include <vector>
namespace Aidge {
/**
* @class LRScheduler
* @brief Manage the learning rate evolution.
*
*/
class LRScheduler {
private:
/// @brief Current step
std::size_t mStep = 0;
/// @brief Pointer to LR update function. Takes a single float value as parameter.
const std::function<float(float, const std::size_t)> mStepFunc;
/// @brief Current learning Rate
float mLR;
/// @brief Initial learning Rate passed in the constructor
const float mInitialLR;
/// @brief Step LRScheduler switches from warmup to LR (>= 1)
std::size_t mSwitchStep;
/// @brief Initial Warm Up passed in the constructor
float mInitialWarmUp;
public:
LRScheduler() = delete;
/**
* @brief Construct a new LRScheduler object. Default is ConstantLR.
*
* @param initialLR Initial learning rate.
* Will default to 0 if a negative value is passed.
* @param stepFunc Recursive update function for learning rate value.
* Default is the Constant function.
* @param nb_warmup_steps Number of warm-up steps before starting to use
* ``stepFunc`` for learning rate update. If specified, learning rate will
* linearly increase from 0 to ``initialLR``.
* Default is 0.
*/
LRScheduler(const float initialLR,
std::function<float(float, const std::size_t)> stepFunc = [](float val, const std::size_t /*step*/) { return val; },
const std::size_t nb_warmup_steps = 0)
: mStepFunc(stepFunc),
mLR((initialLR > 0.0f) ? initialLR : 0.0f),
mInitialLR(mLR),
mSwitchStep(nb_warmup_steps + 1),
mInitialWarmUp(mLR / static_cast<float>(nb_warmup_steps + 1))
{
// ctor
}
public:
/**
* @brief Update the learning rate to the next value.
* @note If the current step is lower than the switch step, the learning rate follows
* a linear function from 0 to the initial learning rate
* @note Else, the learning rate is updated using the provided function.
*/
constexpr void update() {
mLR = (mStep++ < mSwitchStep) ?
static_cast<float>(mStep) * mInitialWarmUp :
mStepFunc(mLR, mStep);
};
constexpr float learning_rate() const noexcept { return mLR; }
constexpr void set_nb_warmup_steps(const std::size_t nb_warmup_steps) noexcept {
mSwitchStep = nb_warmup_steps + 1;
mInitialWarmUp = mLR / static_cast<float>(nb_warmup_steps + 1);
}
std::vector<float> lr_profiling(const std::size_t nbStep) const;
constexpr void reset() noexcept {
mStep = 0;
mLR = mInitialLR;
}
};
} // namespace Aidge
#endif /* AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_
#define AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_
#include "aidge/optimizer/LR/LRScheduler.hpp"
#include <cstddef> // std::size_t
namespace Aidge {
LRScheduler ConstantLR(const float initialLR) {
return LRScheduler(initialLR);
}
LRScheduler StepLR(const float initialLR, const std::size_t stepSize, float gamma = 0.1f) {
return LRScheduler(initialLR,
[stepSize, gamma](float val, const std::size_t step) {
return (step % stepSize == 0) ? val*gamma : val;
});
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/optimizer/LR/LRScheduler.hpp"
#include <vector>
std::vector<float> Aidge::LRScheduler::lr_profiling(const std::size_t nbStep) const {
// instanciate the returned array
std::vector<float> profile(nbStep);
profile[0] = mInitialWarmUp; // equal to mInitialLR if no warm-up
for (std::size_t step = 1; step < nbStep; ++step) {
profile[step] = (step < mSwitchStep) ?
static_cast<float>(step + 1) * mInitialWarmUp :
mStepFunc(profile[step - 1], step);
}
return profile;
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <cstddef> // std::size_t
// #include <memory>
#include <random> // std::random_device, std::mt19937, std::uniform_int_distribution
#include <vector>
// #include "aidge/data/Tensor.hpp"
#include "aidge/optimizer/LR/LRScheduler.hpp"
#include "aidge/optimizer/LR/LRSchedulerList.hpp"
namespace Aidge {
TEST_CASE("[core/optimizer/LR] LRSchduler(computeOutputDims)", "[LRScheduler]") {
constexpr std::uint16_t NBTRIALS = 10;
// Create a random number generator
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::size_t> nbStepsDist(1, 1000);
std::uniform_real_distribution<float> initValDist(1.0e-7f, 1.0f);
SECTION("ConstantLR") {
for (std::size_t trial = 0; trial < NBTRIALS; ++trial) {
const std::size_t nbSteps = nbStepsDist(gen);
// truth
const float truth = initValDist(gen);
// create learning rate scheduler
LRScheduler myLR = ConstantLR(truth);
// prediction
std::vector<float> profile = myLR.lr_profiling(nbSteps);
// learning rate computation
std::size_t step = 0;
for (; (step < nbSteps) && (truth == profile[step]) && (truth == myLR.learning_rate()); ++step) {
myLR.update();
}
REQUIRE(step == nbSteps);
}
}
SECTION("StepLR") {
std::uniform_int_distribution<std::size_t> stepSizeDist(1, 100);
std::uniform_real_distribution<float> gammaDist(0.0001f, 1.0f);
for (std::size_t trial = 0; trial < NBTRIALS; ++trial) {
const float initialLR = initValDist(gen);
const std::size_t nbSteps = nbStepsDist(gen);
const float gamma = gammaDist(gen);
const std::size_t stepSize = stepSizeDist(gen);
LRScheduler myLR = StepLR(initialLR, stepSize, gamma);
// truth
std::vector<float> truth(nbSteps);
truth[0] = initialLR;
for (std::size_t i = 1; i < nbSteps; ++i) {
truth[i] = (i % stepSize == 0) ? truth[i - 1] * gamma : truth[i - 1];
}
// profiling
std::vector<float> profile = myLR.lr_profiling(nbSteps);
// learning rate computation
std::size_t step = 0;
for (; (step < nbSteps) && (truth[step] == profile[step]) && (truth[step] == myLR.learning_rate()); ++step) {
myLR.update();
}
REQUIRE(step == nbSteps);
}
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment