diff --git a/include/aidge/backend/cpu/operator/SliceImpl.hpp b/include/aidge/backend/cpu/operator/SliceImpl.hpp index 69dd88bb5e6e463b1122bc46b294f3220c88a905..41d5367f8397d867270f2134f70890f5fc576393 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl.hpp @@ -26,95 +26,26 @@ namespace Aidge { // class Slice_Op; // compute kernel registry for forward and backward -template <DimIdx_t DIM> class SliceImplForward_cpu - : public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>, - void(const typename Slice_Op<DIM>::Attrs&, - const std::array<std::size_t, DIM>, + : public Registrable<SliceImplForward_cpu, std::tuple<DataType>, + void(const typename Slice_Op::Attrs&, + const std::vector<std::size_t>, const void*, void*)> {}; -template <DimIdx_t DIM> class SliceImplBackward_cpu - : public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>, - void(const typename Slice_Op<DIM>::Attrs&, - const std::array<std::size_t, DIM>, + : public Registrable<SliceImplBackward_cpu, std::tuple<DataType>, + void(const typename Slice_Op::Attrs&, + const std::vector<std::size_t>, const void*, void*)> {}; -template <DimIdx_t DIM> -class SliceImpl_cpu : public OperatorImpl { - public: - SliceImpl_cpu(const Slice_Op<DIM>& op) : OperatorImpl(op) {} - - static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) { - return std::make_unique<SliceImpl_cpu<DIM>>(op); - } - - public: - NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input"); - - // Requires the whole tensors - const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(); - - return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), - std::multiplies<NbElts_t>()); - } - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final { return 0; } - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, - const std::vector<DimSize_t>& inputsSize) const override final { - (void)outputIdx; - (void)inputsSize; - const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(); - return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), - std::multiplies<NbElts_t>()); - } - NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final { - return mNbConsumedData[0]; - } - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final { - return mNbProducedData[0]; - } - void updateConsummerProducer() override final { - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); - } - - void forward() { - // FIXME: uncomment the following code once memory handling will work - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); - - // Find the correct kernel type - auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create( - {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()}); - - // Call kernel - kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<DIM>(), - std::get<1>(std::static_pointer_cast<const Slice_Op<DIM>&>(mOp).getStaticAttributes()), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr() - ); - - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); - } - - void backward() { printf("Not implemented yet.\n"); } -}; -/******************************************************************************/ - -template <> -class SliceImpl_cpu<1> : public OperatorImpl { +class SliceImpl_cpu : public OperatorImpl { public: - SliceImpl_cpu(const Slice_Op<1>& op) : OperatorImpl(op) {} + SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {} - static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) { - return std::make_unique<SliceImpl_cpu<1>>(op); + static std::unique_ptr<SliceImpl_cpu> create(const Slice_Op& op) { + return std::make_unique<SliceImpl_cpu>(op); } public: @@ -127,89 +58,14 @@ public: void updateConsummerProducer() override final; void forward(); - void backward(); -}; - -/******************************************************************************/ - -template <> -class SliceImpl_cpu<2> : public OperatorImpl { - public: - SliceImpl_cpu(const Slice_Op<2>& op) : OperatorImpl(op) {} - static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) { - return std::make_unique<SliceImpl_cpu<2>>(op); - } - - public: - NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, - const std::vector<DimSize_t>& inputsSize) const override final; - NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - void updateConsummerProducer() override final; - - void forward(); - void backward(); -}; - -/******************************************************************************/ - -template <> -class SliceImpl_cpu<3> : public OperatorImpl { - public: - SliceImpl_cpu(const Slice_Op<3>& op) : OperatorImpl(op) {} - - static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) { - return std::make_unique<SliceImpl_cpu<3>>(op); - } - - public: - NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, - const std::vector<DimSize_t>& inputsSize) const override final; - NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - void updateConsummerProducer() override final; - - void forward(); void backward(); }; -/******************************************************************************/ - -template <> -class SliceImpl_cpu<4> : public OperatorImpl { - public: - SliceImpl_cpu(const Slice_Op<4>& op) : OperatorImpl(op) {} - - static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) { - return std::make_unique<SliceImpl_cpu<4>>(op); - } - - public: - NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, - const std::vector<DimSize_t>& inputsSize) const override final; - NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - void updateConsummerProducer() override final; - - void forward(); - void backward(); -}; - - namespace { -static Registrar<Slice_Op<1>> registrarSliceImpl_1D_cpu("cpu", Aidge::SliceImpl_cpu<1>::create); -static Registrar<Slice_Op<2>> registrarSliceImpl_2D_cpu("cpu", Aidge::SliceImpl_cpu<2>::create); -static Registrar<Slice_Op<3>> registrarSliceImpl_3D_cpu("cpu", Aidge::SliceImpl_cpu<3>::create); -static Registrar<Slice_Op<4>> registrarSliceImpl_4D_cpu("cpu", Aidge::SliceImpl_cpu<4>::create); +static Registrar<Slice_Op> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu::create); } // namespace } // namespace Aidge -#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */ +#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */ \ No newline at end of file diff --git a/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp index bbf4ccbae77089ca75dfee34f5bc5b0dd7d3697d..7eb4b9dc2cb8dddc8b7fdaf4d63b8f1d39d879b0 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp @@ -15,46 +15,47 @@ #include "aidge/utils/Registrar.hpp" #include "aidge/operator/Slice.hpp" #include "aidge/backend/cpu/operator/SliceImpl.hpp" -#include <array> +#include <vector> #include <cstddef> #include "aidge/data/Data.hpp" namespace Aidge { -template <class I, std::size_t DIM> -void SliceImpl_cpu_forward_kernel(const typename Slice_Op<DIM>::Attrs& attrs, - const std::array<std::size_t, DIM> inputDims, +template <class I> +void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs, + const std::vector<std::size_t> inputDims, const void* input_, void* output_) { const I* input = static_cast<const I*>(input_) + std::get<0>(attrs); I* output = static_cast<I*>(output_); - const std::array<std::size_t, DIM> slicedDims = std::get<1>(attrs); + const std::vector<std::size_t> slicedDims = std::get<1>(attrs); + const std::size_t nbDims = slicedDims.size(); // for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3} - std::array<std::size_t, DIM> substractedDims; - for (std::size_t i = 0; i < DIM; ++i) { + std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims); + for (std::size_t i = 0; i < nbDims; ++i) { substractedDims[i] = inputDims[i] - slicedDims[i]; } // for slicedDims = {3,2,2,1}, prodSlicedDims = {12,4,2,1} - std::array<std::size_t, DIM> prodSlicedDims; - std::array<std::size_t, DIM+1> prodInputDims; - prodSlicedDims[DIM - 1] = slicedDims[DIM - 1]; - prodInputDims[DIM - 1] = inputDims[DIM - 1]; - prodInputDims[DIM] = 1; - for (std::size_t i = 2; i <= DIM; ++i) { - prodSlicedDims[DIM - i] = prodSlicedDims[DIM - i + 1]*slicedDims[DIM - i]; - prodInputDims[DIM - i] = prodInputDims[DIM - i + 1]*inputDims[DIM - i]; + std::vector<std::size_t> prodSlicedDims = std::vector<std::size_t>(nbDims); + std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims+1); + prodSlicedDims[nbDims - 1] = slicedDims[nbDims - 1]; + prodInputDims[nbDims - 1] = inputDims[nbDims - 1]; + prodInputDims[nbDims] = 1; + for (std::size_t i = 2; i <= nbDims; ++i) { + prodSlicedDims[nbDims - i] = prodSlicedDims[nbDims - i + 1]*slicedDims[nbDims - i]; + prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1]*inputDims[nbDims - i]; } std::size_t j = 0; std::size_t i = 0; for (; j < prodSlicedDims[0];) { output[j] = input[i++]; - ++j; - for (std::size_t idx = DIM - 1; idx > 0; --idx) { - i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0; + ++j; + for (std::size_t idx = nbDims - 1; idx > 0; --idx) { + i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0; } } } @@ -62,37 +63,13 @@ void SliceImpl_cpu_forward_kernel(const typename Slice_Op<DIM>::Attrs& attrs, namespace { // DIM = 1 -static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float32( - {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 1>); -static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Int32( - {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 1>); -static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float64( - {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 1>); - -// DIM = 2 -static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float32( - {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 2>); -static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Int32( - {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 2>); -static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float64( - {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 2>); - -// DIM = 3 -static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float32( - {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 3>); -static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Int32( - {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 3>); -static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float64( - {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 3>); - -// DIM = 4 -static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float32( - {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 4>); -static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Int32( - {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 4>); -static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float64( - {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 4>); +static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float32( + {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float>); +static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Int32( + {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int>); +static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float64( + {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double>); } // namespace } // namespace Aidge -#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_FORWARD_KERNEL_H_ */ +#endif /* AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_ */ diff --git a/src/operator/SliceImpl.cpp b/src/operator/SliceImpl.cpp index 3ae56e1a4f613a4188dc51659853e07674e74768..b60bbe60188f416f28ff2562875dce6e5ee15bd5 100644 --- a/src/operator/SliceImpl.cpp +++ b/src/operator/SliceImpl.cpp @@ -22,231 +22,55 @@ #include <cassert> #include <tuple> - -Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input"); - - // Requires the whole tensors - return std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<1>()[0]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } - -Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, - const std::vector<Aidge::DimSize_t>& inputsSize) const { - (void)outputIdx; - (void)inputsSize; - return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<1>()[0]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { - return mNbConsumedData[0]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { - return mNbProducedData[0]; -} - -void Aidge::SliceImpl_cpu<1>::updateConsummerProducer() { - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - -void Aidge::SliceImpl_cpu<1>::forward() { - // FIXME: uncomment the following code once memory handling will work - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); - - // Find the correct kernel type - auto kernelFunc = Registrar<SliceImplForward_cpu<1>>::create( - {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()}); - - // Call kernel - kernelFunc(dynamic_cast<const Slice_Op<1>&>(mOp).getStaticAttributes(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<1>(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr() - ); - - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - -void Aidge::SliceImpl_cpu<1>::backward() { printf("Not implemented yet.\n"); } - -///////////////////////////////////////////////////////////////////////// - -Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input"); - - // Requires the whole tensors - const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<2>(); - return inputDims[0]*inputDims[1]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } - -Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, - const std::vector<Aidge::DimSize_t>& inputsSize) const { - (void)outputIdx; - (void)inputsSize; - const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<2>(); - return outputDims[0]*outputDims[1]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { - return mNbConsumedData[0]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { - return mNbProducedData[0]; -} - -void Aidge::SliceImpl_cpu<2>::updateConsummerProducer() { - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - -void Aidge::SliceImpl_cpu<2>::forward() { - // FIXME: uncomment the following code once memory handling will work - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); - - // Find the correct kernel type - auto kernelFunc = Registrar<SliceImplForward_cpu<2>>::create( - {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()}); - - // Call kernel - kernelFunc(dynamic_cast<const Slice_Op<2>&>(mOp).getStaticAttributes(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<2>(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr() - ); - - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - -void Aidge::SliceImpl_cpu<2>::backward() { printf("Not implemented yet.\n"); } - -//////////////////////////////////////////////////////////////////////////// - -Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input"); - - // Requires the whole tensors - const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<3>(); - - return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), - std::multiplies<NbElts_t>()); -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } - -Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, - const std::vector<Aidge::DimSize_t>& inputsSize) const { - (void)outputIdx; - (void)inputsSize; - const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<3>(); - return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), - std::multiplies<NbElts_t>()); -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { - return mNbConsumedData[0]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { - return mNbProducedData[0]; -} - -void Aidge::SliceImpl_cpu<3>::updateConsummerProducer() { - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - -void Aidge::SliceImpl_cpu<3>::forward() { - // FIXME: uncomment the following code once memory handling will work - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); - - // Find the correct kernel type - auto kernelFunc = Registrar<SliceImplForward_cpu<3>>::create( - {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()}); - - // Call kernel - kernelFunc(dynamic_cast<const Slice_Op<3>&>(mOp).getStaticAttributes(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<3>(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr() - ); - - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - -void Aidge::SliceImpl_cpu<3>::backward() { printf("Not implemented yet.\n"); } - -////////////////////////////////////////////////////////////////////////////// - -Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input"); // Requires the whole tensors - const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(); + const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(); return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), std::multiplies<NbElts_t>()); } -Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } +Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } -Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, - const std::vector<Aidge::DimSize_t>& inputsSize) const { +Aidge::NbElts_t Aidge::SliceImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t>& inputsSize) const { (void)outputIdx; (void)inputsSize; - const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<4>(); + const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(); return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), std::multiplies<NbElts_t>()); } -Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { return mNbConsumedData[0]; } -Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { +Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { return mNbProducedData[0]; } -void Aidge::SliceImpl_cpu<4>::updateConsummerProducer() { +void Aidge::SliceImpl_cpu::updateConsummerProducer() { // each input is consumed by the minimum amount for a forward pass mNbConsumedData[0] += getNbRequiredData(0); mNbProducedData[0] += getRequiredMemory(0, {}); } -void Aidge::SliceImpl_cpu<4>::forward() { +void Aidge::SliceImpl_cpu::forward() { // FIXME: uncomment the following code once memory handling will work assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); // Find the correct kernel type - auto kernelFunc = Registrar<SliceImplForward_cpu<4>>::create( + auto kernelFunc = Registrar<SliceImplForward_cpu>::create( {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()}); // Call kernel - kernelFunc(dynamic_cast<const Slice_Op<4>&>(mOp).getStaticAttributes(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr() + kernelFunc(dynamic_cast<const Slice_Op&>(mOp).getStaticAttributes(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr() ); // each input is consumed by the minimum amount for a forward pass @@ -255,4 +79,4 @@ void Aidge::SliceImpl_cpu<4>::forward() { mNbProducedData[0] += getRequiredMemory(0, {}); } -void Aidge::SliceImpl_cpu<4>::backward() { printf("Not implemented yet.\n"); } \ No newline at end of file +void Aidge::SliceImpl_cpu::backward() { printf("Not implemented yet.\n"); } \ No newline at end of file