Skip to content
Snippets Groups Projects
Commit e1dbf501 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] Add, Concat and Slice implementations

parent 563fb105
No related branches found
No related tags found
1 merge request!22Update operators implementation
......@@ -31,13 +31,8 @@ class AddImplBackward_cpu
class AddImpl_cpu : public OperatorImpl {
private:
const Add_Op& mOp;
std::vector<NbElts_t> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData = {};
public:
AddImpl_cpu(const Add_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {}
AddImpl_cpu(const Add_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<AddImpl_cpu> create(const Add_Op& op) {
return std::make_unique<AddImpl_cpu>(op);
......
......@@ -39,13 +39,8 @@ class ConcatImplBackward_cpu
class ConcatImpl_cpu : public OperatorImpl {
private:
const Concat_Op& mOp;
std::vector<NbElts_t> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData = {};
public:
ConcatImpl_cpu(const Concat_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {}
ConcatImpl_cpu(const Concat_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<ConcatImpl_cpu> create(const Concat_Op& op) {
return std::make_unique<ConcatImpl_cpu>(op);
......
......@@ -43,13 +43,8 @@ class SliceImplBackward_cpu
template <DimIdx_t DIM>
class SliceImpl_cpu : public OperatorImpl {
private:
const Slice_Op<DIM>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<DIM>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
SliceImpl_cpu(const Slice_Op<DIM>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) {
return std::make_unique<SliceImpl_cpu<DIM>>(op);
......@@ -57,10 +52,10 @@ class SliceImpl_cpu : public OperatorImpl {
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final {
assert(mOp.getInput(0) && "requires valid input");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims();
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
......@@ -70,7 +65,7 @@ class SliceImpl_cpu : public OperatorImpl {
const std::vector<DimSize_t>& inputsSize) const override final {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->dims();
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
......@@ -89,17 +84,17 @@ class SliceImpl_cpu : public OperatorImpl {
void forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create(
{mOp.getInput(0)->dataType()});
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->template dims<DIM>(),
std::get<1>(mOp.getStaticAttributes()),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<DIM>(),
std::get<1>(std::static_pointer_cast<const Slice_Op<DIM>&>(mOp).getStaticAttributes()),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
......@@ -115,19 +110,14 @@ class SliceImpl_cpu : public OperatorImpl {
template <>
class SliceImpl_cpu<1> : public OperatorImpl {
private:
const Slice_Op<1>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<1>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
public:
SliceImpl_cpu(const Slice_Op<1>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) {
return std::make_unique<SliceImpl_cpu<1>>(op);
}
public:
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
......@@ -144,13 +134,8 @@ class SliceImpl_cpu<1> : public OperatorImpl {
template <>
class SliceImpl_cpu<2> : public OperatorImpl {
private:
const Slice_Op<2>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<2>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
SliceImpl_cpu(const Slice_Op<2>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) {
return std::make_unique<SliceImpl_cpu<2>>(op);
......@@ -173,13 +158,8 @@ class SliceImpl_cpu<2> : public OperatorImpl {
template <>
class SliceImpl_cpu<3> : public OperatorImpl {
private:
const Slice_Op<3>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<3>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
SliceImpl_cpu(const Slice_Op<3>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) {
return std::make_unique<SliceImpl_cpu<3>>(op);
......@@ -202,13 +182,8 @@ class SliceImpl_cpu<3> : public OperatorImpl {
template <>
class SliceImpl_cpu<4> : public OperatorImpl {
private:
const Slice_Op<4>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<4>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
SliceImpl_cpu(const Slice_Op<4>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) {
return std::make_unique<SliceImpl_cpu<4>>(op);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment