diff --git a/include/aidge/learning/optimizer/Adam.hpp b/include/aidge/learning/optimizer/Adam.hpp index 2932918c225e69d342ae59b47912c7bbd8c021f9..8c89e53f94dade675d3e7da139554561929f5e3d 100644 --- a/include/aidge/learning/optimizer/Adam.hpp +++ b/include/aidge/learning/optimizer/Adam.hpp @@ -65,15 +65,15 @@ public: Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power); alpha.setBackend(mParameters[0]->getImpl()->backend()); + alpha.setDataType(mParameters[0]->dataType()); Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power); epsilon_hat.setBackend(mParameters[0]->getImpl()->backend()); + epsilon_hat.setDataType(mParameters[0]->dataType()); if (mLRScheduler.step() == 0) { for (std::size_t i = 0; i < mParameters.size(); ++i) { - mMomentum1[i].setDataType(mParameters[i]->grad()->dataType()); mMomentum1[i].zeros(); - mMomentum2[i].setDataType(mParameters[i]->grad()->dataType()); mMomentum2[i].zeros(); } } @@ -94,14 +94,20 @@ public: for (std::size_t i = 0; i < parameters.size(); ++i) { mMomentum1[i] = Tensor(parameters[i]->dims()); mMomentum1[i].setBackend(parameters[i]->getImpl()->backend()); + mMomentum1[i].setDataType(parameters[i]->dataType()); mMomentum2[i] = Tensor(parameters[i]->dims()); mMomentum2[i].setBackend(parameters[i]->getImpl()->backend()); + mMomentum2[i].setDataType(parameters[i]->dataType()); } if (parameters.size() > 0) { mBeta1.setBackend(mParameters[0]->getImpl()->backend()); + mBeta1.setDataType(parameters[0]->dataType()); mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend()); + mReversedBeta1.setDataType(parameters[0]->dataType()); mBeta2.setBackend(mParameters[0]->getImpl()->backend()); + mBeta2.setDataType(parameters[0]->dataType()); mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend()); + mReversedBeta2.setDataType(parameters[0]->dataType()); } } }; diff --git a/include/aidge/learning/optimizer/SGD.hpp b/include/aidge/learning/optimizer/SGD.hpp index cef3573bc6d90211014ea171bdf32e04ec89fb86..da029b36fae81af32aad79e668fab1e98e1a0076 100644 --- a/include/aidge/learning/optimizer/SGD.hpp +++ b/include/aidge/learning/optimizer/SGD.hpp @@ -58,7 +58,9 @@ public: void update() override final { mLR = Tensor(learningRate()); mLR.setBackend(mParameters[0]->getImpl()->backend()); + mLR.setDataType(mParameters[0]->dataType()); mWeightDecay.setBackend(mParameters[0]->getImpl()->backend()); + mWeightDecay.setDataType(mParameters[0]->dataType()); if (mLRScheduler.step() == 0) { for (std::size_t i = 0; i < mParameters.size(); ++i) { @@ -81,10 +83,13 @@ public: for (std::size_t i = 0; i < parameters.size(); ++i) { mGradientInertia[i] = Tensor(parameters[i]->dims()); mGradientInertia[i].setBackend(parameters[i]->backend()); + mGradientInertia[i].setDataType(parameters[i]->dataType()); } if (parameters.size() > 0) { mReversedDampening.setBackend(mParameters[0]->getImpl()->backend()); + mReversedDampening.setDataType(parameters[0]->dataType()); mMomentum.setBackend(mParameters[0]->getImpl()->backend()); + mMomentum.setDataType(parameters[0]->dataType()); } } };