Skip to content
Snippets Groups Projects
Commit f531d7f5 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Slice operator

parent ddc3bea1
No related branches found
No related tags found
No related merge requests found
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SLICE_H_
#define AIDGE_CORE_OPERATOR_SLICE_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Slice_Op : public Operator,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op&)> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 4> mInputs = {std::make_shared<Tensor>(),
std::make_shared<Tensor>(),
std::make_shared<Tensor>(),
std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char* Type = "Slice";
Slice_Op()
: Operator(Type)
{
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Slice_Op(const Slice_Op& op)
: Operator(Type),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Slice_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Slice_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Slice_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 4 && "operator Slice supports only 4 inputs");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final {
if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty())
{
const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr());
const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr());
const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr());
DimSize_t nbAxes = mInputs[1]->dims()[0];
std::vector<DimSize_t> outDims;
for(std::size_t i=0; i<mInputs[0]->dims().size();++i)
{
const int* idxPos = std::find(axes, axes + nbAxes, static_cast<int>(i));
if(idxPos != (axes + nbAxes))
{
// TODO make sure all indxes are positive before this
size_t idx = static_cast<size_t>(*idxPos);
int startVal = starts[idx];
int endVal = ends[idx];
outDims.push_back(endVal - startVal);
}
else
{
outDims.push_back(mInputs[0]->dims()[i]);
}
}
mOutput->resize(outDims);
}
}
bool outputDimsForwarded() const override final {
return !(mOutput->empty());
}
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < 4 && "wrong inputIdx for Slice operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < 4) && "Slice Operator has 4 inputs");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Slice Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 4 && "operator supports only 4 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Slice_Op>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[0]->setBackend(name);
mInputs[1]->setBackend(name);
mInputs[2]->setBackend(name);
mInputs[3]->setBackend(name);
}
void setDatatype(const DataType& datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(DataType::Int32);
mInputs[2]->setDatatype(DataType::Int32);
mInputs[3]->setDatatype(DataType::Int32);
}
inline IOIndex_t nbInputs() const noexcept override final { return 4; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 4; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "starts", "ends", "axes"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Slice(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Slice_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_SLICE_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Slice(py::module& m) {
py::class_<Slice_Op, std::shared_ptr<Slice_Op>, Operator>(m, "SliceOp", py::multiple_inheritance())
.def("get_inputs_name", &Slice_Op::getInputsName)
.def("get_outputs_name", &Slice_Op::getOutputsName);
m.def("Slice", &Slice, py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment