Skip to content
Snippets Groups Projects
Commit ade64013 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

switch slice attrs into inputs

parent 76ff38f0
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!93Change Gather and Slice's attributes into intputs
...@@ -20,31 +20,16 @@ ...@@ -20,31 +20,16 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op class Slice_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>{
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
public: public:
static const std::string Type; static const std::string Type;
Slice_Op() = delete; Slice_Op() : OperatorTensor(Type, 4, 0, 1) {}
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>;
template <SliceAttr e>
using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends),
attr<SliceAttr::Axes>(axes))
{}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its
...@@ -52,8 +37,7 @@ public: ...@@ -52,8 +37,7 @@ public:
* @param op Operator to copy. * @param op Operator to copy.
*/ */
Slice_Op(const Slice_Op &op) Slice_Op(const Slice_Op &op)
: OperatorTensor(op), : OperatorTensor(op)
Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(Slice_Op, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Slice_Op, *this, op.mOutputs[0]->getImpl()->backend());
...@@ -77,7 +61,7 @@ public: ...@@ -77,7 +61,7 @@ public:
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input", "starts", "ends", "axes"};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
...@@ -86,29 +70,12 @@ public: ...@@ -86,29 +70,12 @@ public:
/** /**
* @brief Exract a sub-Tensor from a bigger original Tensor. * @brief Exract a sub-Tensor from a bigger original Tensor.
* @param starts Indexes for each dimension of the first element.
* Can be a negative value. Negative values start their reference from the last index.
* ``-1`` referes to the last index of a dimension.
* @param ends Indexes for each dimension of the last element.
* Can be a negative value. Negative values start their reference from the last index.
* ``-1`` referes to the last index of a dimension.
* @param axes Dimensions for which start/end indexes apply. Not specifying a dimensions
* means the whole dimensions is extracted.
* @param name Name of the Operator. * @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator. * @return std::shared_ptr<Node> A Node containing the Operator.
*/ */
inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts, inline std::shared_ptr<Node> Slice(const std::string &name = "") {
const std::vector<std::int64_t> ends, return std::make_shared<Node>(std::make_shared<Slice_Op>(), name);
const std::vector<std::int64_t> axes,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
} }
} // namespace Aidge } // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" };
}
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
...@@ -25,6 +25,6 @@ void init_Gather(py::module& m) { ...@@ -25,6 +25,6 @@ void init_Gather(py::module& m) {
.def("attributes_name", &Gather_Op::staticGetAttrsName); .def("attributes_name", &Gather_Op::staticGetAttrsName);
declare_registrable<Gather_Op>(m, "GatherOp"); declare_registrable<Gather_Op>(m, "GatherOp");
m.def("Gather", &Gather, py::arg("axis")=0, py::arg("name") = ""); m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -22,6 +22,7 @@ void init_Slice(py::module& m) { ...@@ -22,6 +22,7 @@ void init_Slice(py::module& m) {
.def("get_inputs_name", &Slice_Op::getInputsName) .def("get_inputs_name", &Slice_Op::getInputsName)
.def("get_outputs_name", &Slice_Op::getOutputsName); .def("get_outputs_name", &Slice_Op::getOutputsName);
declare_registrable<Slice_Op>(m, "SliceOp"); declare_registrable<Slice_Op>(m, "SliceOp");
m.def("Slice", &Slice, py::arg("starts"), py::arg("ends"), py::arg("axes"), py::arg("name") = "");
m.def("Slice", &Slice, py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -8,17 +8,16 @@ ...@@ -8,17 +8,16 @@
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -26,28 +25,50 @@ const std::string Aidge::Slice_Op::Type = "Slice"; ...@@ -26,28 +25,50 @@ const std::string Aidge::Slice_Op::Type = "Slice";
void Aidge::Slice_Op::computeOutputDims() { void Aidge::Slice_Op::computeOutputDims() {
// check input have been associated // check input have been associated
if (!getInput(0) || (getInput(0)->empty())) { if (!getInput(0) || !getInput(1) || !getInput(2) || !getInput(3)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
} }
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); if((!getInput(0)->empty()) && (!getInput(1)->empty()) && (!getInput(2)->empty()) && (!getInput(3)->empty()))
std::vector<DimSize_t> outDims = getInput(0)->dims(); {
for (std::size_t i = 0; i < nbAxes; ++i) { const auto starts = mInputs[1]->getImpl()->rawPtr();
// For each slice operation get the params and cast them to size_t const auto ends = mInputs[2]->getImpl()->rawPtr();
const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; const auto axes = mInputs[3]->getImpl()->rawPtr();
const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; DimSize_t nbAxes = mInputs[1]->size();
const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; std::vector<DimSize_t> outDims = getInput(0)->dims();
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims(); for (std::size_t i = 0; i < nbAxes; ++i) {
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis]; // For each slice operation get the params and cast them to size_t
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis]; std::size_t axis, start, end; //TODO find a better way to cast "starts", "ends" and "axes"
if (mInputs[1]->dataType() == DataType::Float32 && mInputs[2]->dataType() == DataType::Float32 && mInputs[3]->dataType() == DataType::Float32)
const std::size_t sliceLength = end - start + 1; {
// Check if slice length is valid const float* axes_ = static_cast<float*>(axes);
if (sliceLength > getInput(0)->dims()[axis]) axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims();
{ const float* starts_ = static_cast<float*>(starts);
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis];
const float* ends_ = static_cast<float*>(ends);
end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]];
}
else if(mInputs[1]->dataType() == DataType::Int32 && mInputs[2]->dataType() == DataType::Int32 && mInputs[3]->dataType() == DataType::Int32)
{
const std::int32_t* axes_ = static_cast<std::int32_t*>(axes);
axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims();
const std::int32_t* starts_ = static_cast<std::int32_t*>(starts);
start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis];
const std::int32_t* ends_ = static_cast<std::int32_t*>(ends);
end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]];
}
else
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice inputs type is not supported yet");
}
const std::size_t sliceLength = end - start;
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
}
outDims[axis] = sliceLength;
} }
outDims[axis] = sliceLength; mOutputs[0]->resize(outDims);
} }
mOutputs[0]->resize(outDims); }
} \ No newline at end of file
...@@ -93,7 +93,10 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -93,7 +93,10 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
} }
std::vector<std::int64_t> usedDims(inputDimsEnd.size()); std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0)); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); Tensor(std::vector<std::size_t>({inputDimsStart.size()}));
// TODO create producer nodes for the attributes
// auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0); slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i); newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment