Skip to content
Snippets Groups Projects
Commit c5e67b0e authored by Maxence Naud's avatar Maxence Naud
Browse files

move files to learning dir, add .cpp, upd includes

parent 8dec7ca3
No related branches found
No related tags found
1 merge request!3Dev - learning - v0.1.0
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#define AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_
#include <cstddef> // std::size_t
#include <functional> // std::function
#include <vector>
namespace Aidge {
/**
* @class LRScheduler
* @brief Manage the learning rate evolution.
*
*/
class LRScheduler {
private:
/// @brief Current step
std::size_t mStep = 0;
/// @brief Pointer to LR update function. Takes a single float value as parameter.
std::function<float(float, const std::size_t)> mStepFunc;
/// @brief Current learning Rate
float mLR;
/// @brief Initial learning Rate passed in the constructor
float mInitialLR;
/// @brief Step LRScheduler switches from warmup to LR (>= 1)
std::size_t mSwitchStep;
/// @brief Initial Warm Up value deduced from the initial learning rate passed
/// in the constructor and the number of warm-up steps
float mInitialWarmUp;
public:
LRScheduler() = delete;
/**
* @brief Construct a new LRScheduler object. Default is ConstantLR.
*
* @param initialLR Initial learning rate.
* Will default to 0 if a negative value is passed.
* @param stepFunc Recursive update function for learning rate value.
* Default is the Constant function.
* @param nb_warmup_steps Number of warm-up steps before starting to use
* ``stepFunc`` for learning rate update. If specified, learning rate will
* linearly increase from 0 to ``initialLR``.
* Default is 0.
*/
LRScheduler(const float initialLR,
std::function<float(float, const std::size_t)> stepFunc = [](float val, const std::size_t /*step*/) { return val; },
const std::size_t nb_warmup_steps = 0)
: mStepFunc(stepFunc),
mLR((initialLR > 0.0f) ? initialLR : 0.0f),
mInitialLR(mLR),
mSwitchStep(nb_warmup_steps + 1),
mInitialWarmUp(mLR / static_cast<float>(nb_warmup_steps + 1))
{
// ctor
}
// Copy constructor
LRScheduler(const LRScheduler& other)
: mStep(other.mStep),
mStepFunc(other.mStepFunc),
mLR(other.mLR),
mInitialLR(other.mInitialLR),
mSwitchStep(other.mSwitchStep),
mInitialWarmUp(other.mInitialWarmUp)
{
// Copy constructor implementation
}
LRScheduler& operator=(const LRScheduler&) = default;
// LRScheduler(LRScheduler&&) = default;
public:
// Getters & setters
constexpr inline std::size_t step() const noexcept {return mStep; }
constexpr inline float learningRate() const noexcept { return mLR; }
constexpr void setNbWarmupSteps(const std::size_t nb_warmup_steps) noexcept {
mSwitchStep = nb_warmup_steps + 1;
mInitialWarmUp = mLR / static_cast<float>(nb_warmup_steps + 1);
}
/**
* @brief Update the learning rate to the next value.
* @note If the current step is lower than the switch step, the learning rate follows
* a linear function from 0 to the initial learning rate
* @note Else, the learning rate is updated using the provided function.
*/
constexpr void update() {
mLR = (mStep++ < mSwitchStep) ?
static_cast<float>(mStep) * mInitialWarmUp :
mStepFunc(mLR, mStep);
};
std::vector<float> lr_profiling(const std::size_t nbStep) const;
constexpr void reset() noexcept {
mStep = 0;
mLR = mInitialLR;
}
};
} // namespace Aidge
#endif /* AIDGE_CORE_OPTIMIZER_LRSCHEDULER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_
#define AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_
#include "aidge/learning/learningRate/LRScheduler.hpp"
#include <cstddef> // std::size_t
namespace Aidge {
namespace learning {
LRScheduler ConstantLR(const float initialLR);
LRScheduler StepLR(const float initialLR, const std::size_t stepSize, float gamma = 0.1f);
} // learning
} // namespace Aidge
#endif /* AIDGE_CORE_OPTIMIZER_LRSCHEDULERSLIST_H_ */
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
* *
********************************************************************************/ ********************************************************************************/
#include "aidge/optimizer/LR/LRScheduler.hpp" #include "aidge/learning/learningRate/LRScheduler.hpp"
#include <cstddef> // std::size_t
#include <vector> #include <vector>
std::vector<float> Aidge::LRScheduler::lr_profiling(const std::size_t nbStep) const { std::vector<float> Aidge::LRScheduler::lr_profiling(const std::size_t nbStep) const {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/learning/learningRate/LRSchedulerList.hpp"
#include <cstddef> // std::size_t
#include "aidge/learning/learningRate/LRScheduler.hpp"
Aidge::LRScheduler Aidge::learning::ConstantLR(const float initialLR) {
return LRScheduler(initialLR);
}
Aidge::LRScheduler Aidge::learning::StepLR(const float initialLR, const std::size_t stepSize, float gamma) {
return LRScheduler(initialLR,
[stepSize, gamma](float val, const std::size_t step) {
return (step % stepSize == 0) ? val*gamma : val;
});
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment