From 090c5050c5675fea7de3ccfaf34dfdcb44467c42 Mon Sep 17 00:00:00 2001 From: Antoni Olivier <olivier.antoni@cea.fr> Date: Tue, 1 Apr 2025 11:05:37 +0200 Subject: [PATCH] Adam optimizer: make clean --- include/aidge/learning/optimizer/Adam.hpp | 43 ++++++++++------------- unit_tests/optimizer/Test_Adam.cpp | 21 +++++++---- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/include/aidge/learning/optimizer/Adam.hpp b/include/aidge/learning/optimizer/Adam.hpp index a018d6e..2932918 100644 --- a/include/aidge/learning/optimizer/Adam.hpp +++ b/include/aidge/learning/optimizer/Adam.hpp @@ -35,12 +35,10 @@ class Adam: public Optimizer, public StaticAttributes<AdamAttr, float, float, fl private: std::vector<Tensor> mMomentum1; std::vector<Tensor> mMomentum2; - Tensor mLR{1.0f}; Tensor mBeta1; Tensor mReversedBeta1; Tensor mBeta2; Tensor mReversedBeta2; - Tensor mEpsilon; public: using Attributes_ = StaticAttributes<AdamAttr, float, float, float>; @@ -51,39 +49,30 @@ public: : Optimizer(), Attributes_(attr<AdamAttr::Beta1>(beta1), attr<AdamAttr::Beta2>(beta2), - attr<AdamAttr::Epsilon>(epsilon)), - mBeta1(beta1), - mReversedBeta1(1.0f - beta1), - mBeta2(beta2), - mReversedBeta2(1.0f - beta2), - mEpsilon(epsilon) + attr<AdamAttr::Epsilon>(epsilon)) { + mBeta1 = Tensor(beta1); + mReversedBeta1 = Tensor(1.0f - beta1); + mBeta2 = Tensor(beta2); + mReversedBeta2 = Tensor(1.0f - beta2); } void update() override final { - mLR = Tensor(learningRate()); - mLR.setBackend(mParameters[0]->getImpl()->backend()); + float mBeta1Power = std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1)); + float mBeta2Power = std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1)); + float mReversedBeta1Power = 1.0f - mBeta1Power; + float mSqrtReversedBeta2Power = std::sqrt(1.0f - mBeta2Power); - if (mParameters[0]->getImpl()->backend() != mBeta1.getImpl()->backend()) { - mBeta1.setBackend(mParameters[0]->getImpl()->backend()); - mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend()); - mBeta2.setBackend(mParameters[0]->getImpl()->backend()); - mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend()); - } - - Tensor alpha = Tensor(learningRate() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1))) - / (1.0f - std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1)))); + Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power); alpha.setBackend(mParameters[0]->getImpl()->backend()); - Tensor epsilon = Tensor(this->getAttr<AdamAttr::Epsilon>() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1)))); - epsilon.setBackend(mParameters[0]->getImpl()->backend()); + Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power); + epsilon_hat.setBackend(mParameters[0]->getImpl()->backend()); if (mLRScheduler.step() == 0) { for (std::size_t i = 0; i < mParameters.size(); ++i) { - mMomentum1[i].setBackend(mParameters[i]->getImpl()->backend()); mMomentum1[i].setDataType(mParameters[i]->grad()->dataType()); mMomentum1[i].zeros(); - mMomentum2[i].setBackend(mParameters[i]->getImpl()->backend()); mMomentum2[i].setDataType(mParameters[i]->grad()->dataType()); mMomentum2[i].zeros(); } @@ -92,7 +81,7 @@ public: for (std::size_t i = 0; i < mParameters.size(); ++i) { mMomentum1[i] = mBeta1 * mMomentum1[i] + mReversedBeta1 * (*mParameters[i]->grad()); mMomentum2[i] = mBeta2 * mMomentum2[i] + mReversedBeta2 * (*mParameters[i]->grad()) * (*mParameters[i]->grad()); - *mParameters[i] -= alpha * mMomentum1[i] / (mMomentum2[i].sqrt() + epsilon); + *mParameters[i] -= alpha * mMomentum1[i] / (mMomentum2[i].sqrt() + epsilon_hat); } mLRScheduler.update(); @@ -108,6 +97,12 @@ public: mMomentum2[i] = Tensor(parameters[i]->dims()); mMomentum2[i].setBackend(parameters[i]->getImpl()->backend()); } + if (parameters.size() > 0) { + mBeta1.setBackend(mParameters[0]->getImpl()->backend()); + mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend()); + mBeta2.setBackend(mParameters[0]->getImpl()->backend()); + mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend()); + } } }; diff --git a/unit_tests/optimizer/Test_Adam.cpp b/unit_tests/optimizer/Test_Adam.cpp index 632cba9..3fa6606 100644 --- a/unit_tests/optimizer/Test_Adam.cpp +++ b/unit_tests/optimizer/Test_Adam.cpp @@ -130,14 +130,18 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") { for (std::size_t step = 0; step < 10; ++step) { // truth - float lr2 = lr * std::sqrt(1.0f - std::pow(beta2, static_cast<float>(step + 1))) / (1.0f - std::pow(beta1, static_cast<float>(step + 1))); - float epsilon2 = epsilon * std::sqrt(1.0f - std::pow(beta2, static_cast<float>(step + 1))); + float beta1_power = std::pow(beta1, static_cast<float>(step + 1)); + float beta2_power = std::pow(beta2, static_cast<float>(step + 1)); + float Reversed_beta1_power = 1.0f - beta1_power; + float sqrtReversed_beta2_power = std::sqrt(1.0f - beta2_power); + float alpha = lr * sqrtReversed_beta2_power / Reversed_beta1_power; + float epsilon_hat = epsilon * sqrtReversed_beta2_power; for (std::size_t t = 0; t < nb_tensors; ++t) { for (std::size_t i = 0; i < size_tensors[t]; ++i) { val_momentum1_tensors[t][i] = beta1 * val_momentum1_tensors[t][i] + (1.0f - beta1) * val_grad_tensors[t][i]; val_momentum2_tensors[t][i] = beta2 * val_momentum2_tensors[t][i] + (1.0f - beta2) * val_grad_tensors[t][i] * val_grad_tensors[t][i]; val_tensors[t][i] = val_tensors[t][i] - - lr2 * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) + epsilon2); + - alpha * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) + epsilon_hat); } } // optimizer @@ -257,20 +261,23 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") { for (std::size_t step = 0; step < 10; ++step) { // truth - float lr2 = lr * std::sqrt(1.0f - std::pow(beta2, step + 1)) / (1.0f - std::pow(beta1, step + 1)); - float epsilon2 = epsilon * std::sqrt(1.0f - std::pow(beta2, step + 1)); + float beta1_power = std::pow(beta1, static_cast<float>(step + 1)); + float beta2_power = std::pow(beta2, static_cast<float>(step + 1)); + float Reversed_beta1_power = 1.0f - beta1_power; + float sqrtReversed_beta2_power = std::sqrt(1.0f - beta2_power); + float alpha = lr * sqrtReversed_beta2_power / Reversed_beta1_power; + float epsilon_hat = epsilon * sqrtReversed_beta2_power; for (std::size_t t = 0; t < nb_tensors; ++t) { for (std::size_t i = 0; i < size_tensors[t]; ++i) { val_momentum1_tensors[t][i] = beta1 * val_momentum1_tensors[t][i] + (1.0f - beta1) * val_grad_tensors[t][i]; val_momentum2_tensors[t][i] = beta2 * val_momentum2_tensors[t][i] + (1.0f - beta2) * val_grad_tensors[t][i] * val_grad_tensors[t][i]; val_tensors[t][i] = val_tensors[t][i] - - lr2 * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) + epsilon2); + - alpha * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) + epsilon_hat); } cudaMemcpy(d_val_momentum1_tensors[t], val_momentum1_tensors[t].get(), size_tensors[t] * sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(d_val_momentum2_tensors[t], val_momentum2_tensors[t].get(), size_tensors[t] * sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(d_val_tensors[t], val_tensors[t].get(), size_tensors[t] * sizeof(float), cudaMemcpyHostToDevice); } - // optimizer opt.update(); // tests -- GitLab