Skip to content
Snippets Groups Projects
Commit e1dbf501 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] Add, Concat and Slice implementations

parent 563fb105
No related branches found
No related tags found
1 merge request!22Update operators implementation
...@@ -31,13 +31,8 @@ class AddImplBackward_cpu ...@@ -31,13 +31,8 @@ class AddImplBackward_cpu
class AddImpl_cpu : public OperatorImpl { class AddImpl_cpu : public OperatorImpl {
private:
const Add_Op& mOp;
std::vector<NbElts_t> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData = {};
public: public:
AddImpl_cpu(const Add_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {} AddImpl_cpu(const Add_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<AddImpl_cpu> create(const Add_Op& op) { static std::unique_ptr<AddImpl_cpu> create(const Add_Op& op) {
return std::make_unique<AddImpl_cpu>(op); return std::make_unique<AddImpl_cpu>(op);
......
...@@ -39,13 +39,8 @@ class ConcatImplBackward_cpu ...@@ -39,13 +39,8 @@ class ConcatImplBackward_cpu
class ConcatImpl_cpu : public OperatorImpl { class ConcatImpl_cpu : public OperatorImpl {
private:
const Concat_Op& mOp;
std::vector<NbElts_t> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData = {};
public: public:
ConcatImpl_cpu(const Concat_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {} ConcatImpl_cpu(const Concat_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<ConcatImpl_cpu> create(const Concat_Op& op) { static std::unique_ptr<ConcatImpl_cpu> create(const Concat_Op& op) {
return std::make_unique<ConcatImpl_cpu>(op); return std::make_unique<ConcatImpl_cpu>(op);
......
...@@ -43,13 +43,8 @@ class SliceImplBackward_cpu ...@@ -43,13 +43,8 @@ class SliceImplBackward_cpu
template <DimIdx_t DIM> template <DimIdx_t DIM>
class SliceImpl_cpu : public OperatorImpl { class SliceImpl_cpu : public OperatorImpl {
private:
const Slice_Op<DIM>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public: public:
SliceImpl_cpu(const Slice_Op<DIM>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} SliceImpl_cpu(const Slice_Op<DIM>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) { static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) {
return std::make_unique<SliceImpl_cpu<DIM>>(op); return std::make_unique<SliceImpl_cpu<DIM>>(op);
...@@ -57,10 +52,10 @@ class SliceImpl_cpu : public OperatorImpl { ...@@ -57,10 +52,10 @@ class SliceImpl_cpu : public OperatorImpl {
public: public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final { NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final {
assert(mOp.getInput(0) && "requires valid input"); assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors // Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims(); const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>()); std::multiplies<NbElts_t>());
...@@ -70,7 +65,7 @@ class SliceImpl_cpu : public OperatorImpl { ...@@ -70,7 +65,7 @@ class SliceImpl_cpu : public OperatorImpl {
const std::vector<DimSize_t>& inputsSize) const override final { const std::vector<DimSize_t>& inputsSize) const override final {
(void)outputIdx; (void)outputIdx;
(void)inputsSize; (void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->dims(); const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>()); std::multiplies<NbElts_t>());
} }
...@@ -89,17 +84,17 @@ class SliceImpl_cpu : public OperatorImpl { ...@@ -89,17 +84,17 @@ class SliceImpl_cpu : public OperatorImpl {
void forward() { void forward() {
// FIXME: uncomment the following code once memory handling will work // FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0"); assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create( auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create(
{mOp.getInput(0)->dataType()}); {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel // Call kernel
kernelFunc(mOp.getInput(0)->template dims<DIM>(), kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<DIM>(),
std::get<1>(mOp.getStaticAttributes()), std::get<1>(std::static_pointer_cast<const Slice_Op<DIM>&>(mOp).getStaticAttributes()),
mOp.getInput(0)->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr() std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
); );
// each input is consumed by the minimum amount for a forward pass // each input is consumed by the minimum amount for a forward pass
...@@ -115,19 +110,14 @@ class SliceImpl_cpu : public OperatorImpl { ...@@ -115,19 +110,14 @@ class SliceImpl_cpu : public OperatorImpl {
template <> template <>
class SliceImpl_cpu<1> : public OperatorImpl { class SliceImpl_cpu<1> : public OperatorImpl {
private: public:
const Slice_Op<1>& mOp; SliceImpl_cpu(const Slice_Op<1>& op) : OperatorImpl(op) {}
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<1>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) { static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) {
return std::make_unique<SliceImpl_cpu<1>>(op); return std::make_unique<SliceImpl_cpu<1>>(op);
} }
public: public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
...@@ -144,13 +134,8 @@ class SliceImpl_cpu<1> : public OperatorImpl { ...@@ -144,13 +134,8 @@ class SliceImpl_cpu<1> : public OperatorImpl {
template <> template <>
class SliceImpl_cpu<2> : public OperatorImpl { class SliceImpl_cpu<2> : public OperatorImpl {
private:
const Slice_Op<2>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public: public:
SliceImpl_cpu(const Slice_Op<2>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} SliceImpl_cpu(const Slice_Op<2>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) { static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) {
return std::make_unique<SliceImpl_cpu<2>>(op); return std::make_unique<SliceImpl_cpu<2>>(op);
...@@ -173,13 +158,8 @@ class SliceImpl_cpu<2> : public OperatorImpl { ...@@ -173,13 +158,8 @@ class SliceImpl_cpu<2> : public OperatorImpl {
template <> template <>
class SliceImpl_cpu<3> : public OperatorImpl { class SliceImpl_cpu<3> : public OperatorImpl {
private:
const Slice_Op<3>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public: public:
SliceImpl_cpu(const Slice_Op<3>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} SliceImpl_cpu(const Slice_Op<3>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) { static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) {
return std::make_unique<SliceImpl_cpu<3>>(op); return std::make_unique<SliceImpl_cpu<3>>(op);
...@@ -202,13 +182,8 @@ class SliceImpl_cpu<3> : public OperatorImpl { ...@@ -202,13 +182,8 @@ class SliceImpl_cpu<3> : public OperatorImpl {
template <> template <>
class SliceImpl_cpu<4> : public OperatorImpl { class SliceImpl_cpu<4> : public OperatorImpl {
private:
const Slice_Op<4>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public: public:
SliceImpl_cpu(const Slice_Op<4>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} SliceImpl_cpu(const Slice_Op<4>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) { static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) {
return std::make_unique<SliceImpl_cpu<4>>(op); return std::make_unique<SliceImpl_cpu<4>>(op);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment