diff --git a/include/aidge/learning/optimizer/SGD.hpp b/include/aidge/learning/optimizer/SGD.hpp index 768a3d05604909c248ac105c444d4cda1aee93c2..cef3573bc6d90211014ea171bdf32e04ec89fb86 100644 --- a/include/aidge/learning/optimizer/SGD.hpp +++ b/include/aidge/learning/optimizer/SGD.hpp @@ -27,33 +27,38 @@ namespace Aidge { enum class SGDAttr { Momentum, - Dampening + Dampening, + WeightDecay }; -class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float> { +class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float, float> { private: std::vector<Tensor> mGradientInertia; Tensor mLR{std::vector<std::size_t>({1})}; Tensor mMomentum{std::vector<std::size_t>({1})}; Tensor mReversedDampening{std::vector<std::size_t>({1})}; + Tensor mWeightDecay{std::vector<std::size_t>({1})}; public: - using Attributes_ = StaticAttributes<SGDAttr, float, float>; + using Attributes_ = StaticAttributes<SGDAttr, float, float, float>; template <SGDAttr e> using attr = typename Attributes_::template attr<e>; - SGD(const float momentum = 0.0f, const float dampening = 0.0f) + SGD(const float momentum = 0.0f, const float dampening = 0.0f, const float weightDecay = 0.0f) : Optimizer(), Attributes_(attr<SGDAttr::Momentum>(momentum), - attr<SGDAttr::Dampening>(dampening)) + attr<SGDAttr::Dampening>(dampening), + attr<SGDAttr::WeightDecay>(weightDecay)) { mMomentum = Tensor(momentum); mReversedDampening = Tensor(1.0f - dampening); + mWeightDecay = Tensor(weightDecay); } void update() override final { mLR = Tensor(learningRate()); mLR.setBackend(mParameters[0]->getImpl()->backend()); + mWeightDecay.setBackend(mParameters[0]->getImpl()->backend()); if (mLRScheduler.step() == 0) { for (std::size_t i = 0; i < mParameters.size(); ++i) { @@ -62,8 +67,9 @@ public: } } else { for (std::size_t i = 0; i < mParameters.size(); ++i) { - mGradientInertia[i] = mMomentum*mGradientInertia[i] + mReversedDampening*(*mParameters[i]->grad()); - *mParameters[i] -= mLR*mGradientInertia[i]; + (*mParameters[i]->grad()) += mWeightDecay * (*mParameters[i]); + mGradientInertia[i] = mMomentum * mGradientInertia[i] + mReversedDampening * (*mParameters[i]->grad()); + *mParameters[i] -= mLR * mGradientInertia[i]; } } mLRScheduler.update(); diff --git a/python_binding/learning/optimizer/pybind_SGD.cpp b/python_binding/learning/optimizer/pybind_SGD.cpp index b9c50ad5e0c765721dc37c0a3109fa08e819e3bb..9ddb6c29fb4c96c9f4d275f741352839ca41eb28 100644 --- a/python_binding/learning/optimizer/pybind_SGD.cpp +++ b/python_binding/learning/optimizer/pybind_SGD.cpp @@ -20,7 +20,7 @@ namespace Aidge { void init_SGD(py::module& m) { py::class_<SGD, std::shared_ptr<SGD>, Attributes, Optimizer>(m, "SGD", py::multiple_inheritance()) - .def(py::init<float, float>(), py::arg("momentum") = 0.0f, py::arg("dampening") = 0.0f) + .def(py::init<float, float, float>(), py::arg("momentum") = 0.0f, py::arg("dampening") = 0.0f, py::arg("weight_decay") = 0.0f) .def("update", &SGD::update); } // } // namespace learning