Skip to content
Snippets Groups Projects
Commit 8de2b659 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] learning rate scheduler class LRScheduler

parent 2866d391
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!88Basic supervised learning
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#define AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#include <cstddef> // std::size_t
#include <functional> // std::function
#include <vector>
namespace Aidge {
/**
* @class LRScheduler
* @brief Manage the learning rate evolution.
*
*/
class LRScheduler {
private:
/// @brief Current step
std::size_t mStep = 0;
/// @brief Pointer to LR update function. Takes a single float value as parameter.
const std::function<float(float, const std::size_t)> mStepFunc;
/// @brief Current learning Rate
float mLR;
/// @brief Initial learning Rate passed in the constructor
const float mInitialLR;
/// @brief Step LRScheduler switches from warmup to LR (>= 1)
std::size_t mSwitchStep;
/// @brief Initial Warm Up passed in the constructor
float mInitialWarmUp;
public:
LRScheduler() = delete;
/**
* @brief Construct a new LRScheduler object. Default is ConstantLR.
*
* @param initialLR Initial learning rate.
* Will default to 0 if a negative value is passed.
* @param stepFunc Recursive update function for learning rate value.
* Default is the Constant function.
* @param nb_warmup_steps Number of warm-up steps before starting to use
* ``stepFunc`` for learning rate update. If specified, learning rate will
* linearly increase from 0 to ``initialLR``.
* Default is 0.
*/
LRScheduler(const float initialLR,
std::function<float(float, const std::size_t)> stepFunc = [](float val, const std::size_t /*step*/) { return val; },
const std::size_t nb_warmup_steps = 0)
: mStepFunc(stepFunc),
mLR((initialLR > 0.0f) ? initialLR : 0.0f),
mInitialLR(mLR),
mSwitchStep(nb_warmup_steps + 1),
mInitialWarmUp(mLR / static_cast<float>(nb_warmup_steps + 1))
{
// ctor
}
public:
/**
* @brief Update the learning rate to the next value.
* @note If the current step is lower than the switch step, the learning rate follows
* a linear function from 0 to the initial learning rate
* @note Else, the learning rate is updated using the provided function.
*/
constexpr void update() {
mLR = (mStep++ < mSwitchStep) ?
static_cast<float>(mStep) * mInitialWarmUp :
mStepFunc(mLR, mStep);
};
constexpr float learning_rate() const noexcept { return mLR; }
constexpr void set_nb_warmup_steps(const std::size_t nb_warmup_steps) noexcept {
mSwitchStep = nb_warmup_steps + 1;
mInitialWarmUp = mLR / static_cast<float>(nb_warmup_steps + 1);
}
std::vector<float> lr_profiling(const std::size_t nbStep) const;
constexpr void reset() noexcept {
mStep = 0;
mLR = mInitialLR;
}
};
} // namespace Aidge
#endif /* AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/optimizer/LR/LRScheduler.hpp"
#include <vector>
std::vector<float> Aidge::LRScheduler::lr_profiling(const std::size_t nbStep) const {
// instanciate the returned array
std::vector<float> profile(nbStep);
profile[0] = mInitialWarmUp; // equal to mInitialLR if no warm-up
for (std::size_t step = 1; step < nbStep; ++step) {
profile[step] = (step < mSwitchStep) ?
static_cast<float>(step + 1) * mInitialWarmUp :
mStepFunc(profile[step - 1], step);
}
return profile;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment