Skip to content
Snippets Groups Projects
Commit f531d7f5 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Slice operator

parent ddc3bea1
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SLICE_H_
#define AIDGE_CORE_OPERATOR_SLICE_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Slice_Op : public Operator,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op&)> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 4> mInputs = {std::make_shared<Tensor>(),
std::make_shared<Tensor>(),
std::make_shared<Tensor>(),
std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char* Type = "Slice";
Slice_Op()
: Operator(Type)
{
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Slice_Op(const Slice_Op& op)
: Operator(Type),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Slice_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Slice_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Slice_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 4 && "operator Slice supports only 4 inputs");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final {
if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty())
{
const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr());
const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr());
const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr());
DimSize_t nbAxes = mInputs[1]->dims()[0];
std::vector<DimSize_t> outDims;
for(std::size_t i=0; i<mInputs[0]->dims().size();++i)
{
const int* idxPos = std::find(axes, axes + nbAxes, static_cast<int>(i));
if(idxPos != (axes + nbAxes))
{
// TODO make sure all indxes are positive before this
size_t idx = static_cast<size_t>(*idxPos);
int startVal = starts[idx];
int endVal = ends[idx];
outDims.push_back(endVal - startVal);
}
else
{
outDims.push_back(mInputs[0]->dims()[i]);
}
}
mOutput->resize(outDims);
}
}
bool outputDimsForwarded() const override final {
return !(mOutput->empty());
}
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < 4 && "wrong inputIdx for Slice operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < 4) && "Slice Operator has 4 inputs");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Slice Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 4 && "operator supports only 4 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Slice_Op>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[0]->setBackend(name);
mInputs[1]->setBackend(name);
mInputs[2]->setBackend(name);
mInputs[3]->setBackend(name);
}
void setDatatype(const DataType& datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(DataType::Int32);
mInputs[2]->setDatatype(DataType::Int32);
mInputs[3]->setDatatype(DataType::Int32);
}
inline IOIndex_t nbInputs() const noexcept override final { return 4; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 4; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "starts", "ends", "axes"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Slice(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Slice_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_SLICE_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Slice(py::module& m) {
py::class_<Slice_Op, std::shared_ptr<Slice_Op>, Operator>(m, "SliceOp", py::multiple_inheritance())
.def("get_inputs_name", &Slice_Op::getInputsName)
.def("get_outputs_name", &Slice_Op::getOutputsName);
m.def("Slice", &Slice, py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment