diff --git a/include/aidge/learning/optimizer/Adam.hpp b/include/aidge/learning/optimizer/Adam.hpp index b5a1f01e84ed279ff7963c1179cd9d207fc4dca8..125cfd792ca61c94f52c0238229f282d4d7f8e47 100644 --- a/include/aidge/learning/optimizer/Adam.hpp +++ b/include/aidge/learning/optimizer/Adam.hpp @@ -53,39 +53,33 @@ public: attr<AdamAttr::Beta2>(beta2), attr<AdamAttr::Epsilon>(epsilon)) { - mBeta1.setBackend("cpu"); - mBeta1.set<float>(0, beta1); - mReversedBeta1.setBackend("cpu"); - mReversedBeta1.set<float>(0, 1.0f - beta1); - - mBeta2.setBackend("cpu"); - mBeta2.set<float>(0, beta2); - mReversedBeta2.setBackend("cpu"); - mReversedBeta2.set<float>(0, 1.0f - beta2); - - mEpsilon.setBackend("cpu"); - mEpsilon.set<float>(0, epsilon); + mBeta1 = Tensor(Array1D<float, 1>{{beta1}}); + mReversedBeta1 = Tensor(Array1D<float, 1>{{1.0f - beta1}}); + + mBeta2 = Tensor(Array1D<float, 1>{{beta2}}); + mReversedBeta2 = Tensor(Array1D<float, 1>{{1.0f - beta2}}); + + mEpsilon = Tensor(Array1D<float, 1>{{epsilon}}); } - void update() override final { + void update() override final { + mLR = Tensor(Array1D<float, 1>{{learningRate()}}); mLR.setBackend(mParameters[0]->getImpl()->backend()); - mLR.set<float>(0, learningRate()); + if (mParameters[0]->getImpl()->backend() != mBeta1.getImpl()->backend()) { mBeta1.setBackend(mParameters[0]->getImpl()->backend()); mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend()); mBeta2.setBackend(mParameters[0]->getImpl()->backend()); mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend()); } - - Tensor alpha{std::vector<std::size_t>({1})}; + + Tensor alpha = Tensor(Array1D<float, 1>{{ static_cast<float>(learningRate() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), mLRScheduler.step() + 1)) + / (1.0f - std::pow(this->getAttr<AdamAttr::Beta1>(), mLRScheduler.step() + 1))) }}); alpha.setBackend(mParameters[0]->getImpl()->backend()); - alpha.set<float>(0, learningRate() * std::sqrt(1.0f - std::pow(mBeta2.get<float>(0), mLRScheduler.step() + 1)) - / (1.0f - std::pow(mBeta1.get<float>(0), mLRScheduler.step() + 1))); - Tensor epsilon{std::vector<std::size_t>({1})}; + Tensor epsilon = Tensor(Array1D<float, 1>{{ static_cast<float>(this->getAttr<AdamAttr::Epsilon>() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), mLRScheduler.step() + 1))) }}); epsilon.setBackend(mParameters[0]->getImpl()->backend()); - epsilon.set<float>(0, mEpsilon.get<float>(0) * std::sqrt(1.0f - std::pow(mBeta2.get<float>(0), mLRScheduler.step() + 1))); - + if (mLRScheduler.step() == 0) { for (std::size_t i = 0; i < mParameters.size(); ++i) { mMomentum1[i].setBackend(mParameters[i]->getImpl()->backend()); @@ -112,7 +106,9 @@ public: mMomentum2 = std::vector<Tensor>(parameters.size()); for (std::size_t i = 0; i < parameters.size(); ++i) { mMomentum1[i] = Tensor(parameters[i]->dims()); + mMomentum1[i].setBackend(parameters[i]->getImpl()->backend()); mMomentum2[i] = Tensor(parameters[i]->dims()); + mMomentum2[i].setBackend(parameters[i]->getImpl()->backend()); } } };