Skip to content
Snippets Groups Projects
Commit 8dec7ca3 authored by Maxence Naud's avatar Maxence Naud
Browse files

Update project name

parent f6389572
No related branches found
No related tags found
1 merge request!3Dev - learning - v0.1.0
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#define AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#include <cstddef> // std::size_t
#include <functional> // std::function
#include <vector>
namespace Aidge {
/**
* @class LRScheduler
* @brief Manage the learning rate evolution.
*
*/
class LRScheduler {
private:
/// @brief Current step
std::size_t mStep = 0;
/// @brief Pointer to LR update function. Takes a single float value as parameter.
std::function<float(float, const std::size_t)> mStepFunc;
/// @brief Current learning Rate
float mLR;
/// @brief Initial learning Rate passed in the constructor
float mInitialLR;
/// @brief Step LRScheduler switches from warmup to LR (>= 1)
std::size_t mSwitchStep;
/// @brief Initial Warm Up passed in the constructor
float mInitialWarmUp;
public:
LRScheduler() = delete;
/**
* @brief Construct a new LRScheduler object. Default is ConstantLR.
*
* @param initialLR Initial learning rate.
* Will default to 0 if a negative value is passed.
* @param stepFunc Recursive update function for learning rate value.
* Default is the Constant function.
* @param nb_warmup_steps Number of warm-up steps before starting to use
* ``stepFunc`` for learning rate update. If specified, learning rate will
* linearly increase from 0 to ``initialLR``.
* Default is 0.
*/
LRScheduler(const float initialLR,
std::function<float(float, const std::size_t)> stepFunc = [](float val, const std::size_t /*step*/) { return val; },
const std::size_t nb_warmup_steps = 0)
: mStepFunc(stepFunc),
mLR((initialLR > 0.0f) ? initialLR : 0.0f),
mInitialLR(mLR),
mSwitchStep(nb_warmup_steps + 1),
mInitialWarmUp(mLR / static_cast<float>(nb_warmup_steps + 1))
{
// ctor
}
// Copy constructor
LRScheduler(const LRScheduler& other)
: mStep(other.mStep),
mStepFunc(other.mStepFunc),
mLR(other.mLR),
mInitialLR(other.mInitialLR),
mSwitchStep(other.mSwitchStep),
mInitialWarmUp(other.mInitialWarmUp)
{
// Copy constructor implementation
}
LRScheduler& operator=(const LRScheduler&) = default;
// LRScheduler(LRScheduler&&) = default;
public:
// Getters & setters
constexpr inline float learningRate() const noexcept { return mLR; }
constexpr void setNbWarmupSteps(const std::size_t nb_warmup_steps) noexcept {
mSwitchStep = nb_warmup_steps + 1;
mInitialWarmUp = mLR / static_cast<float>(nb_warmup_steps + 1);
}
/**
* @brief Update the learning rate to the next value.
* @note If the current step is lower than the switch step, the learning rate follows
* a linear function from 0 to the initial learning rate
* @note Else, the learning rate is updated using the provided function.
*/
constexpr void update() {
mLR = (mStep++ < mSwitchStep) ?
static_cast<float>(mStep) * mInitialWarmUp :
mStepFunc(mLR, mStep);
};
std::vector<float> lr_profiling(const std::size_t nbStep) const;
constexpr void reset() noexcept {
mStep = 0;
mLR = mInitialLR;
}
};
} // namespace Aidge
#endif /* AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_
#define AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_
#include "aidge/optimizer/LR/LRScheduler.hpp"
#include <cstddef> // std::size_t
namespace Aidge {
LRScheduler ConstantLR(const float initialLR) {
return LRScheduler(initialLR);
}
LRScheduler StepLR(const float initialLR, const std::size_t stepSize, float gamma = 0.1f) {
return LRScheduler(initialLR,
[stepSize, gamma](float val, const std::size_t step) {
return (step % stepSize == 0) ? val*gamma : val;
});
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_ */
\ No newline at end of file
aidge_learning aidge_learning
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment