Skip to content
Snippets Groups Projects
Commit 7717ad3c authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

minor code cleanings

parent 3b6e6292
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
...@@ -24,10 +24,10 @@ namespace Aidge { ...@@ -24,10 +24,10 @@ namespace Aidge {
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class GatherImplForward_cpu class GatherImplForward_cpu
: public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(const int, const std::vector<DimSize_t>, const std::vector<DimSize_t>, const void*, const void*, void*)> { : public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> {
}; };
class GatherImplBackward_cpu class GatherImplBackward_cpu
: public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(const int, const std::vector<DimSize_t>, const std::vector<DimSize_t>, const void*, const void*, void*)> { : public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> {
}; };
class GatherImpl_cpu : public OperatorImpl { class GatherImpl_cpu : public OperatorImpl {
......
...@@ -22,38 +22,31 @@ ...@@ -22,38 +22,31 @@
namespace Aidge { namespace Aidge {
template <class I, class O> template <class I, class O>
void GatherImpl_cpu_forward_kernel(const int& axisIdx_, std::vector<DimSize_t> inputDims, const std::vector<DimSize_t> indicesDims, const void* input_, const void* indexes_, void* output_) void GatherImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const std::vector<DimSize_t>& indicesDims, const void* input_, const void* indexes_, void* output_)
{ {
const I* input = static_cast<const I*>(input_); const I* input = static_cast<const I*>(input_);
const int* indexes = static_cast<const int*>(indexes_); const int* indexes = static_cast<const int*>(indexes_);
const std::size_t axisIdx = axisIdx_;
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
// Calculate the total number of elements in the input array std::size_t postAxisElems = 1;
size_t totalElements = 1; for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
for (size_t dimSize : inputDims) { postAxisElems *= inputDims[i];
totalElements *= dimSize;
} }
std::size_t nbElemAfterAxis = 1; std::size_t preAxisElems = 1;
std::size_t nbElemBeforeAxis = 1; for (std::size_t i = 0; i < axisIdx; ++i) {
preAxisElems *= inputDims[i];
for (size_t d = 0; d < inputDims.size(); ++d) {
if( d < axisIdx )
nbElemBeforeAxis *= inputDims[d];
else if ( d > axisIdx )
nbElemAfterAxis *= inputDims[d];
} }
for (std::size_t i=0; i<nbElemBeforeAxis; ++i) for (std::size_t i=0; i<preAxisElems; ++i)
{ {
for(std::size_t idxRow=0; idxRow<indicesDims[0]; ++idxRow) for(std::size_t idxRow=0; idxRow<indicesDims[0]; ++idxRow)
{ {
for(std::size_t idxCol=0; idxCol<indicesDims[1]; ++idxCol) for(std::size_t idxCol=0; idxCol<indicesDims[1]; ++idxCol)
{ {
std::size_t idx = indexes[indicesDims[1] * idxRow + idxCol]; std::size_t idx = indexes[indicesDims[1] * idxRow + idxCol];
const I* startPtr = std::next(input, i * nbElemAfterAxis * inputDims[axisIdx] + idx * nbElemAfterAxis); const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems);
std::copy_n(startPtr, nbElemAfterAxis, output); std::copy_n(startPtr, postAxisElems, output);
output += nbElemAfterAxis; output += postAxisElems;
} }
} }
} }
......
...@@ -24,10 +24,10 @@ namespace Aidge { ...@@ -24,10 +24,10 @@ namespace Aidge {
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class ReshapeImplForward_cpu class ReshapeImplForward_cpu
: public Registrable<ReshapeImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { : public Registrable<ReshapeImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const void*, void*)> {
}; };
class ReshapeImplBackward_cpu class ReshapeImplBackward_cpu
: public Registrable<ReshapeImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { : public Registrable<ReshapeImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const void*, void*)> {
}; };
class ReshapeImpl_cpu : public OperatorImpl { class ReshapeImpl_cpu : public OperatorImpl {
......
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
#include "aidge/backend/cpu/operator/GatherImpl.hpp" #include "aidge/backend/cpu/operator/GatherImpl.hpp"
#include "aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp"
#include <iostream>
Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
return 0; return 0;
......
...@@ -41,7 +41,7 @@ TEST_CASE("[cpu/operator] Erf(forward)") { ...@@ -41,7 +41,7 @@ TEST_CASE("[cpu/operator] Erf(forward)") {
float* resPtr = static_cast<float*>(myErf->getOperator()->getOutput(0)->getImpl()->rawPtr()); float* resPtr = static_cast<float*>(myErf->getOperator()->getOutput(0)->getImpl()->rawPtr());
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr()); float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 10; ++i) { for (std::size_t i = 0; i< expectedOutput->size(); ++i) {
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001); REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
} }
} }
...@@ -81,7 +81,7 @@ TEST_CASE("[cpu/operator] Erf(forward)") { ...@@ -81,7 +81,7 @@ TEST_CASE("[cpu/operator] Erf(forward)") {
float* resPtr = static_cast<float*>(myErf->getOperator()->getOutput(0)->getImpl()->rawPtr()); float* resPtr = static_cast<float*>(myErf->getOperator()->getOutput(0)->getImpl()->rawPtr());
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr()); float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 12; ++i) { for (std::size_t i = 0; i< expectedOutput->size(); ++i) {
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001); REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
} }
} }
......
...@@ -22,10 +22,10 @@ using namespace Aidge; ...@@ -22,10 +22,10 @@ using namespace Aidge;
TEST_CASE("[cpu/operator] Reshape(forward)") { TEST_CASE("[cpu/operator] Reshape(forward)") {
SECTION("1D Tensor") { SECTION("1D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array1D<float,6> { std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> {
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0} {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
}); });
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<int,2>{{2, 3}}); std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<int,2>{{2, 3}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> { std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> {
{ {
{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0},
...@@ -36,28 +36,22 @@ TEST_CASE("[cpu/operator] Reshape(forward)") { ...@@ -36,28 +36,22 @@ TEST_CASE("[cpu/operator] Reshape(forward)") {
std::shared_ptr<Node> myReshape = Reshape(); std::shared_ptr<Node> myReshape = Reshape();
myReshape->getOperator()->setDatatype(DataType::Float32); myReshape->getOperator()->setDatatype(DataType::Float32);
myReshape->getOperator()->setBackend("cpu"); myReshape->getOperator()->setBackend("cpu");
myReshape->getOperator()->associateInput(0, input_1); myReshape->getOperator()->associateInput(0, input);
myReshape->getOperator()->associateInput(1, input_2); myReshape->getOperator()->associateInput(1, shape);
myReshape->getOperator()->computeOutputDims(); myReshape->getOperator()->computeOutputDims();
myReshape->forward(); myReshape->forward();
float* resPtr = static_cast<float*>(myReshape->getOperator()->getOutput(0)->getImpl()->rawPtr()); REQUIRE(*(myReshape->getOperator()->getOutput(0)) == *expectedOutput);
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 6; ++i) {
printf("res %f, expected %f", resPtr[i], expectedPtr[i]);
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
}
} }
SECTION("2D Tensor") { SECTION("2D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array2D<float,2,3> { std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> {
{ {
{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} {4.0, 5.0, 6.0}
} }
}); });
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<int,2>{{3, 2}}); std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<int,2>{{3, 2}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> { std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> {
{ {
{1.0, 2.0}, {1.0, 2.0},
...@@ -69,17 +63,11 @@ TEST_CASE("[cpu/operator] Reshape(forward)") { ...@@ -69,17 +63,11 @@ TEST_CASE("[cpu/operator] Reshape(forward)") {
std::shared_ptr<Node> myReshape = Reshape(); std::shared_ptr<Node> myReshape = Reshape();
myReshape->getOperator()->setDatatype(DataType::Float32); myReshape->getOperator()->setDatatype(DataType::Float32);
myReshape->getOperator()->setBackend("cpu"); myReshape->getOperator()->setBackend("cpu");
myReshape->getOperator()->associateInput(0, input_1); myReshape->getOperator()->associateInput(0, input);
myReshape->getOperator()->associateInput(1, input_2); myReshape->getOperator()->associateInput(1, shape);
myReshape->getOperator()->computeOutputDims(); myReshape->getOperator()->computeOutputDims();
myReshape->forward(); myReshape->forward();
float* resPtr = static_cast<float*>(myReshape->getOperator()->getOutput(0)->getImpl()->rawPtr()); REQUIRE(*(myReshape->getOperator()->getOutput(0)) == *expectedOutput);
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 6; ++i) {
printf("res %f, expected %f", resPtr[i], expectedPtr[i]);
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
}
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment