diff --git a/include/aidge/backend/cpu/operator/SliceImpl.hpp b/include/aidge/backend/cpu/operator/SliceImpl.hpp index c50049206f9fd9b70d6d724aa6d651998d1f1de1..6c4b50c089f43f146bb52f7e6f1ee0301c7e986d 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl.hpp @@ -12,12 +12,15 @@ #ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_H_ #define AIDGE_CPU_OPERATOR_SLICEIMPL_H_ +#include <memory> +#include <tuple> +#include <vector> + #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Slice.hpp" + #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" -#include <memory> -#include <vector> namespace Aidge { // class Slice_Op; @@ -25,12 +28,14 @@ namespace Aidge { // compute kernel registry for forward and backward template <DimIdx_t DIM> class SliceImplForward_cpu - : public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>, void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, void*)> { -}; + : public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>, + void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, + void*)> {}; template <DimIdx_t DIM> class SliceImplBackward_cpu - : public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>, void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, void*)> { -}; + : public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>, + void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, + void*)> {}; template <DimIdx_t DIM> class SliceImpl_cpu : public OperatorImpl { @@ -42,8 +47,8 @@ class SliceImpl_cpu : public OperatorImpl { public: SliceImpl_cpu(const Slice_Op<DIM>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} - static std::unique_ptr<SliceImpl_cpu> create(const Slice_Op<DIM>& op) { - return std::make_unique<SliceImpl_cpu>(op); + static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) { + return std::make_unique<SliceImpl_cpu<DIM>>(op); } public: @@ -53,18 +58,17 @@ class SliceImpl_cpu : public OperatorImpl { // Requires the whole tensors const auto& inputDims = mOp.getInput(0)->dims(); - return std::accumulate(inputDims.begin(), inputDims.end(), - static_cast<NbElts_t>(1), std::multiplies<NbElts_t>()); + return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), + std::multiplies<NbElts_t>()); } - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final { - return 0; - } - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t>& inputsSize) const override final { - (void) outputIdx; - (void) inputsSize; + NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final { return 0; } + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t>& inputsSize) const override final { + (void)outputIdx; + (void)inputsSize; const auto& outputDims = mOp.getOutput(0)->dims(); - return std::accumulate(outputDims.begin(), outputDims.end(), - static_cast<NbElts_t>(1), std::multiplies<NbElts_t>()); + return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), + std::multiplies<NbElts_t>()); } NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final { return mNbConsumedData[0]; @@ -73,9 +77,10 @@ class SliceImpl_cpu : public OperatorImpl { return mNbProducedData[0]; } void updateConsummerProducer() override final { - mNbConsumedData[0]+= getNbRequiredData(0); // each input is consumed by the minimum amount for a forward pass + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); - mNbProducedData[0]+= getRequiredMemory(0, {}); + mNbProducedData[0] += getRequiredMemory(0, {}); } void forward() { @@ -83,30 +88,149 @@ class SliceImpl_cpu : public OperatorImpl { assert(mOp.getInput(0) && "missing input #0"); // Find the correct kernel type - auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create({ - mOp.getInput(0)->dataType(), - mOp.getOutput(0)->dataType()}); + auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create( + {mOp.getInput(0)->dataType()}); // Call kernel - kernelFunc(mOp->getInput(0)->dims(), - mOp->template getAttr<SliceAttr::SliceDims>(), - mOp.getInput(0)->getImpl()->rawPtr(), - mOp.getOutput(0)->getImpl()->rawPtr()); + kernelFunc(mOp.getInput(0)->template dims<DIM>(), + std::get<1>(mOp.getStaticAttributes()), + mOp.getInput(0)->getImpl()->rawPtr(), + mOp.getOutput(0)->getImpl()->rawPtr() + ); - mNbConsumedData[0]+= getNbRequiredData(0); // each input is consumed by the minimum amount for a forward pass + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); - mNbProducedData[0]+= getRequiredMemory(0, {}); + mNbProducedData[0] += getRequiredMemory(0, {}); } - void backward() { - printf("Not implemented yet.\n"); + void backward() { printf("Not implemented yet.\n"); } +}; + +/******************************************************************************/ + +template <> +class SliceImpl_cpu<1> : public OperatorImpl { + private: + const Slice_Op<1>& mOp; + std::array<NbElts_t, 1> mNbConsumedData; + std::array<NbElts_t, 1> mNbProducedData; + + public: + SliceImpl_cpu(const Slice_Op<1>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + + static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) { + return std::make_unique<SliceImpl_cpu<1>>(op); } + + public: + NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t>& inputsSize) const override final; + NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; + void updateConsummerProducer() override final; + + void forward(); + void backward(); }; +/******************************************************************************/ + +template <> +class SliceImpl_cpu<2> : public OperatorImpl { + private: + const Slice_Op<2>& mOp; + std::array<NbElts_t, 1> mNbConsumedData; + std::array<NbElts_t, 1> mNbProducedData; + + public: + SliceImpl_cpu(const Slice_Op<2>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + + static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) { + return std::make_unique<SliceImpl_cpu<2>>(op); + } + + public: + NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t>& inputsSize) const override final; + NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; + void updateConsummerProducer() override final; + + void forward(); + void backward(); +}; + +/******************************************************************************/ + +template <> +class SliceImpl_cpu<3> : public OperatorImpl { + private: + const Slice_Op<3>& mOp; + std::array<NbElts_t, 1> mNbConsumedData; + std::array<NbElts_t, 1> mNbProducedData; + + public: + SliceImpl_cpu(const Slice_Op<3>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + + static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) { + return std::make_unique<SliceImpl_cpu<3>>(op); + } + + public: + NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t>& inputsSize) const override final; + NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; + void updateConsummerProducer() override final; + + void forward(); + void backward(); +}; + +/******************************************************************************/ + +template <> +class SliceImpl_cpu<4> : public OperatorImpl { + private: + const Slice_Op<4>& mOp; + std::array<NbElts_t, 1> mNbConsumedData; + std::array<NbElts_t, 1> mNbProducedData; + + public: + SliceImpl_cpu(const Slice_Op<4>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + + static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) { + return std::make_unique<SliceImpl_cpu<4>>(op); + } + + public: + NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t>& inputsSize) const override final; + NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; + void updateConsummerProducer() override final; + + void forward(); + void backward(); +}; + + + namespace { -template <DimIdx_t DIM> -static Registrar<Slice_Op<DIM>> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu<DIM>::create); -} +static Registrar<Slice_Op<1>> registrarSliceImpl_1D_cpu("cpu", Aidge::SliceImpl_cpu<1>::create); +static Registrar<Slice_Op<2>> registrarSliceImpl_2D_cpu("cpu", Aidge::SliceImpl_cpu<2>::create); +static Registrar<Slice_Op<3>> registrarSliceImpl_3D_cpu("cpu", Aidge::SliceImpl_cpu<3>::create); +static Registrar<Slice_Op<4>> registrarSliceImpl_4D_cpu("cpu", Aidge::SliceImpl_cpu<4>::create); +} // namespace } // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp index 18e599880a91fd881525977c1d37591944565c8c..01e2735af9801c4fe07d8db61e2f01d2d7caeb93 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp @@ -16,17 +16,18 @@ #include <array> #include <cstddef> -#include "aidge/backend/cpu/operator/SliceImpl.hpp" +#include "aidge/data/Data.hpp" namespace Aidge { -template <class I, class O, std::size_t DIM> -void SliceImpl_cpu_forward_kernel(const std::array<std::size_t, DIM> inputDims, - const std::array<std::size_t, DIM> slicedDims, +template <class I, std::size_t DIM> +void SliceImpl_cpu_forward_kernel(const Slice_Op::Attrs& attrs, + const std::array<std::size_t, DIM> inputDims, const void* input_, void* output_) { - const I* input = static_cast<const I*>(input_); - O* output = static_cast<O*>(output_); + const I* input = static_cast<const I*>(input_) + std::get<0>(attrs); + I* output = static_cast<I*>(output_); + const std::array<std::size_t, DIM> slicedDims = std::get<1>(attrs); // for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3} std::array<std::size_t, DIM> substractedDims; @@ -40,31 +41,54 @@ void SliceImpl_cpu_forward_kernel(const std::array<std::size_t, DIM> inputDims, prodSlicedDims[DIM - 1] = slicedDims[DIM - 1]; prodInputDims[DIM - 1] = inputDims[DIM - 1]; prodInputDims[DIM] = 1; - for (std::size_t i = 2; i < DIM; ++i) { + for (std::size_t i = 2; i <= DIM; ++i) { prodSlicedDims[DIM - i] = prodSlicedDims[DIM - i + 1]*slicedDims[DIM - i]; prodInputDims[DIM - i] = prodInputDims[DIM - i + 1]*inputDims[DIM - i]; } std::size_t j = 0; std::size_t i = 0; - for (std::size_t = 0; j < prodSlicedDims[0]; ++j) { + for (; j < prodSlicedDims[0]; ++j) { output[j] = input[i++]; for (std::size_t idx = DIM - 1; idx > 0; --idx) { - i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0; + i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0; } } } namespace { -template <std::size_t DIM> -static Registrar<SliceImplForward_cpu<DIM>> registrarSliceImplForward_cpu_Float32( - {DataType::Float32, DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, float, DIM>); -template <std::size_t DIM> -static Registrar<SliceImplForward_cpu<DIM>> registrarSliceImplForward_cpu_Int32( - {DataType::Int32, DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, int, DIM>); -template <std::size_t DIM> -static Registrar<SliceImplForward_cpu<DIM>> registrarSliceImplForward_cpu_Float64( - {DataType::Float64, DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, double, DIM>); + +// DIM = 1 +static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float32( + {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 1>); +static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Int32( + {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 1>); +static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float64( + {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 1>); + +// DIM = 2 +static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float32( + {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 2>); +static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Int32( + {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 2>); +static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float64( + {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 2>); + +// DIM = 3 +static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float32( + {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 3>); +static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Int32( + {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 3>); +static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float64( + {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 3>); + +// DIM = 4 +static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float32( + {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 4>); +static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Int32( + {DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 4>); +static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float64( + {DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 4>); } // namespace } // namespace Aidge diff --git a/src/operator/SliceImpl.cpp b/src/operator/SliceImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd571f24f7a5537a7ad503069085f34e04633592 --- /dev/null +++ b/src/operator/SliceImpl.cpp @@ -0,0 +1,258 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <numeric> // std::accumulate +#include <functional> // std::multiplies + +#include "aidge/operator/Slice.hpp" + +#include "aidge/backend/cpu/operator/SliceImpl.hpp" +#include "aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp" +#include "aidge/utils/Types.h" +#include <vector> +#include <cassert> +#include <tuple> + + +Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { + assert(mOp.getInput(0) && "requires valid input"); + + // Requires the whole tensors + return mOp.getInput(0)->dims<1>()[0]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } + +Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t>& inputsSize) const { + (void)outputIdx; + (void)inputsSize; + return mOp.getOutput(0)->dims<1>()[0]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { + return mNbConsumedData[0]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { + return mNbProducedData[0]; +} + +void Aidge::SliceImpl_cpu<1>::updateConsummerProducer() { + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<1>::forward() { + // FIXME: uncomment the following code once memory handling will work + assert(mOp.getInput(0) && "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<SliceImplForward_cpu<1>>::create( + {mOp.getInput(0)->dataType()}); + + // Call kernel + kernelFunc(mOp.getInput(0)->template dims<1>(), + std::get<1>(mOp.getStaticAttributes()), + mOp.getInput(0)->getImpl()->rawPtr(), + mOp.getOutput(0)->getImpl()->rawPtr() + ); + + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<1>::backward() { printf("Not implemented yet.\n"); } + +///////////////////////////////////////////////////////////////////////// + +Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { + assert(mOp.getInput(0) && "requires valid input"); + + // Requires the whole tensors + const auto& inputDims = mOp.getInput(0)->dims<2>(); + return inputDims[0]*inputDims[1]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } + +Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t>& inputsSize) const { + (void)outputIdx; + (void)inputsSize; + const auto& outputDims = mOp.getOutput(0)->dims<2>(); + return outputDims[0]*outputDims[1]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { + return mNbConsumedData[0]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { + return mNbProducedData[0]; +} + +void Aidge::SliceImpl_cpu<2>::updateConsummerProducer() { + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<2>::forward() { + // FIXME: uncomment the following code once memory handling will work + assert(mOp.getInput(0) && "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<SliceImplForward_cpu<2>>::create( + {mOp.getInput(0)->dataType()}); + + // Call kernel + kernelFunc(mOp.getStaticAttributes() + mOp.getInput(0)->template dims<2>(), + mOp.getInput(0)->getImpl()->rawPtr(), + mOp.getOutput(0)->getImpl()->rawPtr() + ); + + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<2>::backward() { printf("Not implemented yet.\n"); } + +//////////////////////////////////////////////////////////////////////////// + +Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { + assert(mOp.getInput(0) && "requires valid input"); + + // Requires the whole tensors + const auto& inputDims = mOp.getInput(0)->dims<3>(); + + return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), + std::multiplies<NbElts_t>()); +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } + +Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t>& inputsSize) const { + (void)outputIdx; + (void)inputsSize; + const auto& outputDims = mOp.getOutput(0)->dims<3>(); + return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), + std::multiplies<NbElts_t>()); +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { + return mNbConsumedData[0]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { + return mNbProducedData[0]; +} + +void Aidge::SliceImpl_cpu<3>::updateConsummerProducer() { + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<3>::forward() { + // FIXME: uncomment the following code once memory handling will work + assert(mOp.getInput(0) && "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<SliceImplForward_cpu<3>>::create( + {mOp.getInput(0)->dataType()}); + + // Call kernel + kernelFunc(mOp.getInput(0)->template dims<3>(), + std::get<1>(mOp.getStaticAttributes()), + mOp.getInput(0)->getImpl()->rawPtr(), + mOp.getOutput(0)->getImpl()->rawPtr() + ); + + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<3>::backward() { printf("Not implemented yet.\n"); } + +////////////////////////////////////////////////////////////////////////////// + +Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { + assert(mOp.getInput(0) && "requires valid input"); + + // Requires the whole tensors + const auto& inputDims = mOp.getInput(0)->dims<4>(); + + return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), + std::multiplies<NbElts_t>()); +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } + +Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t>& inputsSize) const { + (void)outputIdx; + (void)inputsSize; + const auto& outputDims = mOp.getOutput(0)->template dims<4>(); + return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), + std::multiplies<NbElts_t>()); +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { + return mNbConsumedData[0]; +} + +Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { + return mNbProducedData[0]; +} + +void Aidge::SliceImpl_cpu<4>::updateConsummerProducer() { + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<4>::forward() { + // FIXME: uncomment the following code once memory handling will work + assert(mOp.getInput(0) && "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<SliceImplForward_cpu<4>>::create( + {mOp.getInput(0)->dataType()}); + + // Call kernel + kernelFunc(mOp.getInput(0)->template dims<4>(), + std::get<1>(mOp.getStaticAttributes()), + mOp.getInput(0)->getImpl()->rawPtr(), + mOp.getOutput(0)->getImpl()->rawPtr() + ); + + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[0] += getNbRequiredData(0); + + mNbProducedData[0] += getRequiredMemory(0, {}); +} + +void Aidge::SliceImpl_cpu<4>::backward() { printf("Not implemented yet.\n"); } \ No newline at end of file diff --git a/unit_tests/operator/Test_SliceImpl.cpp b/unit_tests/operator/Test_SliceImpl.cpp index 486f6edec006d3505cabe2d66b3820862cde3b69..0bf12f9b0faa01798b041462a50ec7db07347130 100644 --- a/unit_tests/operator/Test_SliceImpl.cpp +++ b/unit_tests/operator/Test_SliceImpl.cpp @@ -18,7 +18,7 @@ using namespace Aidge; -TEST_CASE("[cpu/operator] Slice(forward)") { +TEST_CASE("[cpu/operator] Slice(forward)", "[Slice]") { SECTION("1D Tensor") { std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array1D<int,10> { {0, 1, 2,-3, 4,-5,-6, 7, 8, 9} @@ -33,6 +33,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); + mySlice->getOperator()->output(0).print(); REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput); } @@ -56,6 +57,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); + mySlice->getOperator()->output(0).print(); REQUIRE(*mySlice->getOperator()->getOutput(0) == *expectedOutput); } @@ -86,6 +88,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); + mySlice->getOperator()->output(0).print(); REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput); } @@ -104,12 +107,12 @@ TEST_CASE("[cpu/operator] Slice(forward)") { }, { { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9}, {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} }, { { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + {-5, 4, 2,-3,11,-5,-6, 7,-1,10} } } } @@ -128,12 +131,12 @@ TEST_CASE("[cpu/operator] Slice(forward)") { }, { { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9}, {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} }, { { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + {-5, 4, 2,-3,11,-5,-6, 7,-1,10} } } } @@ -145,6 +148,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); + mySlice->getOperator()->output(0).print(); REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput); } } \ No newline at end of file