Skip to content
Snippets Groups Projects
Commit 53e9162f authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'scheduling' into 'dev'

Improved scheduling

See merge request !45
parents f805a9af cd558133
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!45Improved scheduling
Pipeline #42701 passed
Showing
with 25 additions and 83 deletions
...@@ -38,7 +38,7 @@ public: ...@@ -38,7 +38,7 @@ public:
return std::make_unique<ReshapeImpl_cpu>(op); return std::make_unique<ReshapeImpl_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
......
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
return std::make_unique<ScalingImpl_cpu>(op); return std::make_unique<ScalingImpl_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
......
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
return std::make_unique<SigmoidImpl_cpu>(op); return std::make_unique<SigmoidImpl_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
......
...@@ -46,14 +46,6 @@ public: ...@@ -46,14 +46,6 @@ public:
return std::make_unique<SliceImpl_cpu>(op); return std::make_unique<SliceImpl_cpu>(op);
} }
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward() override; void forward() override;
void backward() override; void backward() override;
}; };
......
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
return std::make_unique<SoftmaxImpl_cpu>(op); return std::make_unique<SoftmaxImpl_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
......
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
return std::make_unique<SqrtImpl_cpu>(op); return std::make_unique<SqrtImpl_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override final; void forward() override final;
......
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
return std::make_unique<SubImpl_cpu>(op); return std::make_unique<SubImpl_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
......
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
return std::make_unique<TanhImpl_cpu>(op); return std::make_unique<TanhImpl_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
......
...@@ -63,7 +63,6 @@ public: ...@@ -63,7 +63,6 @@ public:
return std::make_unique<TransposeImpl2D_cpu>(op); return std::make_unique<TransposeImpl2D_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
class TransposeImpl3D_cpu : public OperatorImpl { class TransposeImpl3D_cpu : public OperatorImpl {
...@@ -74,7 +73,6 @@ public: ...@@ -74,7 +73,6 @@ public:
return std::make_unique<TransposeImpl3D_cpu>(op); return std::make_unique<TransposeImpl3D_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
class TransposeImpl4D_cpu : public OperatorImpl { class TransposeImpl4D_cpu : public OperatorImpl {
...@@ -85,7 +83,6 @@ public: ...@@ -85,7 +83,6 @@ public:
return std::make_unique<TransposeImpl4D_cpu>(op); return std::make_unique<TransposeImpl4D_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
class TransposeImpl5D_cpu : public OperatorImpl { class TransposeImpl5D_cpu : public OperatorImpl {
...@@ -96,7 +93,6 @@ public: ...@@ -96,7 +93,6 @@ public:
return std::make_unique<TransposeImpl5D_cpu>(op); return std::make_unique<TransposeImpl5D_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
class TransposeImpl6D_cpu : public OperatorImpl { class TransposeImpl6D_cpu : public OperatorImpl {
...@@ -107,7 +103,6 @@ public: ...@@ -107,7 +103,6 @@ public:
return std::make_unique<TransposeImpl6D_cpu>(op); return std::make_unique<TransposeImpl6D_cpu>(op);
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override; void forward() override;
}; };
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
Aidge::NbElts_t Aidge::AddImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::AddImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::AddImpl_cpu::forward() { void Aidge::AddImpl_cpu::forward() {
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
#include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp" #include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp"
#include "aidge/backend/cpu/operator/AvgPoolingImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/AvgPoolingImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::AvgPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::AvgPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::AvgPoolingImpl2D_cpu::forward() { void Aidge::AvgPoolingImpl2D_cpu::forward() {
......
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" #include "aidge/backend/cpu/operator/BatchNormImpl.hpp"
#include "aidge/backend/cpu/operator/BatchNormImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/BatchNormImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::BatchNormImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::BatchNormImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::BatchNormImpl2D_cpu::forward() { void Aidge::BatchNormImpl2D_cpu::forward() {
......
...@@ -21,46 +21,6 @@ ...@@ -21,46 +21,6 @@
#include "aidge/backend/cpu/operator/ConcatImpl.hpp" #include "aidge/backend/cpu/operator/ConcatImpl.hpp"
#include "aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// for the direct convolution algorithm, convolutions can be in-place, if there is no padding!
return 0;
}
Aidge::NbElts_t Aidge::ConcatImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t>& /*inputsSize*/) const {
// Requires the whole tensors, regardless of available data on inputs
assert(outputIdx == 0 && "operator has only one output");
(void) outputIdx;
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbConsumedData(const Aidge::IOIndex_t inputIdx) const {
assert(inputIdx < mNbConsumedData.size());
return mNbConsumedData[inputIdx];
}
Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbProducedData(const Aidge::IOIndex_t outputIdx) const {
assert(outputIdx < mNbProducedData.size());
return mNbProducedData[outputIdx];
}
void Aidge::ConcatImpl_cpu::updateConsummerProducer() {
for (IOIndex_t inputIdx = 0; static_cast<NbElts_t>(inputIdx) < mNbConsumedData.size(); ++inputIdx)
mNbConsumedData[inputIdx]+= getNbRequiredData(inputIdx); // each input is consumed by the minimum amount for a forward pass
mNbProducedData[0]+= getRequiredMemory(0, {});
}
void Aidge::ConcatImpl_cpu::forward() { void Aidge::ConcatImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input in Concat operator"); assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input in Concat operator");
DataType datatypeFirstInput = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(); DataType datatypeFirstInput = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType();
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
#include "aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp" #include "aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp"
#include "aidge/backend/cpu/operator/ConvDepthWiseImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/ConvDepthWiseImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::ConvDepthWiseImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::ConvDepthWiseImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::ConvDepthWiseImpl2D_cpu::forward() { void Aidge::ConvDepthWiseImpl2D_cpu::forward() {
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
#include "aidge/backend/cpu/operator/ConvImpl.hpp" #include "aidge/backend/cpu/operator/ConvImpl.hpp"
#include "aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::ConvImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::ConvImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::ConvImpl2D_cpu::forward() { void Aidge::ConvImpl2D_cpu::forward() {
......
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
Aidge::NbElts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::DivImpl_cpu::forward() { void Aidge::DivImpl_cpu::forward() {
......
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
#include "aidge/operator/Erf.hpp" #include "aidge/operator/Erf.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
Aidge::NbElts_t Aidge::ErfImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::ErfImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::ErfImpl_cpu::forward() { void Aidge::ErfImpl_cpu::forward() {
......
...@@ -20,11 +20,6 @@ ...@@ -20,11 +20,6 @@
#include "aidge/operator/Gather.hpp" #include "aidge/operator/Gather.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::GatherImpl_cpu::forward() { void Aidge::GatherImpl_cpu::forward() {
const Gather_Op& op = static_cast<const Gather_Op&>(mOp); const Gather_Op& op = static_cast<const Gather_Op&>(mOp);
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
#include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp"
Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::LeakyReLUImpl_cpu::forward() { void Aidge::LeakyReLUImpl_cpu::forward() {
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
#include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp" #include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp"
#include "aidge/backend/cpu/operator/MaxPoolingImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/MaxPoolingImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::MaxPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { Aidge::Elts_t Aidge::MaxPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return Elts_t::DataElts(0);
} }
void Aidge::MaxPoolingImpl2D_cpu::forward() { void Aidge::MaxPoolingImpl2D_cpu::forward() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment