diff --git a/include/aidge/backend/cpu/operator/ConcatImpl.hpp b/include/aidge/backend/cpu/operator/ConcatImpl.hpp index 1d9315e49195ce8ee6501f262025222590491860..880a2e6635ba8ab2ded6f934bf6eb2e3f6d38d5b 100644 --- a/include/aidge/backend/cpu/operator/ConcatImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConcatImpl.hpp @@ -26,12 +26,14 @@ namespace Aidge { class ConcatImplForward_cpu : public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(const Concat_Op::Attrs&, const std::vector<DimSize_t>, + const std::vector<DimSize_t>&, const std::vector<const void*>, void*)> {}; class ConcatImplBackward_cpu : public Registrable<ConcatImplBackward_cpu, std::tuple<DataType, DataType>, void(const Concat_Op::Attrs&, const std::vector<DimSize_t>, + const std::vector<DimSize_t>&, const std::vector<const void*>, void*)> {}; diff --git a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp index 4e87e3695e2835d7611e37d05c4f199b0e52ac34..e67419d382e3a16cd48fe65289cd1c2b5922efd6 100644 --- a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp @@ -13,6 +13,7 @@ #define AIDGE_CPU_OPERATOR_CONCATIMPL_FORWARD_KERNEL_H_ #include <algorithm> +#include <numeric> #include <cstddef> #include <vector> @@ -26,8 +27,9 @@ namespace Aidge { template <class I, class O> void ConcatImpl_cpu_forward_kernel(const Concat_Op::Attrs& attrs, - const std::vector<DimSize_t> dimsFirstInput, - const std::vector<const void*> inputs_, + const std::vector<DimSize_t>& dimsFirstInput, + const std::vector<DimSize_t>& concatAxisValues, + const std::vector<const void*>& inputs_, void* output_) { // FIXME: missing Concat attributes as arguments @@ -37,6 +39,8 @@ void ConcatImpl_cpu_forward_kernel(const Concat_Op::Attrs& attrs, } O* output = static_cast<O*>(output_); + DimSize_t outputAxisValue = std::accumulate(concatAxisValues.begin(), concatAxisValues.end(), 0); + DimSize_t prodDimLower = 1; for (DimIdx_t i = 0; i < std::get<1>(attrs); ++i) { prodDimLower *= dimsFirstInput[i]; @@ -47,13 +51,16 @@ void ConcatImpl_cpu_forward_kernel(const Concat_Op::Attrs& attrs, prodDimHigher *= dimsFirstInput[i]; } + std::size_t oIndexStart = 0; std::size_t oIndex = 0; for (std::size_t inputId = 0; inputId < inputs.size(); ++inputId) { + oIndex = oIndexStart; + const DimSize_t iOffset = prodDimHigher*concatAxisValues[inputId]; for (std::size_t iIndex = 0; iIndex < prodDimLower; ++iIndex) { - std::copy(inputs[inputId] + iIndex, inputs[inputId] + iIndex + prodDimHigher, - output + oIndex); - oIndex += prodDimHigher; + std::copy(inputs[inputId] + iIndex*iOffset, inputs[inputId] + (iIndex+1)*iOffset, output + oIndex); + oIndex += prodDimHigher*outputAxisValue; } + oIndexStart += concatAxisValues[inputId]*prodDimHigher; } } diff --git a/src/operator/ConcatImpl.cpp b/src/operator/ConcatImpl.cpp index 0a98f90f6582ee9e75519ce9cf4a52402b9065cb..5ad46f2ab3e0503445a0500207e2773e3321ff5f 100644 --- a/src/operator/ConcatImpl.cpp +++ b/src/operator/ConcatImpl.cpp @@ -73,12 +73,15 @@ void Aidge::ConcatImpl_cpu::forward() { mOp.getOutput(0)->dataType()}); std::vector<const void*> opInputs; + std::vector<DimSize_t> opInputAxis; for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) { opInputs.push_back(mOp.getInput(i)->getImpl()->rawPtr()); + opInputAxis.push_back(mOp.getInput(i)->dims()[mOp.template getAttr<DimSize_t>("Axis")]); } kernelFunc(mOp.getStaticAttributes(), mOp.getInput(0)->dims(), + opInputAxis, opInputs, mOp.getOutput(0)->getImpl()->rawPtr()); } diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d46b7118c7905fd5ae2a9d413eaff51a97c7ed51 --- /dev/null +++ b/unit_tests/operator/Test_ConcatImpl.cpp @@ -0,0 +1,122 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Add.hpp" + +#include "aidge/backend/cpu.hpp" + +using namespace Aidge; + +TEST_CASE("[cpu/operator] Concat(forward)", "[Concat]") { + SECTION("Concat 4D inputs on 1st axis") { + std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> { + { // + { // + {{20, 47},{21, 48},{22, 49}}, // + {{23, 50},{24, 51},{25, 52}}, // + {{26, 53},{27, 54},{28, 55}} // + }, // + } // + }); // + std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array4D<int,2,3,3,2> { + { + { // + {{29, 56},{30, 57},{31, 58}}, // + {{32, 59},{33, 60},{34, 61}}, // + {{35, 62},{36, 63},{37, 64}} // + }, // + { // + {{38, 65},{39, 66},{40, 67}}, // + {{41, 68},{42, 69},{43, 70}}, // + {{44, 71},{45, 72},{46, 73}} // + } // + } // + }); // + + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> { + { // + { // + {{20, 47},{21, 48},{22, 49}}, // + {{23, 50},{24, 51},{25, 52}}, // + {{26, 53},{27, 54},{28, 55}} // + }, // + { // + {{29, 56},{30, 57},{31, 58}}, // + {{32, 59},{33, 60},{34, 61}}, // + {{35, 62},{36, 63},{37, 64}} // + }, // + { // + {{38, 65},{39, 66},{40, 67}}, // + {{41, 68},{42, 69},{43, 70}}, // + {{44, 71},{45, 72},{46, 73}} // + } // + } // + }); // + + auto myConcat = Concat(2, 0); + myConcat->getOperator()->setBackend("cpu"); + myConcat->getOperator()->setDatatype(DataType::Int32); + myConcat->getOperator()->associateInput(0, input1); + myConcat->getOperator()->associateInput(1, input2); + myConcat->getOperator()->computeOutputDims(); + myConcat->forward(); + + myConcat->getOperator()->getOutput(0)->print(); + + REQUIRE(myConcat->getOperator()->output(0) == *expectedOutput); + } + + SECTION("Concat 4D inputs on 3rd axis") { + std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> { + { // + { // + {{20, 47},{21, 48},{22, 49}}, // + {{23, 50},{24, 51},{25, 52}}, // + {{26, 53},{27, 54},{28, 55}} // + }, // + } // + }); // + std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array4D<int,1,3,6,2> { + { + { // + {{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}}, // + {{32, 59},{33, 60},{34, 61},{41, 68},{42, 69},{43, 70}}, // + {{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}} // + }, + } + }); + + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,9,2> { + { // + { // + {{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}}, // + {{23, 50},{24, 51},{25, 52},{32, 59},{33, 60},{34, 61},{41, 68},{42, 69},{43, 70}}, // + {{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}} // + }, // + } // + }); // + + auto myConcat = Concat(2, 2); + myConcat->getOperator()->setBackend("cpu"); + myConcat->getOperator()->setDatatype(DataType::Int32); + myConcat->getOperator()->associateInput(0, input1); + myConcat->getOperator()->associateInput(1, input2); + myConcat->getOperator()->computeOutputDims(); + myConcat->forward(); + + myConcat->getOperator()->getOutput(0)->print(); + + REQUIRE(myConcat->getOperator()->output(0) == *expectedOutput); + } +} \ No newline at end of file