From 090c5050c5675fea7de3ccfaf34dfdcb44467c42 Mon Sep 17 00:00:00 2001
From: Antoni Olivier <olivier.antoni@cea.fr>
Date: Tue, 1 Apr 2025 11:05:37 +0200
Subject: [PATCH] Adam optimizer: make clean

---
 include/aidge/learning/optimizer/Adam.hpp | 43 ++++++++++-------------
 unit_tests/optimizer/Test_Adam.cpp        | 21 +++++++----
 2 files changed, 33 insertions(+), 31 deletions(-)

diff --git a/include/aidge/learning/optimizer/Adam.hpp b/include/aidge/learning/optimizer/Adam.hpp
index a018d6e..2932918 100644
--- a/include/aidge/learning/optimizer/Adam.hpp
+++ b/include/aidge/learning/optimizer/Adam.hpp
@@ -35,12 +35,10 @@ class Adam: public Optimizer, public StaticAttributes<AdamAttr, float, float, fl
 private:
     std::vector<Tensor> mMomentum1;
     std::vector<Tensor> mMomentum2;
-    Tensor mLR{1.0f};
     Tensor mBeta1;
     Tensor mReversedBeta1;
     Tensor mBeta2;
     Tensor mReversedBeta2;
-    Tensor mEpsilon;
 
 public:
     using Attributes_ = StaticAttributes<AdamAttr, float, float, float>;
@@ -51,39 +49,30 @@ public:
         : Optimizer(),
           Attributes_(attr<AdamAttr::Beta1>(beta1),
                       attr<AdamAttr::Beta2>(beta2),
-                      attr<AdamAttr::Epsilon>(epsilon)),
-          mBeta1(beta1),
-          mReversedBeta1(1.0f - beta1),
-          mBeta2(beta2),
-          mReversedBeta2(1.0f - beta2),
-          mEpsilon(epsilon)
+                      attr<AdamAttr::Epsilon>(epsilon))
     {
+        mBeta1 = Tensor(beta1);
+        mReversedBeta1 = Tensor(1.0f - beta1);
+        mBeta2 = Tensor(beta2);
+        mReversedBeta2 = Tensor(1.0f - beta2);
     }
 
     void update() override final {
-        mLR = Tensor(learningRate());
-        mLR.setBackend(mParameters[0]->getImpl()->backend());
+        float mBeta1Power = std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1));
+        float mBeta2Power = std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1));
+        float mReversedBeta1Power = 1.0f - mBeta1Power;
+        float mSqrtReversedBeta2Power = std::sqrt(1.0f - mBeta2Power);
 
-        if (mParameters[0]->getImpl()->backend() != mBeta1.getImpl()->backend()) {
-            mBeta1.setBackend(mParameters[0]->getImpl()->backend());
-            mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend());
-            mBeta2.setBackend(mParameters[0]->getImpl()->backend());
-            mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
-        }
-
-        Tensor alpha = Tensor(learningRate() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1)))
-                                           / (1.0f - std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1))));
+        Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power);
         alpha.setBackend(mParameters[0]->getImpl()->backend());
 
-        Tensor epsilon = Tensor(this->getAttr<AdamAttr::Epsilon>() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1))));
-        epsilon.setBackend(mParameters[0]->getImpl()->backend());
+        Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power);
+        epsilon_hat.setBackend(mParameters[0]->getImpl()->backend());
 
         if (mLRScheduler.step() == 0) {
             for (std::size_t i = 0; i < mParameters.size(); ++i) {
-                mMomentum1[i].setBackend(mParameters[i]->getImpl()->backend());
                 mMomentum1[i].setDataType(mParameters[i]->grad()->dataType());
                 mMomentum1[i].zeros();
-                mMomentum2[i].setBackend(mParameters[i]->getImpl()->backend());
                 mMomentum2[i].setDataType(mParameters[i]->grad()->dataType());
                 mMomentum2[i].zeros();
             }
@@ -92,7 +81,7 @@ public:
         for (std::size_t i = 0; i < mParameters.size(); ++i) {
             mMomentum1[i] = mBeta1 * mMomentum1[i] + mReversedBeta1 * (*mParameters[i]->grad());
             mMomentum2[i] = mBeta2 * mMomentum2[i] + mReversedBeta2 * (*mParameters[i]->grad()) * (*mParameters[i]->grad());
-            *mParameters[i] -= alpha * mMomentum1[i] / (mMomentum2[i].sqrt() +  epsilon);
+            *mParameters[i] -= alpha * mMomentum1[i] / (mMomentum2[i].sqrt() + epsilon_hat);
         }
 
         mLRScheduler.update();
@@ -108,6 +97,12 @@ public:
             mMomentum2[i] = Tensor(parameters[i]->dims());
             mMomentum2[i].setBackend(parameters[i]->getImpl()->backend());
         }
