From a679ca4ab479bba63c2a08b994f11f8f7a34cae7 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 17 Oct 2023 12:25:32 +0000 Subject: [PATCH] [Upd] Slice kernel signature to be more generic --- include/aidge/backend/cpu/operator/SliceImpl.hpp | 8 ++++++-- src/operator/SliceImpl.cpp | 14 +++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/include/aidge/backend/cpu/operator/SliceImpl.hpp b/include/aidge/backend/cpu/operator/SliceImpl.hpp index 6c4b50c0..dddab386 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl.hpp @@ -29,12 +29,16 @@ namespace Aidge { template <DimIdx_t DIM> class SliceImplForward_cpu : public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>, - void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, + void(const typename Slice_Op<DIM>::Attrs&, + const std::array<std::size_t, DIM>, + const void*, void*)> {}; template <DimIdx_t DIM> class SliceImplBackward_cpu : public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>, - void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, + void(const typename Slice_Op<DIM>::Attrs&, + const std::array<std::size_t, DIM>, + const void*, void*)> {}; template <DimIdx_t DIM> diff --git a/src/operator/SliceImpl.cpp b/src/operator/SliceImpl.cpp index bd571f24..4baaef60 100644 --- a/src/operator/SliceImpl.cpp +++ b/src/operator/SliceImpl.cpp @@ -63,8 +63,8 @@ void Aidge::SliceImpl_cpu<1>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getInput(0)->template dims<1>(), - std::get<1>(mOp.getStaticAttributes()), + kernelFunc(mOp.getStaticAttributes(), + mOp.getInput(0)->template dims<1>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() ); @@ -121,7 +121,7 @@ void Aidge::SliceImpl_cpu<2>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getStaticAttributes() + kernelFunc(mOp.getStaticAttributes(), mOp.getInput(0)->template dims<2>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() @@ -182,8 +182,8 @@ void Aidge::SliceImpl_cpu<3>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getInput(0)->template dims<3>(), - std::get<1>(mOp.getStaticAttributes()), + kernelFunc(mOp.getStaticAttributes(), + mOp.getInput(0)->template dims<3>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() ); @@ -243,8 +243,8 @@ void Aidge::SliceImpl_cpu<4>::forward() { {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp.getInput(0)->template dims<4>(), - std::get<1>(mOp.getStaticAttributes()), + kernelFunc(mOp.getStaticAttributes(), + mOp.getInput(0)->template dims<4>(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr() ); -- GitLab