Skip to content
Snippets Groups Projects
Commit a28eb541 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add noop_with_empty_axes attr

parent c9700bee
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!75Learning backend cuda
......@@ -38,7 +38,10 @@ void ReduceMeanImpl_cpu_forward_kernel(const std::vector<std::int32_t>& axes,
const std::size_t nb_dims = inputDims.size();
const std::size_t totalElements = std::accumulate(inputDims.cbegin(), inputDims.cend(), 1, std::multiplies<std::size_t>());
if (axes.size() == 1) {
if (axes.empty()){
std::copy_n(input,totalElements, output);
}
else if (axes.size() == 1) {
const std::size_t stride_pre = std::accumulate(inputDims.cbegin(), inputDims.cbegin() + axes[0], 1, std::multiplies<std::size_t>());
const std::size_t stride_post = std::accumulate(inputDims.crbegin(), inputDims.crbegin() + nb_dims -1 - axes[0], 1, std::multiplies<std::size_t>());
......
......@@ -38,7 +38,10 @@ void ReduceSumImpl_cpu_forward_kernel(const std::vector<std::int32_t>& axes,
const std::size_t nb_dims = inputDims.size();
const std::size_t totalElements = std::accumulate(inputDims.cbegin(), inputDims.cend(), 1, std::multiplies<std::size_t>());
if (axes.size() == 1) {
if (axes.empty()){
std::copy_n(input,totalElements, output);
}
else if (axes.size() == 1) {
const std::size_t stride_pre = std::accumulate(inputDims.cbegin(), inputDims.cbegin() + axes[0], 1, std::multiplies<std::size_t>());
const std::size_t stride_post = std::accumulate(inputDims.crbegin(), inputDims.crbegin() + nb_dims -1 - axes[0], 1, std::multiplies<std::size_t>());
......
......@@ -157,7 +157,7 @@ TEST_CASE("[cpu/operator] ReduceMean(forward)", "[ReduceMean][CPU]") {
{18.25}
});
std::shared_ptr<Node> myReduceMean = ReduceMean({0, 1, 2}, 0);
std::shared_ptr<Node> myReduceMean = ReduceMean({}, 0);
auto op = std::static_pointer_cast<OperatorTensor>(myReduceMean -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
......@@ -179,7 +179,7 @@ TEST_CASE("[cpu/operator] ReduceMean(forward)", "[ReduceMean][CPU]") {
{0.1293547f}
});
std::shared_ptr<Node> myReduceMean = ReduceMean({0, 1}, 0);
std::shared_ptr<Node> myReduceMean = ReduceMean({}, 0);
auto op = std::static_pointer_cast<OperatorTensor>(myReduceMean -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
......@@ -189,5 +189,33 @@ TEST_CASE("[cpu/operator] ReduceMean(forward)", "[ReduceMean][CPU]") {
// approxEq<float>(*(op->getOutput(0)), *myOutput);
REQUIRE(approxEq<float>(*(op->getOutput(0)), *myOutput));
}
SECTION("noop_with_empty_axes") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,3,2,2> {
{
{
{ 5.0, 1.0 },
{ 20.0, 2.0 }
},
{
{ 30.0, 1.0 },
{ 40.0, 2.0 }
},
{
{ 55.0, 1.0 },
{ 60.0, 2.0 }
}
}
});
std::shared_ptr<Node> myReduceSum = ReduceSum({}, 0, 1);
auto op = std::static_pointer_cast<OperatorTensor>(myReduceSum -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReduceSum->forward();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *myInput);
}
}
}
\ No newline at end of file
......@@ -157,7 +157,7 @@ TEST_CASE("[cpu/operator] ReduceSum(forward)", "[ReduceSum][CPU]") {
{219.0}
});
std::shared_ptr<Node> myReduceSum = ReduceSum({0, 1, 2}, 0);
std::shared_ptr<Node> myReduceSum = ReduceSum({}, 0);
auto op = std::static_pointer_cast<OperatorTensor>(myReduceSum -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
......@@ -189,5 +189,33 @@ TEST_CASE("[cpu/operator] ReduceSum(forward)", "[ReduceSum][CPU]") {
// approxEq<float>(*(op->getOutput(0)), *myOutput);
REQUIRE(approxEq<float>(*(op->getOutput(0)), *myOutput));
}
SECTION("noop_with_empty_axes") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,3,2,2> {
{
{
{ 5.0, 1.0 },
{ 20.0, 2.0 }
},
{
{ 30.0, 1.0 },
{ 40.0, 2.0 }
},
{
{ 55.0, 1.0 },
{ 60.0, 2.0 }
}
}
});
std::shared_ptr<Node> myReduceSum = ReduceSum({}, 0, 1);
auto op = std::static_pointer_cast<OperatorTensor>(myReduceSum -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReduceSum->forward();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *myInput);
}
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment