Skip to content
Snippets Groups Projects
Commit 7c94c8b0 authored by Maxence Naud's avatar Maxence Naud
Browse files

Remove the need to specify the number of dimensions for the input of Slice_Op

parent f0100a63
No related branches found
No related tags found
1 merge request!23Tiling
Pipeline #35103 failed
......@@ -26,95 +26,26 @@ namespace Aidge {
// class Slice_Op;
// compute kernel registry for forward and backward
template <DimIdx_t DIM>
class SliceImplForward_cpu
: public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>,
void(const typename Slice_Op<DIM>::Attrs&,
const std::array<std::size_t, DIM>,
: public Registrable<SliceImplForward_cpu, std::tuple<DataType>,
void(const typename Slice_Op::Attrs&,
const std::vector<std::size_t>,
const void*,
void*)> {};
template <DimIdx_t DIM>
class SliceImplBackward_cpu
: public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>,
void(const typename Slice_Op<DIM>::Attrs&,
const std::array<std::size_t, DIM>,
: public Registrable<SliceImplBackward_cpu, std::tuple<DataType>,
void(const typename Slice_Op::Attrs&,
const std::vector<std::size_t>,
const void*,
void*)> {};
template <DimIdx_t DIM>
class SliceImpl_cpu : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op<DIM>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) {
return std::make_unique<SliceImpl_cpu<DIM>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final { return 0; }
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final {
return mNbConsumedData[0];
}
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final {
return mNbProducedData[0];
}
void updateConsummerProducer() override final {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void forward() {
// FIXME: uncomment the following code once memory handling will work
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create(
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<DIM>(),
std::get<1>(std::static_pointer_cast<const Slice_Op<DIM>&>(mOp).getStaticAttributes()),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void backward() { printf("Not implemented yet.\n"); }
};
/******************************************************************************/
template <>
class SliceImpl_cpu<1> : public OperatorImpl {
class SliceImpl_cpu : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op<1>& op) : OperatorImpl(op) {}
SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) {
return std::make_unique<SliceImpl_cpu<1>>(op);
static std::unique_ptr<SliceImpl_cpu> create(const Slice_Op& op) {
return std::make_unique<SliceImpl_cpu>(op);
}
public:
......@@ -127,89 +58,14 @@ public:
void updateConsummerProducer() override final;
void forward();
void backward();
};
/******************************************************************************/
template <>
class SliceImpl_cpu<2> : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op<2>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) {
return std::make_unique<SliceImpl_cpu<2>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void backward();
};
/******************************************************************************/
template <>
class SliceImpl_cpu<3> : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op<3>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) {
return std::make_unique<SliceImpl_cpu<3>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void backward();
};
/******************************************************************************/
template <>
class SliceImpl_cpu<4> : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op<4>& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) {
return std::make_unique<SliceImpl_cpu<4>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void backward();
};
namespace {
static Registrar<Slice_Op<1>> registrarSliceImpl_1D_cpu("cpu", Aidge::SliceImpl_cpu<1>::create);
static Registrar<Slice_Op<2>> registrarSliceImpl_2D_cpu("cpu", Aidge::SliceImpl_cpu<2>::create);
static Registrar<Slice_Op<3>> registrarSliceImpl_3D_cpu("cpu", Aidge::SliceImpl_cpu<3>::create);
static Registrar<Slice_Op<4>> registrarSliceImpl_4D_cpu("cpu", Aidge::SliceImpl_cpu<4>::create);
static Registrar<Slice_Op> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */
\ No newline at end of file
......@@ -15,46 +15,47 @@
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include <array>
#include <vector>
#include <cstddef>
#include "aidge/data/Data.hpp"
namespace Aidge {
template <class I, std::size_t DIM>
void SliceImpl_cpu_forward_kernel(const typename Slice_Op<DIM>::Attrs& attrs,
const std::array<std::size_t, DIM> inputDims,
template <class I>
void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs,
const std::vector<std::size_t> inputDims,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_) + std::get<0>(attrs);
I* output = static_cast<I*>(output_);
const std::array<std::size_t, DIM> slicedDims = std::get<1>(attrs);
const std::vector<std::size_t> slicedDims = std::get<1>(attrs);
const std::size_t nbDims = slicedDims.size();
// for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3}
std::array<std::size_t, DIM> substractedDims;
for (std::size_t i = 0; i < DIM; ++i) {
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
for (std::size_t i = 0; i < nbDims; ++i) {
substractedDims[i] = inputDims[i] - slicedDims[i];
}
// for slicedDims = {3,2,2,1}, prodSlicedDims = {12,4,2,1}
std::array<std::size_t, DIM> prodSlicedDims;
std::array<std::size_t, DIM+1> prodInputDims;
prodSlicedDims[DIM - 1] = slicedDims[DIM - 1];
prodInputDims[DIM - 1] = inputDims[DIM - 1];
prodInputDims[DIM] = 1;
for (std::size_t i = 2; i <= DIM; ++i) {
prodSlicedDims[DIM - i] = prodSlicedDims[DIM - i + 1]*slicedDims[DIM - i];
prodInputDims[DIM - i] = prodInputDims[DIM - i + 1]*inputDims[DIM - i];
std::vector<std::size_t> prodSlicedDims = std::vector<std::size_t>(nbDims);
std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims+1);
prodSlicedDims[nbDims - 1] = slicedDims[nbDims - 1];
prodInputDims[nbDims - 1] = inputDims[nbDims - 1];
prodInputDims[nbDims] = 1;
for (std::size_t i = 2; i <= nbDims; ++i) {
prodSlicedDims[nbDims - i] = prodSlicedDims[nbDims - i + 1]*slicedDims[nbDims - i];
prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1]*inputDims[nbDims - i];
}
std::size_t j = 0;
std::size_t i = 0;
for (; j < prodSlicedDims[0];) {
output[j] = input[i++];
++j;
for (std::size_t idx = DIM - 1; idx > 0; --idx) {
i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0;
++j;
for (std::size_t idx = nbDims - 1; idx > 0; --idx) {
i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0;
}
}
}
......@@ -62,37 +63,13 @@ void SliceImpl_cpu_forward_kernel(const typename Slice_Op<DIM>::Attrs& attrs,
namespace {
// DIM = 1
static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 1>);
static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 1>);
static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 1>);
// DIM = 2
static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 2>);
static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 2>);
static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 2>);
// DIM = 3
static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 3>);
static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 3>);
static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 3>);
// DIM = 4
static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 4>);
static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 4>);
static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 4>);
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float>);
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int>);
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_FORWARD_KERNEL_H_ */
#endif /* AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_ */
......@@ -22,231 +22,55 @@
#include <cassert>
#include <tuple>
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
return std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<1>()[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<1>()[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<1>::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<1>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<1>>::create(
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const Slice_Op<1>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<1>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<1>::backward() { printf("Not implemented yet.\n"); }
/////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<2>();
return inputDims[0]*inputDims[1];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<2>();
return outputDims[0]*outputDims[1];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<2>::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<2>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<2>>::create(
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const Slice_Op<2>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<2>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<2>::backward() { printf("Not implemented yet.\n"); }
////////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<3>();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<3>();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<3>::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<3>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<3>>::create(
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const Slice_Op<3>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<3>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<3>::backward() { printf("Not implemented yet.\n"); }
//////////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>();
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
Aidge::NbElts_t Aidge::SliceImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<4>();
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<4>::updateConsummerProducer() {
void Aidge::SliceImpl_cpu::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<4>::forward() {
void Aidge::SliceImpl_cpu::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<4>>::create(
auto kernelFunc = Registrar<SliceImplForward_cpu>::create(
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const Slice_Op<4>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
kernelFunc(dynamic_cast<const Slice_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
......@@ -255,4 +79,4 @@ void Aidge::SliceImpl_cpu<4>::forward() {
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<4>::backward() { printf("Not implemented yet.\n"); }
\ No newline at end of file
void Aidge::SliceImpl_cpu::backward() { printf("Not implemented yet.\n"); }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment