Skip to content
Snippets Groups Projects
Commit 03a58652 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

change attrs back to int64_t

parent 7ec8b5d6
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
Pipeline #38131 failed
...@@ -40,7 +40,7 @@ void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const ...@@ -40,7 +40,7 @@ void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const
preAxisElems *= inputDims[i]; preAxisElems *= inputDims[i];
} }
const std::vector<std::int32_t> indices = std::get<0>(attrs); const std::vector<std::int64_t> indices = std::get<0>(attrs);
for (std::size_t i=0; i<preAxisElems; ++i) for (std::size_t i=0; i<preAxisElems; ++i)
{ {
for(std::size_t j=0; j<indices.size(); ++j) for(std::size_t j=0; j<indices.size(); ++j)
......
...@@ -32,9 +32,9 @@ void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs, ...@@ -32,9 +32,9 @@ void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs,
DimSize_t nbAxes = std::get<2>(attrs).size(); DimSize_t nbAxes = std::get<2>(attrs).size();
for (std::size_t i = 0; i < nbAxes; ++i) { for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t // For each slice operation get the params and cast them to size_t
const std::int32_t axis_ = std::get<2>(attrs)[i]; const std::int64_t axis_ = std::get<2>(attrs)[i];
const std::int32_t start_ = std::get<0>(attrs)[i]; const std::int64_t start_ = std::get<0>(attrs)[i];
const std::int32_t end_ = std::get<1>(attrs)[i]; const std::int64_t end_ = std::get<1>(attrs)[i];
const std::size_t axis = axis_ >= 0 ? axis_ : static_cast<std::size_t>(axis_) + inputDims.size(); const std::size_t axis = axis_ >= 0 ? axis_ : static_cast<std::size_t>(axis_) + inputDims.size();
const std::size_t start = start_ >= 0 ? start_ : start_ + inputDims[axis]; const std::size_t start = start_ >= 0 ? start_ : start_ + inputDims[axis];
const std::size_t end = end_ >= 0 ? end_ : end_ + inputDims[axis]; const std::size_t end = end_ >= 0 ? end_ : end_ + inputDims[axis];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment