Skip to content
Snippets Groups Projects

[Feature] Handle the Device Index

Merged Benjamin Halimi requested to merge use_device into dev
3 unresolved threads
Files
7
@@ -58,23 +58,43 @@ public:
@@ -58,23 +58,43 @@ public:
}
}
void update() override final {
void update() override final {
 
 
auto backend = mParameters[0]->backend();
 
auto device = mParameters[0]->device();
 
auto dataType = mParameters[0]->dataType();
 
float mBeta1Power = std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1));
float mBeta1Power = std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1));
float mBeta2Power = std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1));
float mBeta2Power = std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1));
 
float mReversedBeta1Power = 1.0f - mBeta1Power;
float mReversedBeta1Power = 1.0f - mBeta1Power;
float mSqrtReversedBeta2Power = std::sqrt(1.0f - mBeta2Power);
float mSqrtReversedBeta2Power = std::sqrt(1.0f - mBeta2Power);
Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power);
Tensor alpha = Tensor(learningRate() * mSqrtReversedBeta2Power / mReversedBeta1Power);
alpha.setBackend(mParameters[0]->getImpl()->backend());
alpha.setBackend(backend, device);
alpha.setDataType(mParameters[0]->dataType());
alpha.setDataType(dataType);
Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power);
Tensor epsilon_hat = Tensor(this->getAttr<AdamAttr::Epsilon>() * mSqrtReversedBeta2Power);
epsilon_hat.setBackend(mParameters[0]->getImpl()->backend());
epsilon_hat.setBackend(backend, device);
epsilon_hat.setDataType(mParameters[0]->dataType());
epsilon_hat.setDataType(dataType);
 
 
mBeta1.setBackend(backend, device);
 
mBeta1.setDataType(dataType);
 
mReversedBeta1.setBackend(backend, device);
 
mReversedBeta1.setDataType(dataType);
 
 
mBeta2.setBackend(backend, device);
 
mBeta2.setDataType(dataType);
 
mReversedBeta2.setBackend(backend, device);
 
mReversedBeta2.setDataType(dataType);
if (mLRScheduler.step() == 0) {
if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) {
for (std::size_t i = 0; i < mParameters.size(); ++i) {
mMomentum1[i].zeros();
mMomentum1[i].zeros();
mMomentum2[i].zeros();
mMomentum1[i].setBackend(backend, device);
 
mMomentum1[i].setDataType(dataType);
 
mMomentum2[i].zeros();
 
mMomentum2[i].setBackend(backend, device);
 
mMomentum2[i].setDataType(dataType);
}
}
}
}
@@ -88,25 +108,33 @@ public:
@@ -88,25 +108,33 @@ public:
}
}
void setParameters(const std::vector<std::shared_ptr<Tensor>>& parameters) override final {
void setParameters(const std::vector<std::shared_ptr<Tensor>>& parameters) override final {
 
Optimizer::setParameters(parameters);
Optimizer::setParameters(parameters);
mMomentum1 = std::vector<Tensor>(parameters.size());
mMomentum1 = std::vector<Tensor>(parameters.size());
mMomentum2 = std::vector<Tensor>(parameters.size());
mMomentum2 = std::vector<Tensor>(parameters.size());
 
for (std::size_t i = 0; i < parameters.size(); ++i) {
for (std::size_t i = 0; i < parameters.size(); ++i) {
 
mMomentum1[i] = Tensor(parameters[i]->dims());
mMomentum1[i] = Tensor(parameters[i]->dims());
mMomentum1[i].setBackend(parameters[i]->getImpl()->backend());
mMomentum1[i].setBackend(parameters[i]->backend(), parameters[i]->device());
mMomentum1[i].setDataType(parameters[i]->dataType());
mMomentum1[i].setDataType(parameters[i]->dataType());
 
mMomentum2[i] = Tensor(parameters[i]->dims());
mMomentum2[i] = Tensor(parameters[i]->dims());
mMomentum2[i].setBackend(parameters[i]->getImpl()->backend());
mMomentum2[i].setBackend(parameters[i]->backend(), parameters[i]->device());
mMomentum2[i].setDataType(parameters[i]->dataType());
mMomentum2[i].setDataType(parameters[i]->dataType());
}
}
if (parameters.size() > 0) {
if (parameters.size() > 0) {
mBeta1.setBackend(mParameters[0]->getImpl()->backend());
 
mBeta1.setBackend(mParameters[0]->backend(), mParameters[0]->device());
mBeta1.setDataType(parameters[0]->dataType());
mBeta1.setDataType(parameters[0]->dataType());
mReversedBeta1.setBackend(mParameters[0]->getImpl()->backend());
 
mReversedBeta1.setBackend(mParameters[0]->backend(), mParameters[0]->device());
mReversedBeta1.setDataType(parameters[0]->dataType());
mReversedBeta1.setDataType(parameters[0]->dataType());
mBeta2.setBackend(mParameters[0]->getImpl()->backend());
 
mBeta2.setBackend(mParameters[0]->backend(), mParameters[0]->device());
mBeta2.setDataType(parameters[0]->dataType());
mBeta2.setDataType(parameters[0]->dataType());
mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
 
mReversedBeta2.setBackend(mParameters[0]->backend(), mParameters[0]->device());
mReversedBeta2.setDataType(parameters[0]->dataType());
mReversedBeta2.setDataType(parameters[0]->dataType());
}
}
}
}
Loading