[Feature] Handle the Device Index
3 unresolved threads
3 unresolved threads
Compare changes
Files
7@@ -58,23 +58,43 @@ public:
@@ -58,23 +58,43 @@ public:
float mBeta1Power = std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1));
float mBeta2Power = std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1));
@@ -88,25 +108,33 @@ public:
@@ -88,25 +108,33 @@ public: