Skip to content
Snippets Groups Projects
Commit 0ec99fd5 authored by Maxence Naud's avatar Maxence Naud
Browse files

upd SGD

parent 408be6e9
No related branches found
No related tags found
2 merge requests!3Dev - learning - v0.1.0,!1[Add] loss function system, MSE loss function, unit-test associated with MSE
......@@ -52,7 +52,7 @@ public:
mReversedDampening.set<float>(0, 1.0f - dampening);
}
void update() override {
void update() override final {
mLR.setBackend(mParameters[0]->getImpl()->backend());
mLR.set<float>(0, learningRate());
if (mParameters[0]->getImpl()->backend() != mMomentum.getImpl()->backend()) {
......@@ -74,7 +74,7 @@ public:
mLRScheduler.update();
}
void setParameters(const std::vector<std::shared_ptr<Tensor>>& parameters) {
void setParameters(const std::vector<std::shared_ptr<Tensor>>& parameters) override final {
Optimizer::setParameters(parameters);
mGradientInertia = std::vector<Tensor>(parameters.size());
for (std::size_t i = 0; i < parameters.size(); ++i) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment