Skip to content
Snippets Groups Projects
Commit 8efe542d authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

change Slice attr to int64

parent 6cec7f27
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
Pipeline #35846 failed
...@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes }; ...@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op class Slice_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<int>, std::vector<int>, std::vector<int>> { public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> {
public: public:
static const std::string Type; static const std::string Type;
Slice_Op() = delete; Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<int>, std::vector<int>, std::vector<int>>; using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>;
template <SliceAttr e> template <SliceAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<int>& starts, const std::vector<int>& ends, const std::vector<int>& axes) Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts), Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends), attr<SliceAttr::Ends>(ends),
...@@ -85,9 +85,9 @@ public: ...@@ -85,9 +85,9 @@ public:
}; };
inline std::shared_ptr<Node> Slice(const std::vector<int> starts, inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts,
const std::vector<int> ends, const std::vector<std::int32_t> ends,
const std::vector<int> axes, const std::vector<std::int32_t> axes,
const std::string &name = "") { const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
......
...@@ -32,9 +32,9 @@ void Aidge::Slice_Op::computeOutputDims() { ...@@ -32,9 +32,9 @@ void Aidge::Slice_Op::computeOutputDims() {
for(std::size_t i=0; i<nbAxes;++i) for(std::size_t i=0; i<nbAxes;++i)
{ {
// For each slice operation get the params and cast them to size_t // For each slice operation get the params and cast them to size_t
int axis_ = this->template getAttr<SliceAttr::Axes>()[i]; std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
int start_ = this->template getAttr<SliceAttr::Starts>()[i]; std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
int end_ = this->template getAttr<SliceAttr::Ends>()[i]; std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
std::size_t axis = axis_>=0?axis_:axis_+getInput(0)->nbDims(); std::size_t axis = axis_>=0?axis_:axis_+getInput(0)->nbDims();
std::size_t start = start_>=0?start_:start_+getInput(0)->dims()[axis]; std::size_t start = start_>=0?start_:start_+getInput(0)->dims()[axis];
std::size_t end = end_>=0?end_:end_+getInput(0)->dims()[axis]; std::size_t end = end_>=0?end_:end_+getInput(0)->dims()[axis];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment