Skip to content
Snippets Groups Projects
Commit fc2e005e authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix Slice outputDims and add test

parent 7717ad3c
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
......@@ -26,12 +26,12 @@ namespace Aidge {
class SliceImplForward_cpu
: public Registrable<SliceImplForward_cpu,
std::tuple<DataType, DataType>,
void(const std::vector<DimSize_t>, const void*, DimSize_t, const void*, const void*, const void*, void*)> {
void(const std::vector<DimSize_t>&, const void*, DimSize_t, const void*, const void*, const void*, void*)> {
};
class SliceImplBackward_cpu
: public Registrable<SliceImplBackward_cpu,
std::tuple<DataType, DataType>,
void(const std::vector<DimSize_t>, const void*, DimSize_t, const void*, const void*, const void*, void*)> {
void(const std::vector<DimSize_t>&, const void*, DimSize_t, const void*, const void*, const void*, void*)> {
};
class SliceImpl_cpu : public OperatorImpl {
......
......@@ -18,7 +18,7 @@
namespace Aidge {
template <class I, class O>
void SliceImpl_cpu_forward_kernel(const std::vector<DimSize_t> inputDims,
void SliceImpl_cpu_forward_kernel(const std::vector<DimSize_t>& inputDims,
const void* input_,
DimSize_t nbSlices,
const void* axes_,
......@@ -60,7 +60,7 @@ void SliceImpl_cpu_forward_kernel(const std::vector<DimSize_t> inputDims,
{
strideOnCurrDim *= currentDims[j];
}
size_t sliceSize = (endIdx - startIdx) * strideOnCurrDim;
size_t sliceSize = (endIdx - startIdx + 1) * strideOnCurrDim;
// For each slice operation, we will slice all elements on the axis (subSlice)
// the number of sublices is the product of dimension previous to the slice dimension
......@@ -74,7 +74,7 @@ void SliceImpl_cpu_forward_kernel(const std::vector<DimSize_t> inputDims,
for(size_t s=0; s<nbSubSlices; ++s)
{
// Compute the pointer postion on input
std::size_t copyStartPos = s * strideOnCurrDim * currentDims[axisIdx] + strideOnCurrDim;
std::size_t copyStartPos = s * strideOnCurrDim * currentDims[axisIdx] + startIdx * strideOnCurrDim;
const I* copyPtr = std::next(tempInArray.data(), copyStartPos);
// Copy slice to output array and update pointer
std::copy_n(copyPtr, sliceSize , tempOutArrayPtr);
......@@ -84,7 +84,7 @@ void SliceImpl_cpu_forward_kernel(const std::vector<DimSize_t> inputDims,
// Update the input for the next slice operation
tempInArray.assign(tempOutArray.begin(), tempOutArray.begin() + copiedElems);
currentDims[axisIdx] = endIdx - startIdx;
currentDims[axisIdx] = endIdx - startIdx + 1;
}
std::copy_n(tempInArray.data(), copiedElems, output);
......
......@@ -17,6 +17,7 @@
#include "aidge/backend/cpu.hpp"
#include <memory>
#include <iostream>
using namespace Aidge;
......@@ -30,7 +31,7 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
});
std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(Array1D<int,2>{{0, 1}});
std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(Array1D<int,2>{{1, 1}});
std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,2>{{2, 4}});
std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,2>{{1, 3}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<int,1,3> {
{
{6, 7, 8}
......@@ -47,9 +48,54 @@ TEST_CASE("[cpu/operator] Slice(forward)") {
mySlice->getOperator()->computeOutputDims();
mySlice->forward();
REQUIRE(*std::static_pointer_cast<Tensor>(mySlice->getOperator()->getOutput(0)) == *expectedOutput);
REQUIRE(*(mySlice->getOperator()->getOutput(0)) == *expectedOutput);
}
SECTION("3D Tensor") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array3D<int,2,4,3> {
{
{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
{10, 11, 12}
},
{
{13, 14, 15},
{16, 17, 18},
{18, 20, 21},
{22, 23, 24}
}
}
});
std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(Array1D<int,2>{{1, 2}});
std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(Array1D<int,2>{{0, 2}});
std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,2>{{2, 2}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,2,3,1> {
{
{
{3},
{6},
{9}
},
{
{15},
{18},
{21}
}
}
});
std::shared_ptr<Node> mySlice = Slice();
mySlice->getOperator()->setDatatype(DataType::Int32);
mySlice->getOperator()->setBackend("cpu");
mySlice->getOperator()->associateInput(0, input);
mySlice->getOperator()->associateInput(1, axes);
mySlice->getOperator()->associateInput(2, starts);
mySlice->getOperator()->associateInput(3, ends);
mySlice->getOperator()->computeOutputDims();
mySlice->forward();
REQUIRE(*(mySlice->getOperator()->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment