From e31dca3ae146db69cccc3a29fcff4e8e9e76880e Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Thu, 23 Nov 2023 15:30:24 +0100 Subject: [PATCH] add axis attr to Softmax --- .../backend/cpu/operator/SoftmaxImpl.hpp | 4 +- .../operator/SoftmaxImpl_forward_kernels.hpp | 39 ++++++++++--------- src/operator/SoftmaxImpl.cpp | 12 +++--- unit_tests/operator/Test_SoftmaxImpl.cpp | 9 ++--- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp index 995f57f7..2b15eb36 100644 --- a/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp +++ b/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp @@ -24,10 +24,10 @@ namespace Aidge { // compute kernel registry for forward and backward class SoftmaxImplForward_cpu - : public Registrable<SoftmaxImplForward_cpu, std::tuple<DataType, DataType>, void(const DimSize_t, const DimSize_t, const DimSize_t, const void*, void*)> { + : public Registrable<SoftmaxImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const void*, void*)> { }; class SoftmaxImplBackward_cpu - : public Registrable<SoftmaxImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { + : public Registrable<SoftmaxImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const void*, void*)> { }; class SoftmaxImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp index 297a3a32..fb264afd 100644 --- a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp @@ -22,30 +22,33 @@ namespace Aidge { template <class I, class O> -void SoftmaxImpl_cpu_forward_kernel(const DimSize_t batchSize, - const DimSize_t channelSize, - const DimSize_t featureSize, - const void* input_, - void* output_) { - +void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_) +{ const I* input = static_cast<const I*>(input_); O* output = static_cast<O*>(output_); - for (std::size_t batch = 0; batch < batchSize; ++batch) { - for (std::size_t feature = 0; feature < featureSize; ++feature) { - std::size_t ioIndex = batch*channelSize*featureSize + feature; + std::size_t postAxisElems = 1; + for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) { + postAxisElems *= inputDims[i]; + } + std::size_t preAxisElems = 1; + for (std::size_t i = 0; i < axisIdx; ++i) { + preAxisElems *= inputDims[i]; + } - I sum(0.0); - for (std::size_t ch = 0; ch < channelSize; ++ch) { - output[ioIndex] = std::exp(input[ioIndex]); - sum += output[ioIndex]; - ioIndex+=featureSize; + for (std::size_t i = 0; i < preAxisElems; ++i) { + for (std::size_t j = 0; j < postAxisElems; ++j) { + // Calculate sum of exponentials within the axis + I sumExp = 0; + for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { + std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; + sumExp += std::exp(input[inIdx]); } - ioIndex = batch*channelSize*featureSize + feature; - for (std::size_t ch = 0; ch < channelSize; ++ch) { - output[ioIndex] /= sum; - ioIndex += featureSize; + // Calculate softmax for the current slice along the axis + for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { + std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; + output[inIdx] = std::exp(input[inIdx]) / sumExp; } } } diff --git a/src/operator/SoftmaxImpl.cpp b/src/operator/SoftmaxImpl.cpp index 45b455a3..ae89090a 100644 --- a/src/operator/SoftmaxImpl.cpp +++ b/src/operator/SoftmaxImpl.cpp @@ -28,20 +28,18 @@ Aidge::NbElts_t Aidge::SoftmaxImpl_cpu::getNbRequiredProtected(const Aidge::IOIn void Aidge::SoftmaxImpl_cpu::forward() { assert(mOp.getInput(0) && "missing input #0"); - assert(mOp.getInput(0)->nbDims()>1); + // assert(mOp.getInput(0)->nbDims()>1); // Find the correct kernel type auto kernelFunc = Registrar<SoftmaxImplForward_cpu>::create({ mOp.getInput(0)->dataType(), mOp.getOutput(0)->dataType()}); - DimSize_t batchSize = mOp.getInput(0)->dims()[0]; - DimSize_t channelSize = mOp.getInput(0)->dims()[1]; - DimSize_t featureSize = mOp.getInput(0)->sizeM1()/channelSize; + Softmax_Op::Attrs attr = dynamic_cast<const Softmax_Op&>(mOp).getStaticAttributes(); + const int& axisIdx = static_cast<const int&>(std::get<0>(attr)); // Call kernel - kernelFunc(batchSize, - channelSize, - featureSize, + kernelFunc(axisIdx, + mOp.getInput(0)->dims(), mOp.getInput(0)->getImpl()->rawPtr(), mOp.getOutput(0)->getImpl()->rawPtr()); } diff --git a/unit_tests/operator/Test_SoftmaxImpl.cpp b/unit_tests/operator/Test_SoftmaxImpl.cpp index bad34102..81b73786 100644 --- a/unit_tests/operator/Test_SoftmaxImpl.cpp +++ b/unit_tests/operator/Test_SoftmaxImpl.cpp @@ -39,7 +39,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") { } }); - std::shared_ptr<Node> mySoftmax = Softmax(); + std::shared_ptr<Node> mySoftmax = Softmax(1); mySoftmax->getOperator()->setDatatype(DataType::Float32); mySoftmax->getOperator()->setBackend("cpu"); mySoftmax->getOperator()->associateInput(0,input); @@ -48,7 +48,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") { float* resPtr = static_cast<float*>(mySoftmax->getOperator()->getOutput(0)->getImpl()->rawPtr()); float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr()); - for (std::size_t i = 0; i< 20; ++i) { + for (std::size_t i = 0; i< expectedOutput->size(); ++i) { REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001); } @@ -107,7 +107,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") { } }); - std::shared_ptr<Node> mySoftmax = Softmax(); + std::shared_ptr<Node> mySoftmax = Softmax(1); mySoftmax->getOperator()->setDatatype(DataType::Float32); mySoftmax->getOperator()->setBackend("cpu"); mySoftmax->getOperator()->associateInput(0,input); @@ -116,9 +116,8 @@ TEST_CASE("[cpu/operator] Softmax(forward)") { float* resPtr = static_cast<float*>(mySoftmax->getOperator()->getOutput(0)->getImpl()->rawPtr()); float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr()); - for (std::size_t i = 0; i< 54; ++i) { + for (std::size_t i = 0; i< expectedOutput->size(); ++i) { REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001); } - // REQUIRE(*mySoftmax->getOperator()->getOutput(0) == *expectedOutput); } } \ No newline at end of file -- GitLab