+        if (parameters.size() > 0) {
+            mBeta1.setBackend(mParameters[0]->getImpl()->backend());
+            mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend());
+            mBeta2.setBackend(mParameters[0]->getImpl()->backend());
+            mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
+        }
     }
 };
 
diff --git a/unit_tests/optimizer/Test_Adam.cpp b/unit_tests/optimizer/Test_Adam.cpp
index 632cba9..3fa6606 100644
--- a/unit_tests/optimizer/Test_Adam.cpp
+++ b/unit_tests/optimizer/Test_Adam.cpp
@@ -130,14 +130,18 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") {
 
             for (std::size_t step = 0; step < 10; ++step) {
                 // truth
-                float lr2 = lr * std::sqrt(1.0f - std::pow(beta2, static_cast<float>(step + 1))) / (1.0f - std::pow(beta1, static_cast<float>(step + 1)));
-                float epsilon2 = epsilon * std::sqrt(1.0f - std::pow(beta2, static_cast<float>(step + 1)));
+                float beta1_power = std::pow(beta1, static_cast<float>(step + 1));
+                float beta2_power = std::pow(beta2, static_cast<float>(step + 1));
+                float Reversed_beta1_power = 1.0f - beta1_power;
+                float sqrtReversed_beta2_power = std::sqrt(1.0f - beta2_power);
+                float alpha = lr *  sqrtReversed_beta2_power / Reversed_beta1_power;
+                float epsilon_hat = epsilon * sqrtReversed_beta2_power;
                 for (std::size_t t = 0; t < nb_tensors; ++t) {
                     for (std::size_t i = 0; i < size_tensors[t]; ++i) {
                         val_momentum1_tensors[t][i] = beta1 * val_momentum1_tensors[t][i] + (1.0f - beta1) * val_grad_tensors[t][i];
                         val_momentum2_tensors[t][i] = beta2 * val_momentum2_tensors[t][i] + (1.0f - beta2) * val_grad_tensors[t][i] * val_grad_tensors[t][i];
                         val_tensors[t][i] = val_tensors[t][i]
-                                        - lr2 * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) +  epsilon2);
+                                        - alpha * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) +  epsilon_hat);
                     }
                 }
                 // optimizer
@@ -257,20 +261,23 @@ TEST_CASE("[learning/Adam] update", "[Optimizer][Adam]") {
 
             for (std::size_t step = 0; step < 10; ++step) {
                 // truth
-                float lr2 = lr * std::sqrt(1.0f - std::pow(beta2, step + 1)) / (1.0f - std::pow(beta1, step + 1));
-                float epsilon2 = epsilon * std::sqrt(1.0f - std::pow(beta2, step + 1));
+                float beta1_power = std::pow(beta1, static_cast<float>(step + 1));
+                float beta2_power = std::pow(beta2, static_cast<float>(step + 1));
+                float Reversed_beta1_power = 1.0f - beta1_power;
+                float sqrtReversed_beta2_power = std::sqrt(1.0f - beta2_power);
+                float alpha = lr *  sqrtReversed_beta2_power / Reversed_beta1_power;
+                float epsilon_hat = epsilon * sqrtReversed_beta2_power;
                 for (std::size_t t = 0; t < nb_tensors; ++t) {
                     for (std::size_t i = 0; i < size_tensors[t]; ++i) {
                         val_momentum1_tensors[t][i] = beta1 * val_momentum1_tensors[t][i] + (1.0f - beta1) * val_grad_tensors[t][i];
                         val_momentum2_tensors[t][i] = beta2 * val_momentum2_tensors[t][i] + (1.0f - beta2) * val_grad_tensors[t][i] * val_grad_tensors[t][i];
                         val_tensors[t][i] = val_tensors[t][i]
-                                        - lr2 * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) +  epsilon2);
+                                        - alpha * val_momentum1_tensors[t][i] / (std::sqrt(val_momentum2_tensors[t][i]) +  epsilon_hat);
                     }
                     cudaMemcpy(d_val_momentum1_tensors[t], val_momentum1_tensors[t].get(), size_tensors[t] * sizeof(float), cudaMemcpyHostToDevice);
                     cudaMemcpy(d_val_momentum2_tensors[t], val_momentum2_tensors[t].get(), size_tensors[t] * sizeof(float), cudaMemcpyHostToDevice);
                     cudaMemcpy(d_val_tensors[t], val_tensors[t].get(), size_tensors[t] * sizeof(float), cudaMemcpyHostToDevice);
                 }
-
                 // optimizer
                 opt.update();
                 // tests
-- 
GitLab