Skip to content
Snippets Groups Projects
Commit a679ca4a authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] Slice kernel signature to be more generic

parent a23bb0cf
No related branches found
No related tags found
2 merge requests!22Update operators implementation,!16Draft: Tiling
...@@ -29,12 +29,16 @@ namespace Aidge { ...@@ -29,12 +29,16 @@ namespace Aidge {
template <DimIdx_t DIM> template <DimIdx_t DIM>
class SliceImplForward_cpu class SliceImplForward_cpu
: public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>, : public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>,
void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, void(const typename Slice_Op<DIM>::Attrs&,
const std::array<std::size_t, DIM>,
const void*,
void*)> {}; void*)> {};
template <DimIdx_t DIM> template <DimIdx_t DIM>
class SliceImplBackward_cpu class SliceImplBackward_cpu
: public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>, : public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>,
void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, void(const typename Slice_Op<DIM>::Attrs&,
const std::array<std::size_t, DIM>,
const void*,
void*)> {}; void*)> {};
template <DimIdx_t DIM> template <DimIdx_t DIM>
......
...@@ -63,8 +63,8 @@ void Aidge::SliceImpl_cpu<1>::forward() { ...@@ -63,8 +63,8 @@ void Aidge::SliceImpl_cpu<1>::forward() {
{mOp.getInput(0)->dataType()}); {mOp.getInput(0)->dataType()});
// Call kernel // Call kernel
kernelFunc(mOp.getInput(0)->template dims<1>(), kernelFunc(mOp.getStaticAttributes(),
std::get<1>(mOp.getStaticAttributes()), mOp.getInput(0)->template dims<1>(),
mOp.getInput(0)->getImpl()->rawPtr(), mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr() mOp.getOutput(0)->getImpl()->rawPtr()
); );
...@@ -121,7 +121,7 @@ void Aidge::SliceImpl_cpu<2>::forward() { ...@@ -121,7 +121,7 @@ void Aidge::SliceImpl_cpu<2>::forward() {
{mOp.getInput(0)->dataType()}); {mOp.getInput(0)->dataType()});
// Call kernel // Call kernel
kernelFunc(mOp.getStaticAttributes() kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<2>(), mOp.getInput(0)->template dims<2>(),
mOp.getInput(0)->getImpl()->rawPtr(), mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr() mOp.getOutput(0)->getImpl()->rawPtr()
...@@ -182,8 +182,8 @@ void Aidge::SliceImpl_cpu<3>::forward() { ...@@ -182,8 +182,8 @@ void Aidge::SliceImpl_cpu<3>::forward() {
{mOp.getInput(0)->dataType()}); {mOp.getInput(0)->dataType()});
// Call kernel // Call kernel
kernelFunc(mOp.getInput(0)->template dims<3>(), kernelFunc(mOp.getStaticAttributes(),
std::get<1>(mOp.getStaticAttributes()), mOp.getInput(0)->template dims<3>(),
mOp.getInput(0)->getImpl()->rawPtr(), mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr() mOp.getOutput(0)->getImpl()->rawPtr()
); );
...@@ -243,8 +243,8 @@ void Aidge::SliceImpl_cpu<4>::forward() { ...@@ -243,8 +243,8 @@ void Aidge::SliceImpl_cpu<4>::forward() {
{mOp.getInput(0)->dataType()}); {mOp.getInput(0)->dataType()});
// Call kernel // Call kernel
kernelFunc(mOp.getInput(0)->template dims<4>(), kernelFunc(mOp.getStaticAttributes(),
std::get<1>(mOp.getStaticAttributes()), mOp.getInput(0)->template dims<4>(),
mOp.getInput(0)->getImpl()->rawPtr(), mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr() mOp.getOutput(0)->getImpl()->rawPtr()
); );
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment