diff --git a/include/aidge/backend/cpu/operator/ConcatImpl.hpp b/include/aidge/backend/cpu/operator/ConcatImpl.hpp
index 84020f5d53a13459441104650136912ce2e0123b..6db0045792a9db7742a97e4a0ed7f43ebfa2cc09 100644
--- a/include/aidge/backend/cpu/operator/ConcatImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ConcatImpl.hpp
@@ -24,10 +24,10 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class ConcatImplForward_cpu
-    : public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<DimSize_t>, const std::vector<void*>, void*)> {
+    : public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const std::vector<void*>, void*)> {
 };
 class ConcatImplBackward_cpu
-    : public Registrable<ConcatImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<DimSize_t>, const std::vector<void*>, void*)> {
+    : public Registrable<ConcatImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const std::vector<void*>, void*)> {
 };
 
 class ConcatImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp
index c962f6dae022de9757b6042be995cc37fb16bc3b..99825e10f6c3789687e64a3e45f08afb0cf43d4e 100644
--- a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp
@@ -22,7 +22,7 @@
 
 namespace Aidge {
 template <class I, class O>
-void ConcatImpl_cpu_forward_kernel(const int& axisIdx, std::vector<DimSize_t> arraysDims, const std::vector<void*> input_, void* output_)
+void ConcatImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const std::vector<DimSize_t>& dimsOnAxis, const std::vector<void*> input_, void* output_)
 {
     O* output = static_cast<O*>(output_);
     std::vector<I*> input;
@@ -31,27 +31,25 @@ void ConcatImpl_cpu_forward_kernel(const int& axisIdx, std::vector<DimSize_t> ar
         input.emplace_back(static_cast<I*>(elem));
     }
 
-    // compute length of chunks to copy from each input tensor
-    size_t chunkSize = 1;
-    size_t totalTensorSize = 1;
-    for(size_t i=arraysDims.size()-1; i>0 ; --i)
-    {
-        if(i >= axisIdx)
-			chunkSize *= arraysDims[i];
-		totalTensorSize *= arraysDims[i];
+    std::size_t postAxisElems = 1;
+    for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
+        postAxisElems *= inputDims[i];
+    }
+    std::size_t preAxisElems = 1;
+    for (std::size_t i = 0; i < axisIdx; ++i) {
+        preAxisElems *= inputDims[i];
     }
 
-	size_t iterationsCount = totalTensorSize / chunkSize;
-	for(size_t i=0; i<iterationsCount ; ++i)
+    for(std::size_t i=0; i<preAxisElems ; ++i)
     {
-		for(size_t j=0; j < input.size(); ++j)
+		for(std::size_t j=0; j < input.size(); ++j)
 		{
-			I* copyPtr = std::next(input[j], i * chunkSize);
-			std::copy_n(copyPtr, chunkSize, output);
-			output += chunkSize;
-		}
+            std::size_t strideOnAxis = postAxisElems * dimsOnAxis[j];
+            const I* copyPtr = std::next(input[j], i * strideOnAxis);
+            std::copy_n(copyPtr, strideOnAxis, output);
+            output += strideOnAxis;
+	    }
 	}
-
 }
 
 namespace {
diff --git a/src/operator/ConcatImpl.cpp b/src/operator/ConcatImpl.cpp
index 18bf031ad35679e8611ed5132f1cb0fd7e352872..d7e5167c6b509bea0a5d20e5ffb7e72a1188fa16 100644
--- a/src/operator/ConcatImpl.cpp
+++ b/src/operator/ConcatImpl.cpp
@@ -21,31 +21,33 @@
 #include "aidge/backend/cpu/operator/ConcatImpl.hpp"
 #include "aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp"
 
-#include <iostream>
-
 Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
     // this implementation can be in-place
     return 0;
 }
 
 void Aidge::ConcatImpl_cpu::forward() {
-    assert(mOp.getInput(0) && "missing input #0");
-
+    for (std::size_t i = 0; i < dynamic_cast<const Concat_Op&>(mOp).mNbIn; ++i) {
+        assert(mOp.getInput(i) && ("missing input #"+std::to_string(i)).c_str());
+    }
     Concat_Op::Attrs attr = dynamic_cast<const Concat_Op&>(mOp).getStaticAttributes();
-    const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
-    assert(mOp.getInput(0)->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx)
+    std::size_t axisIdx = static_cast<const int&>(std::get<0>(attr));
+    assert(mOp.getInput(0)->nbDims() > axisIdx && ("input dim must be bigger than "+std::to_string(axisIdx)).c_str());
 
     auto kernelFunc = Registrar<ConcatImplForward_cpu>::create({
         mOp.getInput(0)->dataType(),
         mOp.getOutput(0)->dataType()});
 
-    // Call kernel
     std::vector<void*> inputTensors;
+    std::vector<std::size_t> dimsOnAxis;
     for (std::size_t i = 0; i < dynamic_cast<const Concat_Op&>(mOp).mNbIn; ++i) {
         inputTensors.push_back(mOp.getInput(i)->getImpl()->rawPtr());
+        dimsOnAxis.push_back(mOp.getInput(i)->dims()[axisIdx]);
     }
+    // Call kernel
     kernelFunc(axisIdx,
                mOp.getInput(0)->dims(),
+               dimsOnAxis,
                inputTensors,
                mOp.getOutput(0)->getImpl()->rawPtr());
 }
diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp
index 9bda07d1423fbbee3ce757d3ba4ac40948e605ec..b5222810ee602e703efca62fd552ec0763d257ec 100644
--- a/unit_tests/operator/Test_ConcatImpl.cpp
+++ b/unit_tests/operator/Test_ConcatImpl.cpp
@@ -21,7 +21,7 @@
 using namespace Aidge;
 
 TEST_CASE("[cpu/operator] Concat(forward)") {
-    SECTION("2D Tensor") {
+        SECTION("2D Tensors") {
         std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array2D<float,2,2> {
             {
                 {0.00543531, 0.53726782},
@@ -41,7 +41,7 @@ TEST_CASE("[cpu/operator] Concat(forward)") {
             }
         });
 
-        std::shared_ptr<Node> myConcat = Concat(1);
+        std::shared_ptr<Node> myConcat = Concat(1, 2);
         myConcat->getOperator()->setDatatype(DataType::Float32);
         myConcat->getOperator()->setBackend("cpu");
         myConcat->getOperator()->associateInput(0,input1);
@@ -49,11 +49,57 @@ TEST_CASE("[cpu/operator] Concat(forward)") {
         myConcat->getOperator()->computeOutputDims();
         myConcat->forward();
 
-        float* resPtr = static_cast<float*>(myConcat->getOperator()->getOutput(0)->getImpl()->rawPtr());
-        float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
-        for (std::size_t i = 0; i< 3; ++i) {
-            REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
-        }
+        REQUIRE(*(myConcat->getOperator()->getOutput(0)) == *expectedOutput);
+    }
+    SECTION("3D Tensors") {
+        std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array3D<int,2,1,3> {
+            {
+                {
+                    {1, 2, 3}
+                },
+                {
+                    {4, 5, 6}
+                }
+            }
+        });
+        std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array3D<int,2,2,3> {
+            {
+                {
+                    {10, 11, 12},
+                    {13, 14, 15}
+                },
+                {
+                    {16, 17, 18},
+                    {19, 20, 21}
+                }
+            }
+        });
+        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,2,4,3> {
+            {
+                {
+                    { 1, 2, 3 },
+                    { 10, 11, 12 },
+                    { 13, 14, 15 },
+                    { 1, 2, 3 }
+                },
+                {
+                    { 4, 5, 6 },
+                    { 16, 17, 18 },
+                    { 19, 20, 21 },
+                    { 4, 5, 6 }
+                }
+            }
+        });
+
+        std::shared_ptr<Node> myConcat = Concat(1, 3);
+        myConcat->getOperator()->setDatatype(DataType::Int32);
+        myConcat->getOperator()->setBackend("cpu");
+        myConcat->getOperator()->associateInput(0,input1);
+        myConcat->getOperator()->associateInput(1,input2);
+        myConcat->getOperator()->associateInput(2,input1);
+        myConcat->getOperator()->computeOutputDims();
+        myConcat->forward();
 
+        REQUIRE(*(myConcat->getOperator()->getOutput(0)) == *expectedOutput);
     }
 }
\ No newline at end of file