Skip to content
Snippets Groups Projects
Commit 05c3fd5c authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add consts

parent ce73448a
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
...@@ -27,7 +27,9 @@ void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const ...@@ -27,7 +27,9 @@ void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const
const I* input = static_cast<const I*>(input_); const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
std::size_t axisIdx = std::get<2>(attrs)>=0 ? std::get<2>(attrs) : static_cast<std::size_t>(std::get<2>(attrs)) + inputDims.size(); const std::size_t axisIdx = std::get<2>(attrs)>=0 ?
std::get<2>(attrs) :
static_cast<std::size_t>(std::get<2>(attrs)) + inputDims.size();
std::size_t postAxisElems = 1; std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) { for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
...@@ -38,12 +40,12 @@ void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const ...@@ -38,12 +40,12 @@ void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const
preAxisElems *= inputDims[i]; preAxisElems *= inputDims[i];
} }
std::vector<std::int64_t> indices = std::get<0>(attrs); const std::vector<std::int64_t> indices = std::get<0>(attrs);
for (std::size_t i=0; i<preAxisElems; ++i) for (std::size_t i=0; i<preAxisElems; ++i)
{ {
for(std::size_t j=0; j<indices.size(); ++j) for(std::size_t j=0; j<indices.size(); ++j)
{ {
std::size_t idx = indices[j] >= 0 ? indices[j] : indices[j] + inputDims[axisIdx]; const std::size_t idx = indices[j] >= 0 ? indices[j] : indices[j] + inputDims[axisIdx];
const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems); const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems);
std::copy_n(startPtr, postAxisElems, output); std::copy_n(startPtr, postAxisElems, output);
output += postAxisElems; output += postAxisElems;
......
...@@ -30,7 +30,7 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs& ...@@ -30,7 +30,7 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs&
const I* input = static_cast<const I*>(input_); const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
DimSize_t keepDims = std::get<1>(attrs); const DimSize_t keepDims = std::get<1>(attrs);
// Calculate the total number of elements in the input array // Calculate the total number of elements in the input array
size_t totalElements = 1; size_t totalElements = 1;
for (size_t dimSize : inputDims) { for (size_t dimSize : inputDims) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment