diff --git a/include/aidge/backend/cpu/operator/SliceImpl.hpp b/include/aidge/backend/cpu/operator/SliceImpl.hpp index 6c4b50c089f43f146bb52f7e6f1ee0301c7e986d..dddab386005114c67e72e8db4e1d56eebc11b08e 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl.hpp @@ -29,12 +29,16 @@ namespace Aidge { template <DimIdx_t DIM> class SliceImplForward_cpu : public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>, - void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, + void(const typename Slice_Op<DIM>::Attrs&, + const std::array<std::size_t, DIM>, + const void*, void*)> {}; template <DimIdx_t DIM> class SliceImplBackward_cpu : public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>, - void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, + void(const typename Slice_Op<DIM>::Attrs&, + const std::array<std::size_t, DIM>, + const void*, void*)> {}; template <DimIdx_t DIM> diff --git a/src/operator/SliceImpl.cpp b/src/operator/SliceImpl.cpp index bd571f24f7a5537a7ad503069085f34e04633592..4baaef6087ad611f041b6d11fb3a6a0b5b3b63a1 100644 --- a/src/operator/SliceImpl.cpp +++ b/src/operator/SliceImpl.cpp @@ -63,8 +63,8 @@ void Aidge::SliceImpl_cpu<1>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getInput(0)->template dims<1>(), - std::get<1>(mOp.getStaticAttributes()), + kernelFunc(mOp.getStaticAttributes(), + mOp.getInput(0)->template dims<1>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() ); @@ -121,7 +121,7 @@ void Aidge::SliceImpl_cpu<2>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getStaticAttributes() + kernelFunc(mOp.getStaticAttributes(), mOp.getInput(0)->template dims<2>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() @@ -182,8 +182,8 @@ void Aidge::SliceImpl_cpu<3>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getInput(0)->template dims<3>(), - std::get<1>(mOp.getStaticAttributes()), + kernelFunc(mOp.getStaticAttributes(), + mOp.getInput(0)->template dims<3>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() ); @@ -243,8 +243,8 @@ void Aidge::SliceImpl_cpu<4>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getInput(0)->template dims<4>(), - std::get<1>(mOp.getStaticAttributes()), + kernelFunc(mOp.getStaticAttributes(), + mOp.getInput(0)->template dims<4>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() );