Skip to content
Snippets Groups Projects
Commit 58f7c200 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'fix_optim_dataType' into 'dev'

Fix tensor data type in optimizer (SGD, Adam)

See merge request !42
parents 670fd4dc 00e34c6c
No related branches found
No related tags found
2 merge requests!44Update 0.2.3 -> 0.3.0,!42Fix tensor data type in optimizer (SGD, Adam)
Pipeline #70642 passed
...@@ -65,15 +65,15 @@ public: ...@@ -65,15 +65,15 @@ public:
Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power); Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power);
alpha.setBackend(mParameters[0]->getImpl()->backend()); alpha.setBackend(mParameters[0]->getImpl()->backend());
alpha.setDataType(mParameters[0]->dataType());
Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power); Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power);
epsilon_hat.setBackend(mParameters[0]->getImpl()->backend()); epsilon_hat.setBackend(mParameters[0]->getImpl()->backend());
epsilon_hat.setDataType(mParameters[0]->dataType());
if (mLRScheduler.step() == 0) { if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) { for (std::size_t i = 0; i < mParameters.size(); ++i) {
mMomentum1[i].setDataType(mParameters[i]->grad()->dataType());
mMomentum1[i].zeros(); mMomentum1[i].zeros();
mMomentum2[i].setDataType(mParameters[i]->grad()->dataType());
mMomentum2[i].zeros(); mMomentum2[i].zeros();
} }
} }
...@@ -94,14 +94,20 @@ public: ...@@ -94,14 +94,20 @@ public:
for (std::size_t i = 0; i < parameters.size(); ++i) { for (std::size_t i = 0; i < parameters.size(); ++i) {
mMomentum1[i] = Tensor(parameters[i]->dims()); mMomentum1[i] = Tensor(parameters[i]->dims());
mMomentum1[i].setBackend(parameters[i]->getImpl()->backend()); mMomentum1[i].setBackend(parameters[i]->getImpl()->backend());
mMomentum1[i].setDataType(parameters[i]->dataType());
mMomentum2[i] = Tensor(parameters[i]->dims()); mMomentum2[i] = Tensor(parameters[i]->dims());
mMomentum2[i].setBackend(parameters[i]->getImpl()->backend()); mMomentum2[i].setBackend(parameters[i]->getImpl()->backend());
mMomentum2[i].setDataType(parameters[i]->dataType());
} }
if (parameters.size() > 0) { if (parameters.size() > 0) {
mBeta1.setBackend(mParameters[0]->getImpl()->backend()); mBeta1.setBackend(mParameters[0]->getImpl()->backend());
mBeta1.setDataType(parameters[0]->dataType());
mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend()); mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend());
mReversedBeta1.setDataType(parameters[0]->dataType());
mBeta2.setBackend(mParameters[0]->getImpl()->backend()); mBeta2.setBackend(mParameters[0]->getImpl()->backend());
mBeta2.setDataType(parameters[0]->dataType());
mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend()); mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
mReversedBeta2.setDataType(parameters[0]->dataType());
} }
} }
}; };
......
...@@ -58,7 +58,9 @@ public: ...@@ -58,7 +58,9 @@ public:
void update() override final { void update() override final {
mLR = Tensor(learningRate()); mLR = Tensor(learningRate());
mLR.setBackend(mParameters[0]->getImpl()->backend()); mLR.setBackend(mParameters[0]->getImpl()->backend());
mLR.setDataType(mParameters[0]->dataType());
mWeightDecay.setBackend(mParameters[0]->getImpl()->backend()); mWeightDecay.setBackend(mParameters[0]->getImpl()->backend());
mWeightDecay.setDataType(mParameters[0]->dataType());
if (mLRScheduler.step() == 0) { if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) { for (std::size_t i = 0; i < mParameters.size(); ++i) {
...@@ -81,10 +83,13 @@ public: ...@@ -81,10 +83,13 @@ public:
for (std::size_t i = 0; i < parameters.size(); ++i) { for (std::size_t i = 0; i < parameters.size(); ++i) {
mGradientInertia[i] = Tensor(parameters[i]->dims()); mGradientInertia[i] = Tensor(parameters[i]->dims());
mGradientInertia[i].setBackend(parameters[i]->backend()); mGradientInertia[i].setBackend(parameters[i]->backend());
mGradientInertia[i].setDataType(parameters[i]->dataType());
} }
if (parameters.size() > 0) { if (parameters.size() > 0) {
mReversedDampening.setBackend(mParameters[0]->getImpl()->backend()); mReversedDampening.setBackend(mParameters[0]->getImpl()->backend());
mReversedDampening.setDataType(parameters[0]->dataType());
mMomentum.setBackend(mParameters[0]->getImpl()->backend()); mMomentum.setBackend(mParameters[0]->getImpl()->backend());
mMomentum.setDataType(parameters[0]->dataType());
} }
} }
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment