Skip to content
Snippets Groups Projects

Fix tensor data type in optimizer (SGD, Adam)

Merged Olivier Antoni requested to merge oantoni/aidge_learning:fix_optim_dataType into dev
2 files
+ 13
2
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -65,15 +65,15 @@ public:
Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power);
alpha.setBackend(mParameters[0]->getImpl()->backend());
alpha.setDataType(mParameters[0]->dataType());
Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power);
epsilon_hat.setBackend(mParameters[0]->getImpl()->backend());
epsilon_hat.setDataType(mParameters[0]->dataType());
if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) {
mMomentum1[i].setDataType(mParameters[i]->grad()->dataType());
mMomentum1[i].zeros();
mMomentum2[i].setDataType(mParameters[i]->grad()->dataType());
mMomentum2[i].zeros();
}
}
@@ -94,14 +94,20 @@ public:
for (std::size_t i = 0; i < parameters.size(); ++i) {
mMomentum1[i] = Tensor(parameters[i]->dims());
mMomentum1[i].setBackend(parameters[i]->getImpl()->backend());
mMomentum1[i].setDataType(parameters[i]->dataType());
mMomentum2[i] = Tensor(parameters[i]->dims());
mMomentum2[i].setBackend(parameters[i]->getImpl()->backend());
mMomentum2[i].setDataType(parameters[i]->dataType());
}
if (parameters.size() > 0) {
mBeta1.setBackend(mParameters[0]->getImpl()->backend());
mBeta1.setDataType(parameters[0]->dataType());
mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend());
mReversedBeta1.setDataType(parameters[0]->dataType());
mBeta2.setBackend(mParameters[0]->getImpl()->backend());
mBeta2.setDataType(parameters[0]->dataType());
mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
mReversedBeta2.setDataType(parameters[0]->dataType());
}
}
};
Loading