diff --git a/include/aidge/backend/cpu/operator/ConcatImpl.hpp b/include/aidge/backend/cpu/operator/ConcatImpl.hpp index 84020f5d53a13459441104650136912ce2e0123b..6db0045792a9db7742a97e4a0ed7f43ebfa2cc09 100644 --- a/include/aidge/backend/cpu/operator/ConcatImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConcatImpl.hpp @@ -24,10 +24,10 @@ namespace Aidge { // compute kernel registry for forward and backward class ConcatImplForward_cpu - : public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<DimSize_t>, const std::vector<void*>, void*)> { + : public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const std::vector<void*>, void*)> { }; class ConcatImplBackward_cpu - : public Registrable<ConcatImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const std::vector<DimSize_t>, const std::vector<void*>, void*)> { + : public Registrable<ConcatImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const std::vector<void*>, void*)> { }; class ConcatImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp index c962f6dae022de9757b6042be995cc37fb16bc3b..99825e10f6c3789687e64a3e45f08afb0cf43d4e 100644 --- a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp @@ -22,7 +22,7 @@ namespace Aidge { template <class I, class O> -void ConcatImpl_cpu_forward_kernel(const int& axisIdx, std::vector<DimSize_t> arraysDims, const std::vector<void*> input_, void* output_) +void ConcatImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const std::vector<DimSize_t>& dimsOnAxis, const std::vector<void*> input_, void* output_) { O* output = static_cast<O*>(output_); std::vector<I*> input; @@ -31,27 +31,25 @@ void ConcatImpl_cpu_forward_kernel(const int& axisIdx, std::vector<DimSize_t> ar input.emplace_back(static_cast<I*>(elem)); } - // compute length of chunks to copy from each input tensor - size_t chunkSize = 1; - size_t totalTensorSize = 1; - for(size_t i=arraysDims.size()-1; i>0 ; --i) - { - if(i >= axisIdx) - chunkSize *= arraysDims[i]; - totalTensorSize *= arraysDims[i]; + std::size_t postAxisElems = 1; + for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) { + postAxisElems *= inputDims[i]; + } + std::size_t preAxisElems = 1; + for (std::size_t i = 0; i < axisIdx; ++i) { + preAxisElems *= inputDims[i]; } - size_t iterationsCount = totalTensorSize / chunkSize; - for(size_t i=0; i<iterationsCount ; ++i) + for(std::size_t i=0; i<preAxisElems ; ++i) { - for(size_t j=0; j < input.size(); ++j) + for(std::size_t j=0; j < input.size(); ++j) { - I* copyPtr = std::next(input[j], i * chunkSize); - std::copy_n(copyPtr, chunkSize, output); - output += chunkSize; - } + std::size_t strideOnAxis = postAxisElems * dimsOnAxis[j]; + const I* copyPtr = std::next(input[j], i * strideOnAxis); + std::copy_n(copyPtr, strideOnAxis, output); + output += strideOnAxis; + } } - } namespace { diff --git a/src/operator/ConcatImpl.cpp b/src/operator/ConcatImpl.cpp index 18bf031ad35679e8611ed5132f1cb0fd7e352872..d7e5167c6b509bea0a5d20e5ffb7e72a1188fa16 100644 --- a/src/operator/ConcatImpl.cpp +++ b/src/operator/ConcatImpl.cpp @@ -21,31 +21,33 @@ #include "aidge/backend/cpu/operator/ConcatImpl.hpp" #include "aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp" -#include <iostream> - Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place return 0; } void Aidge::ConcatImpl_cpu::forward() { - assert(mOp.getInput(0) && "missing input #0"); - + for (std::size_t i = 0; i < dynamic_cast<const Concat_Op&>(mOp).mNbIn; ++i) { + assert(mOp.getInput(i) && ("missing input #"+std::to_string(i)).c_str()); + } Concat_Op::Attrs attr = dynamic_cast<const Concat_Op&>(mOp).getStaticAttributes(); - const int& axisIdx = static_cast<const int&>(std::get<0>(attr)); - assert(mOp.getInput(0)->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx) + std::size_t axisIdx = static_cast<const int&>(std::get<0>(attr)); + assert(mOp.getInput(0)->nbDims() > axisIdx && ("input dim must be bigger than "+std::to_string(axisIdx)).c_str()); auto kernelFunc = Registrar<ConcatImplForward_cpu>::create({ mOp.getInput(0)->dataType(), mOp.getOutput(0)->dataType()}); - // Call kernel std::vector<void*> inputTensors; + std::vector<std::size_t> dimsOnAxis; for (std::size_t i = 0; i < dynamic_cast<const Concat_Op&>(mOp).mNbIn; ++i) { inputTensors.push_back(mOp.getInput(i)->getImpl()->rawPtr()); + dimsOnAxis.push_back(mOp.getInput(i)->dims()[axisIdx]); } + // Call kernel kernelFunc(axisIdx, mOp.getInput(0)->dims(), + dimsOnAxis, inputTensors, mOp.getOutput(0)->getImpl()->rawPtr()); } diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp index 9bda07d1423fbbee3ce757d3ba4ac40948e605ec..b5222810ee602e703efca62fd552ec0763d257ec 100644 --- a/unit_tests/operator/Test_ConcatImpl.cpp +++ b/unit_tests/operator/Test_ConcatImpl.cpp @@ -21,7 +21,7 @@ using namespace Aidge; TEST_CASE("[cpu/operator] Concat(forward)") { - SECTION("2D Tensor") { + SECTION("2D Tensors") { std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array2D<float,2,2> { { {0.00543531, 0.53726782}, @@ -41,7 +41,7 @@ TEST_CASE("[cpu/operator] Concat(forward)") { } }); - std::shared_ptr<Node> myConcat = Concat(1); + std::shared_ptr<Node> myConcat = Concat(1, 2); myConcat->getOperator()->setDatatype(DataType::Float32); myConcat->getOperator()->setBackend("cpu"); myConcat->getOperator()->associateInput(0,input1); @@ -49,11 +49,57 @@ TEST_CASE("[cpu/operator] Concat(forward)") { myConcat->getOperator()->computeOutputDims(); myConcat->forward(); - float* resPtr = static_cast<float*>(myConcat->getOperator()->getOutput(0)->getImpl()->rawPtr()); - float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr()); - for (std::size_t i = 0; i< 3; ++i) { - REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001); - } + REQUIRE(*(myConcat->getOperator()->getOutput(0)) == *expectedOutput); + } + SECTION("3D Tensors") { + std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array3D<int,2,1,3> { + { + { + {1, 2, 3} + }, + { + {4, 5, 6} + } + } + }); + std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array3D<int,2,2,3> { + { + { + {10, 11, 12}, + {13, 14, 15} + }, + { + {16, 17, 18}, + {19, 20, 21} + } + } + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,2,4,3> { + { + { + { 1, 2, 3 }, + { 10, 11, 12 }, + { 13, 14, 15 }, + { 1, 2, 3 } + }, + { + { 4, 5, 6 }, + { 16, 17, 18 }, + { 19, 20, 21 }, + { 4, 5, 6 } + } + } + }); + + std::shared_ptr<Node> myConcat = Concat(1, 3); + myConcat->getOperator()->setDatatype(DataType::Int32); + myConcat->getOperator()->setBackend("cpu"); + myConcat->getOperator()->associateInput(0,input1); + myConcat->getOperator()->associateInput(1,input2); + myConcat->getOperator()->associateInput(2,input1); + myConcat->getOperator()->computeOutputDims(); + myConcat->forward(); + REQUIRE(*(myConcat->getOperator()->getOutput(0)) == *expectedOutput); } } \ No newline at end of file