diff --git a/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp
index b7e7924b052a55a78ded012e4a13e90b64bab8ee..cc9a0a7df98f8ab6a185a3aecc650d78f04e14a6 100644
--- a/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp
@@ -42,34 +42,34 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs&
     std::vector<I> tempOutArray(input, input + totalElements);
     std::vector<size_t> currentDims = inputDims;
 
-
     std::size_t addedElems = 0;
     for(std::size_t i=0; i<DIM ; ++i)
     {
 		addedElems = 0;
+		int axis_ = std::get<0>(attrs)[i];
+		std::size_t axis = axis_>=0? axis_: axis_ + inputDims.size();
+
 		I* tempOutArrayPtr = tempOutArray.data();
-	
-        std::size_t axis = std::get<0>(attrs)[i];
-        std::size_t nbElemAfterAxis = 1;
-        std::size_t nbElemBeforeAxis = 1;
 
-        for (size_t d = 0; d < currentDims.size(); ++d) {
-			if (d > axis)
-            	nbElemAfterAxis *= currentDims[d];
-			else if (d < axis)
-				nbElemBeforeAxis *= currentDims[d];
-        }
+		std::size_t postAxisElems = 1;
+		for (std::size_t d = axis + 1; d < inputDims.size(); ++d) {
+			postAxisElems *= inputDims[d];
+		}
+		std::size_t preAxisElems = 1;
+		for (std::size_t d = 0; d < axis; ++d) {
+			preAxisElems *= inputDims[d];
+		}
 
-        for (std::size_t j=0; j<nbElemBeforeAxis; ++j)
+        for (std::size_t j=0; j<preAxisElems; ++j)
         {
-            for (std::size_t k=0; k<nbElemAfterAxis; ++k)
+            for (std::size_t k=0; k<postAxisElems; ++k)
             {
 				// Compute the mean value for the element k of each stride
                 I mean = 0;
                 for(std::size_t l=0; l<currentDims[axis];l++)
                 {
-                        size_t idx = j * (nbElemAfterAxis * currentDims[axis]) + l * nbElemAfterAxis + k;
-                        mean += tempInArray[idx];
+					size_t idx = j * (postAxisElems * currentDims[axis]) + l * postAxisElems + k;
+					mean += tempInArray[idx];
                 }
                 tempOutArrayPtr[addedElems] = mean / currentDims[axis];
                 addedElems++;
@@ -78,11 +78,10 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs&
 
         // Update the input for the next reduce operation
         tempInArray.assign(tempOutArray.begin(), tempOutArray.begin() + addedElems);
-		if(keepDims)
-        	currentDims[axis] = 1;
-		else
+        if(keepDims)
+			currentDims[axis] = 1;
+        else if (currentDims.size()>1)
 			currentDims.erase(currentDims.begin()+axis);
-
     }
 	std::copy_n(tempInArray.data(), addedElems, output);
 }
diff --git a/unit_tests/operator/Test_ReduceMeanImpl.cpp b/unit_tests/operator/Test_ReduceMeanImpl.cpp
index ff21fc63ede6d89e85daa05bad1e08c944ce63cb..6a5166c7b22784554f86c9b66c213f321eec4fc6 100644
--- a/unit_tests/operator/Test_ReduceMeanImpl.cpp
+++ b/unit_tests/operator/Test_ReduceMeanImpl.cpp
@@ -112,15 +112,11 @@ TEST_CASE("[cpu/operator] ReduceMean(forward)") {
                 }
             }
         });
-        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array3D<float,1,1,1> {
-            { 
-                {
-                    {18.25}
-                }
-            }
+        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array1D<float,1> {
+            {18.25}
         });
 
-        std::shared_ptr<Node> myReduceMean = ReduceMean({0, 1, 2});
+        std::shared_ptr<Node> myReduceMean = ReduceMean({0, 1, 2}, 0);
         auto op = std::static_pointer_cast<OperatorTensor>(myReduceMean -> getOperator());
         op->associateInput(0,myInput);
         op->setDataType(DataType::Float32);
@@ -128,7 +124,7 @@ TEST_CASE("[cpu/operator] ReduceMean(forward)") {
         op->computeOutputDims();
         myReduceMean->forward();
         op->getOutput(0)->print();
-
+    
         REQUIRE(*(op->getOutput(0)) == *myOutput);
     }
 }
\ No newline at end of file