Skip to content
Snippets Groups Projects
Commit e9fdb3b6 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Step attribute to Slice

parent b97a42bc
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!122Add missing attributes to operators
......@@ -30,25 +30,26 @@ public:
void forward() override;
};
enum class SliceAttr { Starts, Ends, Axes };
enum class SliceAttr { Starts, Ends, Axes, Steps };
class Slice_Op
: public OperatorTensor,
public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>> {
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>> {
public:
static const std::string Type;
Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>>;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>>;
template <SliceAttr e> using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes)
: OperatorTensor(Type, 4, 0, 1),
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes, const std::vector<std::int64_t>& steps)
: OperatorTensor(Type, 5, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends),
attr<SliceAttr::Axes>(axes))
attr<SliceAttr::Axes>(axes),
attr<SliceAttr::Steps>(steps))
{
mImpl = std::make_shared<Slice_OpImpl>(*this);
}
......@@ -83,11 +84,12 @@ public:
void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
static const std::vector<std::string> getInputsName(){
return {"data_input", "starts", "ends", "axes"};
return {"data_input", "starts", "ends", "axes", "steps"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
/**
......@@ -98,14 +100,32 @@ public:
inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t>& starts = {},
const std::vector<std::int64_t>& ends = {},
const std::vector<std::int8_t>& axes = {},
const std::vector<std::int64_t>& steps = {},
const std::string &name = "") {
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes, steps), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" };
const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" };
}
namespace Aidge {
class SliceImplForward
: public Registrable<SliceImplForward,
std::tuple<DataType>,
void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
template <typename I>
void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_);
namespace {
static Registrar<SliceImplForward> registrarSliceImplForward_Float32(
{DataType::Float32}, Slice_forward_kernel<float>);
static Registrar<SliceImplForward> registrarSliceImplForward_Int32(
{DataType::Int32}, Slice_forward_kernel<int>);
static Registrar<SliceImplForward> registrarSliceImplForward_Int64(
{DataType::Float64}, Slice_forward_kernel<double>);
}
}
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
......@@ -30,6 +30,7 @@ void init_Slice(py::module& m) {
py::arg("starts") = std::vector<std::int64_t>(),
py::arg("ends") = std::vector<std::int64_t>(),
py::arg("axes") = std::vector<std::int8_t>(),
py::arg("steps") = std::vector<std::int64_t>(),
py::arg("name") = "");
}
} // namespace Aidge
......@@ -11,6 +11,7 @@
#include "aidge/operator/Slice.hpp"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
......@@ -25,88 +26,91 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
#include "aidge/data/Data.hpp"
#include "aidge/utils/Registrar.hpp"
void Aidge::Slice_OpImpl::forward() {
const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp);
if (!op.getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type);
}
AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) &&
(op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
"start, end and axes arguments should be the same size.");
const std::size_t nbDims = op.getInput(0)->nbDims();
const std::vector<std::size_t>& inputDims = op.getInput(0)->dims();
auto outputDims = op.getInput(0)->dims();
template<class I>
void Aidge::Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_){
const I* input = static_cast<const I*>(input_);
I* output = static_cast<I*>(output_);
// compute index of the output's first element
// compute output dimension at the same time (may change between two forward calls)
std::size_t beginning = 0;
const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size();
const std::size_t nbDims = inputDims.size();
std::vector<DimSize_t> dims = inputDims;
DimSize_t totalSize = std::accumulate(inputDims.cbegin(), inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
I* outputAccumulation = new I[totalSize];
const I* inputAccumulation = input;
const std::size_t nbAxes = std::get<0>(attrs).size();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size()));
DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis]));
DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis]));
const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
beginning += start * stridePostAxis;
const std::size_t sliceLength = end - start;
outputDims[axis] = sliceLength;
}
op.getOutput(0)->resize(outputDims);
DimIdx_t axis = std::get<2>(attrs)[i] >= 0 ?
static_cast<DimIdx_t>(std::get<2>(attrs)[i]) :
static_cast<DimIdx_t>(std::get<2>(attrs)[i] + static_cast<DimIdx_t>(inputDims.size()));
std::int64_t start = std::get<0>(attrs)[i] >= 0 ?
std::get<0>(attrs)[i] :
std::get<0>(attrs)[i] + static_cast<std::int64_t>(inputDims[axis]);
std::int64_t end = std::get<1>(attrs)[i] >= 0 ?
std::get<1>(attrs)[i] :
std::get<1>(attrs)[i] + static_cast<std::int64_t>(inputDims[axis]);
std::int64_t step = std::get<3>(attrs)[i];
std::size_t sliceSize = static_cast<std::size_t>((end - start) / std::abs(step));
// for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3}
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
for (std::size_t i = 0; i < nbDims; ++i) {
substractedDims[i] = inputDims[i] - outputDims[i];
}
// for outputDims = {3,2,2,1}: prodOutputDims = {12,4,2,1}
std::vector<std::size_t> prodOutputDims = std::vector<std::size_t>(nbDims);
std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims + 1);
prodOutputDims[nbDims - 1] = outputDims[nbDims - 1];
prodInputDims[nbDims - 1] = inputDims[nbDims - 1];
prodInputDims[nbDims] = 1;
for (std::size_t i = 2; i <= nbDims; ++i) {
prodOutputDims[nbDims - i] = prodOutputDims[nbDims - i + 1] * outputDims[nbDims - i];
prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1] * inputDims[nbDims - i];
}
if ( i > 0) {
outputAccumulation = new I[totalSize];
}
const std::size_t stride_pre = std::accumulate(dims.cbegin(), dims.cbegin() + axis, 1, std::multiplies<std::size_t>());
const std::size_t stride_post = std::accumulate(dims.crbegin(), dims.crbegin() + nbDims -1 - axis, 1, std::multiplies<std::size_t>());
std::int64_t firstElem = step > 0 ? start : end;
std::int64_t lastElem = step > 0 ? end : start;
std::size_t i = beginning;
std::size_t size = 0; // number of elements to copy
std::size_t offset = 0;
for (std::size_t j = 0; j < prodOutputDims[0];) {
++size;
++i;
++j;
bool newChunk = false;
for (std::size_t idx = nbDims - 1; idx > 0; --idx) {
if (j % prodOutputDims[idx] == 0) {
i += substractedDims[idx] * prodInputDims[idx + 1];
newChunk = true;
for (std::size_t outer = 0; outer < stride_pre; outer++)
{
std::size_t addedSlices = 0;
for (std::int64_t inner = firstElem; inner < lastElem; inner+=step)
{
size_t idx = outer * stride_post * dims[axis] + inner * stride_post;
size_t idx_out = outer * stride_post * sliceSize + addedSlices * stride_post;
if (idx < totalSize) {
std::copy_n(std::next(inputAccumulation, idx), stride_post, std::next(outputAccumulation, idx_out));
}
addedSlices++;
}
}
if (newChunk) {
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset);
beginning = i;
offset += size;
size = 0;
totalSize /= dims[axis];
totalSize *= sliceSize;
dims[axis] = sliceSize;
if (inputAccumulation != input) {
delete[] inputAccumulation;
}
inputAccumulation = outputAccumulation;
}
// Copy elements from inputAccumulation to output while dividing by divisor
std::copy_n(inputAccumulation, totalSize, output);
// op.getOutput(0)->getImpl()->copy(inputAccumulation, totalSize);
if (outputAccumulation) {
delete[] outputAccumulation;
}
}
if (size > 0) {
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset);
void Aidge::Slice_OpImpl::forward() {
const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp);
if (!op.getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type);
}
AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) &&
(op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
"start, end and axes arguments should be the same size.");
// Find the correct kernel type
auto kernelFunc =
Registrar<SliceImplForward>::create({std::static_pointer_cast<Tensor>(op.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const Slice_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(op.getInput(0))->getImpl()->hostPtr(),
std::static_pointer_cast<Tensor>(op.getOutput(0))->getImpl()->hostPtr());
}
const std::string Aidge::Slice_Op::Type = "Slice";
......@@ -127,7 +131,7 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType.");
this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Starts>().clear();
this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
this->template getAttr<SliceAttr::Ends>().clear();
this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size());
......@@ -179,11 +183,46 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Input DataType is not supported.", type());
break;
}
}
// Fill Steps attr if empty
if(this->template getAttr<SliceAttr::Steps>().empty()) {
// In case the input Steps is not provided, default value is 1
this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(getInput(1)->size(), 1);
if (getInput(4) && !getInput(4)->empty()) {
this->template getAttr<SliceAttr::Steps>().clear();
this->template getAttr<SliceAttr::Steps>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
}
}
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
......@@ -197,7 +236,10 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
const std::size_t sliceLength = end - start;
if(this->template getAttr<SliceAttr::Steps>()[i] == 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
}
const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i]));
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
......
......@@ -106,7 +106,11 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
std::vector<std::int8_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
// Create Slice's Steps attribute
std::vector<std::int64_t> steps(inputDimsEnd.size());
std::iota(steps.begin(), steps.end(), static_cast<std::int64_t>(1));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, steps, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i);
......
......@@ -69,7 +69,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu");
mySlice->forward();
// mySlice->getOperator()->output(0).print();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
......@@ -176,7 +176,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu");
mySlice->forward();
// mySlice->getOperator()->output(0).print();
// op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
......@@ -217,13 +217,13 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
}
});
std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3});
std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3}, {1,1,1,1});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu");
mySlice->forward();
// mySlice->getOperator()->output(0).print();
// op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment