Skip to content
Snippets Groups Projects
Commit aae638bb authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'weight_decay' into 'dev'

[Feat] Add Weight Decay to the SGD Optimizer

See merge request !35
parents f1261398 f49e2f40
No related branches found
No related tags found
2 merge requests!44Update 0.2.3 -> 0.3.0,!35[Feat] Add Weight Decay to the SGD Optimizer
Pipeline #65493 passed
......@@ -27,33 +27,38 @@ namespace Aidge {
enum class SGDAttr {
Momentum,
Dampening
Dampening,
WeightDecay
};
class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float> {
class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float, float> {
private:
std::vector<Tensor> mGradientInertia;
Tensor mLR{std::vector<std::size_t>({1})};
Tensor mMomentum{std::vector<std::size_t>({1})};
Tensor mReversedDampening{std::vector<std::size_t>({1})};
Tensor mWeightDecay{std::vector<std::size_t>({1})};
public:
using Attributes_ = StaticAttributes<SGDAttr, float, float>;
using Attributes_ = StaticAttributes<SGDAttr, float, float, float>;
template <SGDAttr e>
using attr = typename Attributes_::template attr<e>;
SGD(const float momentum = 0.0f, const float dampening = 0.0f)
SGD(const float momentum = 0.0f, const float dampening = 0.0f, const float weightDecay = 0.0f)
: Optimizer(),
Attributes_(attr<SGDAttr::Momentum>(momentum),
attr<SGDAttr::Dampening>(dampening))
attr<SGDAttr::Dampening>(dampening),
attr<SGDAttr::WeightDecay>(weightDecay))
{
mMomentum = Tensor(momentum);
mReversedDampening = Tensor(1.0f - dampening);
mWeightDecay = Tensor(weightDecay);
}
void update() override final {
mLR = Tensor(learningRate());
mLR.setBackend(mParameters[0]->getImpl()->backend());
mWeightDecay.setBackend(mParameters[0]->getImpl()->backend());
if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) {
......@@ -62,8 +67,9 @@ public:
}
} else {
for (std::size_t i = 0; i < mParameters.size(); ++i) {
mGradientInertia[i] = mMomentum*mGradientInertia[i] + mReversedDampening*(*mParameters[i]->grad());
*mParameters[i] -= mLR*mGradientInertia[i];
(*mParameters[i]->grad()) += mWeightDecay * (*mParameters[i]);
mGradientInertia[i] = mMomentum * mGradientInertia[i] + mReversedDampening * (*mParameters[i]->grad());
*mParameters[i] -= mLR * mGradientInertia[i];
}
}
mLRScheduler.update();
......
......@@ -20,7 +20,7 @@ namespace Aidge {
void init_SGD(py::module& m) {
py::class_<SGD, std::shared_ptr<SGD>, Attributes, Optimizer>(m, "SGD", py::multiple_inheritance())
.def(py::init<float, float>(), py::arg("momentum") = 0.0f, py::arg("dampening") = 0.0f)
.def(py::init<float, float, float>(), py::arg("momentum") = 0.0f, py::arg("dampening") = 0.0f, py::arg("weight_decay") = 0.0f)
.def("update", &SGD::update);
}
// } // namespace learning
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment