Skip to content
Snippets Groups Projects
Commit a679ca4a authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] Slice kernel signature to be more generic

parent a23bb0cf
No related branches found
No related tags found
2 merge requests!22Update operators implementation,!16Draft: Tiling
......@@ -29,12 +29,16 @@ namespace Aidge {
template <DimIdx_t DIM>
class SliceImplForward_cpu
: public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>,
void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*,
void(const typename Slice_Op<DIM>::Attrs&,
const std::array<std::size_t, DIM>,
const void*,
void*)> {};
template <DimIdx_t DIM>
class SliceImplBackward_cpu
: public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>,
void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*,
void(const typename Slice_Op<DIM>::Attrs&,
const std::array<std::size_t, DIM>,
const void*,
void*)> {};
template <DimIdx_t DIM>
......
......@@ -63,8 +63,8 @@ void Aidge::SliceImpl_cpu<1>::forward() {
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->template dims<1>(),
std::get<1>(mOp.getStaticAttributes()),
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<1>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
......@@ -121,7 +121,7 @@ void Aidge::SliceImpl_cpu<2>::forward() {
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getStaticAttributes()
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<2>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
......@@ -182,8 +182,8 @@ void Aidge::SliceImpl_cpu<3>::forward() {
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->template dims<3>(),
std::get<1>(mOp.getStaticAttributes()),
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<3>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
......@@ -243,8 +243,8 @@ void Aidge::SliceImpl_cpu<4>::forward() {
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->template dims<4>(),
std::get<1>(mOp.getStaticAttributes()),
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<4>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment