Skip to content
Snippets Groups Projects
Commit 7844f1ae authored by Maxence Naud's avatar Maxence Naud
Browse files

Change 1D attribute Tensors for scalar Tensors and use compound assignment...

Change 1D attribute Tensors for scalar Tensors and use compound assignment operators on 'Parameter' to avoid reset of 'mGrad' attribute
parent 04eb304f
No related branches found
No related tags found
2 merge requests!28v0.2.2,!24Change 1D attribute Tensors for scalar Tensors and use compound assignment...
This commit is part of merge request !24. Comments created here will be created in the context of that merge request.
...@@ -35,12 +35,12 @@ class Adam: public Optimizer, public StaticAttributes<AdamAttr, float, float, fl ...@@ -35,12 +35,12 @@ class Adam: public Optimizer, public StaticAttributes<AdamAttr, float, float, fl
private: private:
std::vector<Tensor> mMomentum1; std::vector<Tensor> mMomentum1;
std::vector<Tensor> mMomentum2; std::vector<Tensor> mMomentum2;
Tensor mLR{std::vector<std::size_t>({1})}; Tensor mLR{1.0f};
Tensor mBeta1{std::vector<std::size_t>({1})}; Tensor mBeta1;
Tensor mReversedBeta1{std::vector<std::size_t>({1})}; Tensor mReversedBeta1;
Tensor mBeta2{std::vector<std::size_t>({1})}; Tensor mBeta2;
Tensor mReversedBeta2{std::vector<std::size_t>({1})}; Tensor mReversedBeta2;
Tensor mEpsilon{std::vector<std::size_t>({1})}; Tensor mEpsilon;
public: public:
using Attributes_ = StaticAttributes<AdamAttr, float, float, float>; using Attributes_ = StaticAttributes<AdamAttr, float, float, float>;
...@@ -51,19 +51,17 @@ public: ...@@ -51,19 +51,17 @@ public:
: Optimizer(), : Optimizer(),
Attributes_(attr<AdamAttr::Beta1>(beta1), Attributes_(attr<AdamAttr::Beta1>(beta1),
attr<AdamAttr::Beta2>(beta2), attr<AdamAttr::Beta2>(beta2),
attr<AdamAttr::Epsilon>(epsilon)) attr<AdamAttr::Epsilon>(epsilon)),
mBeta1(beta1),
mReversedBeta1(1.0f - beta1),
mBeta2(beta2),
mReversedBeta2(1.0f - beta2),
mEpsilon(epsilon)
{ {
mBeta1 = Tensor(Array1D<float, 1>{{beta1}});
mReversedBeta1 = Tensor(Array1D<float, 1>{{1.0f - beta1}});
mBeta2 = Tensor(Array1D<float, 1>{{beta2}});
mReversedBeta2 = Tensor(Array1D<float, 1>{{1.0f - beta2}});
mEpsilon = Tensor(Array1D<float, 1>{{epsilon}});
} }
void update() override final { void update() override final {
mLR = Tensor(Array1D<float, 1>{{learningRate()}}); mLR = Tensor(learningRate());
mLR.setBackend(mParameters[0]->getImpl()->backend()); mLR.setBackend(mParameters[0]->getImpl()->backend());
if (mParameters[0]->getImpl()->backend() != mBeta1.getImpl()->backend()) { if (mParameters[0]->getImpl()->backend() != mBeta1.getImpl()->backend()) {
...@@ -73,11 +71,11 @@ public: ...@@ -73,11 +71,11 @@ public:
mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend()); mReversedBeta2.setBackend(mParameters[0]->getImpl()->backend());
} }
Tensor alpha = Tensor(Array1D<float, 1>{{ static_cast<float>(learningRate() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), mLRScheduler.step() + 1)) Tensor alpha = Tensor(learningRate() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1)))
/ (1.0f - std::pow(this->getAttr<AdamAttr::Beta1>(), mLRScheduler.step() + 1))) }}); / (1.0f - std::pow(this->getAttr<AdamAttr::Beta1>(), static_cast<float>(mLRScheduler.step() + 1))));
alpha.setBackend(mParameters[0]->getImpl()->backend()); alpha.setBackend(mParameters[0]->getImpl()->backend());
Tensor epsilon = Tensor(Array1D<float, 1>{{ static_cast<float>(this->getAttr<AdamAttr::Epsilon>() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), mLRScheduler.step() + 1))) }}); Tensor epsilon = Tensor(this->getAttr<AdamAttr::Epsilon>() * std::sqrt(1.0f - std::pow(this->getAttr<AdamAttr::Beta2>(), static_cast<float>(mLRScheduler.step() + 1))));
epsilon.setBackend(mParameters[0]->getImpl()->backend()); epsilon.setBackend(mParameters[0]->getImpl()->backend());
if (mLRScheduler.step() == 0) { if (mLRScheduler.step() == 0) {
...@@ -90,13 +88,13 @@ public: ...@@ -90,13 +88,13 @@ public:
mMomentum2[i].zeros(); mMomentum2[i].zeros();
} }
} }
for (std::size_t i = 0; i < mParameters.size(); ++i) { for (std::size_t i = 0; i < mParameters.size(); ++i) {
mMomentum1[i] = mBeta1 * mMomentum1[i] + mReversedBeta1 * (*mParameters[i]->grad()); mMomentum1[i] = mBeta1 * mMomentum1[i] + mReversedBeta1 * (*mParameters[i]->grad());
mMomentum2[i] = mBeta2 * mMomentum2[i] + mReversedBeta2 * (*mParameters[i]->grad()) * (*mParameters[i]->grad()); mMomentum2[i] = mBeta2 * mMomentum2[i] + mReversedBeta2 * (*mParameters[i]->grad()) * (*mParameters[i]->grad());
*mParameters[i] = *mParameters[i] - alpha * mMomentum1[i] / (mMomentum2[i].sqrt() + epsilon); *mParameters[i] -= alpha * mMomentum1[i] / (mMomentum2[i].sqrt() + epsilon);
} }
mLRScheduler.update(); mLRScheduler.update();
} }
......
...@@ -47,23 +47,23 @@ public: ...@@ -47,23 +47,23 @@ public:
Attributes_(attr<SGDAttr::Momentum>(momentum), Attributes_(attr<SGDAttr::Momentum>(momentum),
attr<SGDAttr::Dampening>(dampening)) attr<SGDAttr::Dampening>(dampening))
{ {
mMomentum = Tensor(Array1D<float, 1>{{momentum}}); mMomentum = Tensor(momentum);
mReversedDampening = Tensor(Array1D<float, 1>{{1.0f - dampening}}); mReversedDampening = Tensor(1.0f - dampening);
} }
void update() override final { void update() override final {
mLR = Tensor(Array1D<float, 1>{{learningRate()}}); mLR = Tensor(learningRate());
mLR.setBackend(mParameters[0]->getImpl()->backend()); mLR.setBackend(mParameters[0]->getImpl()->backend());
if (mLRScheduler.step() == 0) { if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) { for (std::size_t i = 0; i < mParameters.size(); ++i) {
mGradientInertia[i] = mParameters[i]->grad()->clone(); mGradientInertia[i] = mParameters[i]->grad()->clone();
*mParameters[i] = *mParameters[i] - mLR*mGradientInertia[i]; *mParameters[i] -= mLR*mGradientInertia[i];
} }
} else { } else {
for (std::size_t i = 0; i < mParameters.size(); ++i) { for (std::size_t i = 0; i < mParameters.size(); ++i) {
mGradientInertia[i] = mMomentum*mGradientInertia[i] + mReversedDampening*(*mParameters[i]->grad()); mGradientInertia[i] = mMomentum*mGradientInertia[i] + mReversedDampening*(*mParameters[i]->grad());
*mParameters[i] = *mParameters[i] - mLR*mGradientInertia[i]; *mParameters[i] -= mLR*mGradientInertia[i];
} }
} }
mLRScheduler.update(); mLRScheduler.update();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment