Skip to content
Snippets Groups Projects
Commit 09deddee authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd][WIP] Slice impl compiles but tests do not pass yet

parent 3b8417ed
No related branches found
No related tags found
2 merge requests!22Update operators implementation,!16Draft: Tiling
Pipeline #32745 failed
......@@ -12,12 +12,15 @@
#ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_H_
#define AIDGE_CPU_OPERATOR_SLICEIMPL_H_
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Slice_Op;
......@@ -25,12 +28,14 @@ namespace Aidge {
// compute kernel registry for forward and backward
template <DimIdx_t DIM>
class SliceImplForward_cpu
: public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>, void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, void*)> {
};
: public Registrable<SliceImplForward_cpu<DIM>, std::tuple<DataType>,
void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*,
void*)> {};
template <DimIdx_t DIM>
class SliceImplBackward_cpu
: public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>, void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*, void*)> {
};
: public Registrable<SliceImplBackward_cpu<DIM>, std::tuple<DataType>,
void(std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, const void*,
void*)> {};
template <DimIdx_t DIM>
class SliceImpl_cpu : public OperatorImpl {
......@@ -42,8 +47,8 @@ class SliceImpl_cpu : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op<DIM>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
static std::unique_ptr<SliceImpl_cpu> create(const Slice_Op<DIM>& op) {
return std::make_unique<SliceImpl_cpu>(op);
static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) {
return std::make_unique<SliceImpl_cpu<DIM>>(op);
}
public:
......@@ -53,18 +58,17 @@ class SliceImpl_cpu : public OperatorImpl {
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims();
return std::accumulate(inputDims.begin(), inputDims.end(),
static_cast<NbElts_t>(1), std::multiplies<NbElts_t>());
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final {
return 0;
}
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t>& inputsSize) const override final {
(void) outputIdx;
(void) inputsSize;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final { return 0; }
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->dims();
return std::accumulate(outputDims.begin(), outputDims.end(),
static_cast<NbElts_t>(1), std::multiplies<NbElts_t>());
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final {
return mNbConsumedData[0];
......@@ -73,9 +77,10 @@ class SliceImpl_cpu : public OperatorImpl {
return mNbProducedData[0];
}
void updateConsummerProducer() override final {
mNbConsumedData[0]+= getNbRequiredData(0); // each input is consumed by the minimum amount for a forward pass
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0]+= getRequiredMemory(0, {});
mNbProducedData[0] += getRequiredMemory(0, {});
}
void forward() {
......@@ -83,30 +88,149 @@ class SliceImpl_cpu : public OperatorImpl {
assert(mOp.getInput(0) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create({
mOp.getInput(0)->dataType(),
mOp.getOutput(0)->dataType()});
auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create(
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp->getInput(0)->dims(),
mOp->template getAttr<SliceAttr::SliceDims>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr());
kernelFunc(mOp.getInput(0)->template dims<DIM>(),
std::get<1>(mOp.getStaticAttributes()),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
mNbConsumedData[0]+= getNbRequiredData(0); // each input is consumed by the minimum amount for a forward pass
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0]+= getRequiredMemory(0, {});
mNbProducedData[0] += getRequiredMemory(0, {});
}
void backward() {
printf("Not implemented yet.\n");
void backward() { printf("Not implemented yet.\n"); }
};
/******************************************************************************/
template <>
class SliceImpl_cpu<1> : public OperatorImpl {
private:
const Slice_Op<1>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<1>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) {
return std::make_unique<SliceImpl_cpu<1>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void backward();
};
/******************************************************************************/
template <>
class SliceImpl_cpu<2> : public OperatorImpl {
private:
const Slice_Op<2>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<2>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) {
return std::make_unique<SliceImpl_cpu<2>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void backward();
};
/******************************************************************************/
template <>
class SliceImpl_cpu<3> : public OperatorImpl {
private:
const Slice_Op<3>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<3>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) {
return std::make_unique<SliceImpl_cpu<3>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void backward();
};
/******************************************************************************/
template <>
class SliceImpl_cpu<4> : public OperatorImpl {
private:
const Slice_Op<4>& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SliceImpl_cpu(const Slice_Op<4>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) {
return std::make_unique<SliceImpl_cpu<4>>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t>& inputsSize) const override final;
NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void backward();
};
namespace {
template <DimIdx_t DIM>
static Registrar<Slice_Op<DIM>> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu<DIM>::create);
}
static Registrar<Slice_Op<1>> registrarSliceImpl_1D_cpu("cpu", Aidge::SliceImpl_cpu<1>::create);
static Registrar<Slice_Op<2>> registrarSliceImpl_2D_cpu("cpu", Aidge::SliceImpl_cpu<2>::create);
static Registrar<Slice_Op<3>> registrarSliceImpl_3D_cpu("cpu", Aidge::SliceImpl_cpu<3>::create);
static Registrar<Slice_Op<4>> registrarSliceImpl_4D_cpu("cpu", Aidge::SliceImpl_cpu<4>::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */
......@@ -16,17 +16,18 @@
#include <array>
#include <cstddef>
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include "aidge/data/Data.hpp"
namespace Aidge {
template <class I, class O, std::size_t DIM>
void SliceImpl_cpu_forward_kernel(const std::array<std::size_t, DIM> inputDims,
const std::array<std::size_t, DIM> slicedDims,
template <class I, std::size_t DIM>
void SliceImpl_cpu_forward_kernel(const Slice_Op::Attrs& attrs,
const std::array<std::size_t, DIM> inputDims,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const I* input = static_cast<const I*>(input_) + std::get<0>(attrs);
I* output = static_cast<I*>(output_);
const std::array<std::size_t, DIM> slicedDims = std::get<1>(attrs);
// for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3}
std::array<std::size_t, DIM> substractedDims;
......@@ -40,31 +41,54 @@ void SliceImpl_cpu_forward_kernel(const std::array<std::size_t, DIM> inputDims,
prodSlicedDims[DIM - 1] = slicedDims[DIM - 1];
prodInputDims[DIM - 1] = inputDims[DIM - 1];
prodInputDims[DIM] = 1;
for (std::size_t i = 2; i < DIM; ++i) {
for (std::size_t i = 2; i <= DIM; ++i) {
prodSlicedDims[DIM - i] = prodSlicedDims[DIM - i + 1]*slicedDims[DIM - i];
prodInputDims[DIM - i] = prodInputDims[DIM - i + 1]*inputDims[DIM - i];
}
std::size_t j = 0;
std::size_t i = 0;
for (std::size_t = 0; j < prodSlicedDims[0]; ++j) {
for (; j < prodSlicedDims[0]; ++j) {
output[j] = input[i++];
for (std::size_t idx = DIM - 1; idx > 0; --idx) {
i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0;
i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0;
}
}
}
namespace {
template <std::size_t DIM>
static Registrar<SliceImplForward_cpu<DIM>> registrarSliceImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, float, DIM>);
template <std::size_t DIM>
static Registrar<SliceImplForward_cpu<DIM>> registrarSliceImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, int, DIM>);
template <std::size_t DIM>
static Registrar<SliceImplForward_cpu<DIM>> registrarSliceImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, double, DIM>);
// DIM = 1
static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 1>);
static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 1>);
static Registrar<SliceImplForward_cpu<1>> registrarSliceImplForward_1D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 1>);
// DIM = 2
static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 2>);
static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 2>);
static Registrar<SliceImplForward_cpu<2>> registrarSliceImplForward_2D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 2>);
// DIM = 3
static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 3>);
static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 3>);
static Registrar<SliceImplForward_cpu<3>> registrarSliceImplForward_3D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 3>);
// DIM = 4
static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float, 4>);
static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Int32(
{DataType::Int32}, Aidge::SliceImpl_cpu_forward_kernel<int, 4>);
static Registrar<SliceImplForward_cpu<4>> registrarSliceImplForward_4D_cpu_Float64(
{DataType::Float64}, Aidge::SliceImpl_cpu_forward_kernel<double, 4>);
} // namespace
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <numeric> // std::accumulate
#include <functional> // std::multiplies
#include "aidge/operator/Slice.hpp"
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include "aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp"
#include "aidge/utils/Types.h"
#include <vector>
#include <cassert>
#include <tuple>
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
// Requires the whole tensors
return mOp.getInput(0)->dims<1>()[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
return mOp.getOutput(0)->dims<1>()[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<1>::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<1>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<1>>::create(
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->template dims<1>(),
std::get<1>(mOp.getStaticAttributes()),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<1>::backward() { printf("Not implemented yet.\n"); }
/////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims<2>();
return inputDims[0]*inputDims[1];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->dims<2>();
return outputDims[0]*outputDims[1];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<2>::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<2>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<2>>::create(
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getStaticAttributes()
mOp.getInput(0)->template dims<2>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<2>::backward() { printf("Not implemented yet.\n"); }
////////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims<3>();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->dims<3>();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<3>::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<3>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<3>>::create(
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->template dims<3>(),
std::get<1>(mOp.getStaticAttributes()),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<3>::backward() { printf("Not implemented yet.\n"); }
//////////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims<4>();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->template dims<4>();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SliceImpl_cpu<4>::updateConsummerProducer() {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<4>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<4>>::create(
{mOp.getInput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->template dims<4>(),
std::get<1>(mOp.getStaticAttributes()),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[0] += getNbRequiredData(0);
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu<4>::backward() { printf("Not implemented yet.\n"); }
\ No newline at end of file
......@@ -18,7 +18,7 @@
using namespace Aidge;
TEST_CASE("[cpu/operator] Slice(forward)") {
TEST_CASE("[cpu/operator] Slice(forward)", "[Slice]") {
SECTION("1D Tensor") {
std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array1D<int,10> {
{0, 1, 2,-3, 4,-5,-6, 7, 8, 9}
......@@ -33,6 +33,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->computeOutputDims();
mySlice->forward();
mySlice->getOperator()->output(0).print();
REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput);
}
......@@ -56,6 +57,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->computeOutputDims();
mySlice->forward();
mySlice->getOperator()->output(0).print();
REQUIRE(*mySlice->getOperator()->getOutput(0) == *expectedOutput);
}
......@@ -86,6 +88,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->computeOutputDims();
mySlice->forward();
mySlice->getOperator()->output(0).print();
REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput);
}
......@@ -104,12 +107,12 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
},
{
{
{ 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
{ 0, 1, 2,-3, 6,-5,-6, 7, 8, 9},
{-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
},
{
{ 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
{-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
{-5, 4, 2,-3,11,-5,-6, 7,-1,10}
}
}
}
......@@ -128,12 +131,12 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
},
{
{
{ 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
{ 0, 1, 2,-3, 6,-5,-6, 7, 8, 9},
{-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
},
{
{ 0, 1, 2,-3, 4,-5,-6, 7, 8, 9},
{-5, 4, 2,-3, 4,-5,-6, 7,-1,10}
{-5, 4, 2,-3,11,-5,-6, 7,-1,10}
}
}
}
......@@ -145,6 +148,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->computeOutputDims();
mySlice->forward();
mySlice->getOperator()->output(0).print();
REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment