Skip to content
Snippets Groups Projects
Commit ade64013 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

switch slice attrs into inputs

parent 76ff38f0
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!93Change Gather and Slice's attributes into intputs
......@@ -20,31 +20,16 @@
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op
: public OperatorTensor,
public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>{
public:
static const std::string Type;
Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>;
template <SliceAttr e>
using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends),
attr<SliceAttr::Axes>(axes))
{}
Slice_Op() : OperatorTensor(Type, 4, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its
......@@ -52,8 +37,7 @@ public:
* @param op Operator to copy.
*/
Slice_Op(const Slice_Op &op)
: OperatorTensor(op),
Attributes_(op)
: OperatorTensor(op)
{
if (op.mImpl){
SET_IMPL_MACRO(Slice_Op, *this, op.mOutputs[0]->getImpl()->backend());
......@@ -77,7 +61,7 @@ public:
}
static const std::vector<std::string> getInputsName(){
return {"data_input"};
return {"data_input", "starts", "ends", "axes"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
......@@ -86,29 +70,12 @@ public:
/**
* @brief Exract a sub-Tensor from a bigger original Tensor.
* @param starts Indexes for each dimension of the first element.
* Can be a negative value. Negative values start their reference from the last index.
* ``-1`` referes to the last index of a dimension.
* @param ends Indexes for each dimension of the last element.
* Can be a negative value. Negative values start their reference from the last index.
* ``-1`` referes to the last index of a dimension.
* @param axes Dimensions for which start/end indexes apply. Not specifying a dimensions
* means the whole dimensions is extracted.
* @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator.
*/
inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts,
const std::vector<std::int64_t> ends,
const std::vector<std::int64_t> axes,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
inline std::shared_ptr<Node> Slice(const std::string &name = "") {
return std::make_shared<Node>(std::make_shared<Slice_Op>(), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" };
}
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
......@@ -25,6 +25,6 @@ void init_Gather(py::module& m) {
.def("attributes_name", &Gather_Op::staticGetAttrsName);
declare_registrable<Gather_Op>(m, "GatherOp");
m.def("Gather", &Gather, py::arg("axis")=0, py::arg("name") = "");
m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("name") = "");
}
} // namespace Aidge
......@@ -22,6 +22,7 @@ void init_Slice(py::module& m) {
.def("get_inputs_name", &Slice_Op::getInputsName)
.def("get_outputs_name", &Slice_Op::getOutputsName);
declare_registrable<Slice_Op>(m, "SliceOp");
m.def("Slice", &Slice, py::arg("starts"), py::arg("ends"), py::arg("axes"), py::arg("name") = "");
m.def("Slice", &Slice, py::arg("name") = "");
}
} // namespace Aidge
......@@ -8,17 +8,16 @@
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
......@@ -26,28 +25,50 @@ const std::string Aidge::Slice_Op::Type = "Slice";
void Aidge::Slice_Op::computeOutputDims() {
// check input have been associated
if (!getInput(0) || (getInput(0)->empty())) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
if (!getInput(0) || !getInput(1) || !getInput(2) || !getInput(3)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis];
const std::size_t sliceLength = end - start + 1;
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
if((!getInput(0)->empty()) && (!getInput(1)->empty()) && (!getInput(2)->empty()) && (!getInput(3)->empty()))
{
const auto starts = mInputs[1]->getImpl()->rawPtr();
const auto ends = mInputs[2]->getImpl()->rawPtr();
const auto axes = mInputs[3]->getImpl()->rawPtr();
DimSize_t nbAxes = mInputs[1]->size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
std::size_t axis, start, end; //TODO find a better way to cast "starts", "ends" and "axes"
if (mInputs[1]->dataType() == DataType::Float32 && mInputs[2]->dataType() == DataType::Float32 && mInputs[3]->dataType() == DataType::Float32)
{
const float* axes_ = static_cast<float*>(axes);
axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims();
const float* starts_ = static_cast<float*>(starts);
start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis];
const float* ends_ = static_cast<float*>(ends);
end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]];
}
else if(mInputs[1]->dataType() == DataType::Int32 && mInputs[2]->dataType() == DataType::Int32 && mInputs[3]->dataType() == DataType::Int32)
{
const std::int32_t* axes_ = static_cast<std::int32_t*>(axes);
axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims();
const std::int32_t* starts_ = static_cast<std::int32_t*>(starts);
start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis];
const std::int32_t* ends_ = static_cast<std::int32_t*>(ends);
end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]];
}
else
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice inputs type is not supported yet");
}
const std::size_t sliceLength = end - start;
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
}
outDims[axis] = sliceLength;
}
outDims[axis] = sliceLength;
mOutputs[0]->resize(outDims);
}
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
......@@ -93,7 +93,10 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
}
std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
Tensor(std::vector<std::size_t>({inputDimsStart.size()}));
// TODO create producer nodes for the attributes
// auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment