Skip to content
Snippets Groups Projects

Fix Split kernel

Merged Houssem ROUIS requested to merge hrouis/aidge_core:Fix/Split into dev
1 file
+ 4
4
Compare changes
  • Side-by-side
  • Inline
+ 4
4
@@ -37,15 +37,15 @@ void Aidge::Split_OpImpl::forward() {
const std::size_t stride_post = std::accumulate(dims.crbegin(), dims.crbegin() + dims.size() -1 - axis, 1, std::multiplies<std::size_t>());
for (auto i = 0; i < op.nbOutputs(); ++i)
{
DimIdx_t chunkIdxOnAxis = std::accumulate(splits.cbegin(), splits.cbegin() + i, 0) * stride_post;
DimIdx_t offset = 0;
DimSize_t chunkIdxOnAxis = std::accumulate(splits.cbegin(), splits.cbegin() + i, 0) * stride_post;
DimSize_t offset = 0;
for (std::size_t j = 0; j < stride_pre; ++j)
{
// Compute chunk position in input tensor
DimIdx_t idx = j * stride_post * dims[axis] + chunkIdxOnAxis;
DimSize_t idx = j * stride_post * dims[axis] + chunkIdxOnAxis;
// Copy chunk in ouput
op.getOutput(i)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(idx),
splits[i] * stride_post, offset);
splits[i] * stride_post, offset);
offset += splits[i] * stride_post;
}
Loading