Skip to content
Snippets Groups Projects

[Feat] Add Weight Decay to the SGD Optimizer

Merged Benjamin Halimi requested to merge weight_decay into dev
2 files
+ 14
8
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -27,33 +27,38 @@ namespace Aidge {
enum class SGDAttr {
Momentum,
Dampening
Dampening,
WeightDecay
};
class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float> {
class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float, float> {
private:
std::vector<Tensor> mGradientInertia;
Tensor mLR{std::vector<std::size_t>({1})};
Tensor mMomentum{std::vector<std::size_t>({1})};
Tensor mReversedDampening{std::vector<std::size_t>({1})};
Tensor mWeightDecay{std::vector<std::size_t>({1})};
public:
using Attributes_ = StaticAttributes<SGDAttr, float, float>;
using Attributes_ = StaticAttributes<SGDAttr, float, float, float>;
template <SGDAttr e>
using attr = typename Attributes_::template attr<e>;
SGD(const float momentum = 0.0f, const float dampening = 0.0f)
SGD(const float momentum = 0.0f, const float dampening = 0.0f, const float weightDecay = 0.0f)
: Optimizer(),
Attributes_(attr<SGDAttr::Momentum>(momentum),
attr<SGDAttr::Dampening>(dampening))
attr<SGDAttr::Dampening>(dampening),
attr<SGDAttr::WeightDecay>(weightDecay))
{
mMomentum = Tensor(momentum);
mReversedDampening = Tensor(1.0f - dampening);
mWeightDecay = Tensor(weightDecay);
}
void update() override final {
mLR = Tensor(learningRate());
mLR.setBackend(mParameters[0]->getImpl()->backend());
mWeightDecay.setBackend(mParameters[0]->getImpl()->backend());
if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) {
@@ -62,8 +67,9 @@ public:
}
} else {
for (std::size_t i = 0; i < mParameters.size(); ++i) {
mGradientInertia[i] = mMomentum*mGradientInertia[i] + mReversedDampening*(*mParameters[i]->grad());
*mParameters[i] -= mLR*mGradientInertia[i];
(*mParameters[i]->grad()) += mWeightDecay * (*mParameters[i]);
mGradientInertia[i] = mMomentum * mGradientInertia[i] + mReversedDampening * (*mParameters[i]->grad());
*mParameters[i] -= mLR * mGradientInertia[i];
}
}
mLRScheduler.update();
Loading