Skip to content
Snippets Groups Projects
Commit 56197a2f authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Slice operator

parent 62a6a58d
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_H_
#define AIDGE_CPU_OPERATOR_SLICEIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Slice_Op;
// compute kernel registry for forward and backward
class SliceImplForward_cpu
: public Registrable<SliceImplForward_cpu,
std::tuple<DataType, DataType>,
void(const std::vector<DimSize_t>, const void*, DimSize_t, const void*, const void*, const void*, void*)> {
};
class SliceImplBackward_cpu
: public Registrable<SliceImplBackward_cpu,
std::tuple<DataType, DataType>,
void(const std::vector<DimSize_t>, const void*, DimSize_t, const void*, const void*, const void*, void*)> {
};
class SliceImpl_cpu : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<SliceImpl_cpu> create(const Slice_Op& op) {
return std::make_unique<SliceImpl_cpu>(op);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<Slice_Op> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SLICEIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
namespace Aidge {
template <class I, class O>
void SliceImpl_cpu_forward_kernel(const std::vector<DimSize_t> inputDims,
const void* input_,
DimSize_t nbSlices,
const void* axes_,
const void* starts_,
const void* ends_,
void* output_) {
const I* input = static_cast<const I*>(input_);
const int* axes = static_cast<const int*>(axes_);
const int* starts = static_cast<const int*>(starts_);
const int* ends = static_cast<const int*>(ends_);
O* output = static_cast<O*>(output_);
// Calculate the total number of elements in the input array
size_t totalElements = 1;
for (size_t dimSize : inputDims) {
totalElements *= dimSize;
}
// Create a temporary arrays to store intermediate input/output for each slice op
std::vector<I> tempInArray(input, input + totalElements);
std::vector<I> tempOutArray(input, input + totalElements);
std::vector<size_t> currentDims = inputDims;
size_t copiedElems = 0;
// Loop over each slice operation
for(size_t i=0; i< nbSlices; ++i)
{
copiedElems = 0;
I* tempOutArrayPtr = tempOutArray.data();
// Extract parameters for the current slice, make sure indexes are positive
size_t axisIdx = axes[i]>=0?axes[i]:(axes[i]+currentDims.size());
size_t startIdx = starts[i]>=0?starts[i]:(starts[i]+currentDims[axisIdx]);
size_t endIdx = ends[i]>=0?ends[i]:(ends[i]+currentDims[axisIdx]);
// Compute the size of the slice over each element on the axis
size_t strideOnCurrDim = 1;
for(size_t j=(axisIdx+1); j<currentDims.size(); ++j)
{
strideOnCurrDim *= currentDims[j];
}
size_t sliceSize = (endIdx - startIdx) * strideOnCurrDim;
// For each slice operation, we will slice all elements on the axis (subSlice)
// the number of sublices is the product of dimension previous to the slice dimension
size_t nbSubSlices = 1;
for(size_t j=0; j<axisIdx; ++j)
{
nbSubSlices*=currentDims[j];
}
// Operate the slice over each element of the dim we want to slice
for(size_t s=0; s<nbSubSlices; ++s)
{
// Compute the pointer postion on input
std::size_t copyStartPos = s * strideOnCurrDim * currentDims[axisIdx] + strideOnCurrDim;
const I* copyPtr = std::next(tempInArray.data(), copyStartPos);
// Copy slice to output array and update pointer
std::copy_n(copyPtr, sliceSize , tempOutArrayPtr);
tempOutArrayPtr += sliceSize ;
copiedElems+= sliceSize ;
}
// Update the input for the next slice operation
tempInArray.assign(tempOutArray.begin(), tempOutArray.begin() + copiedElems);
currentDims[axisIdx] = endIdx - startIdx;
}
std::copy_n(tempInArray.data(), copiedElems, output);
}
namespace {
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32},
Aidge::SliceImpl_cpu_forward_kernel<float, float>);
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32},
Aidge::SliceImpl_cpu_forward_kernel<int, int>);
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64},
Aidge::SliceImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include "aidge/backend/cpu/operator/SliceImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::SliceImpl_cpu::forward() {
assert(mOp.getInput(0) && "missing input #0");
assert(mOp.getInput(1) && "missing input #1"); //TODO fill axes (input #1) if not given
assert(mOp.getInput(2) && "missing input #2");
assert(mOp.getInput(3) && "missing input #3");
assert((mOp.getInput(1)->nbDims() == 1) && "input #1 must either be a tensor of rank 1");
assert((mOp.getInput(2)->nbDims() == 1) && "input #2 must either be a tensor of rank 1");
assert((mOp.getInput(3)->nbDims() == 1) && "input #3 must either be a tensor of rank 1");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu>::create({
mOp.getInput(0)->dataType(),
mOp.getInput(1)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->dims(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getInput(1)->dims()[0],
mOp.getInput(1)->getImpl()->rawPtr(),
mOp.getInput(2)->getImpl()->rawPtr(),
mOp.getInput(3)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr());
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/backend/cpu.hpp"
#include <memory>
using namespace Aidge;
TEST_CASE("[cpu/operator] Slice(forward)") {
SECTION("2D Tensor") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,2,4> {
{
{1, 2, 3, 4},
{5, 6, 7, 8}
}
});
std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(Array1D<int,2>{{0, 1}});
std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(Array1D<int,2>{{1, 1}});
std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,2>{{2, 4}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<int,1,3> {
{
{6, 7, 8}
}
});
std::shared_ptr<Node> mySlice = Slice();
mySlice->getOperator()->setDatatype(DataType::Int32);
mySlice->getOperator()->setBackend("cpu");
mySlice->getOperator()->associateInput(0, input);
mySlice->getOperator()->associateInput(1, axes);
mySlice->getOperator()->associateInput(2, starts);
mySlice->getOperator()->associateInput(3, ends);
mySlice->getOperator()->computeOutputDims();
mySlice->forward();
REQUIRE(*std::static_pointer_cast<Tensor>(mySlice->getOperator()->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment