diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp index d3cafa7e7380e31dd331950e381e08210c3f3a4c..c6e4cd36746141d7f1d1092c9bd45af41d8a9173 100644 --- a/include/aidge/backend/cpu/operator/PowImpl.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl.hpp @@ -25,10 +25,10 @@ namespace Aidge { // compute kernel registry for forward and backward class PowImplForward_cpu - : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*,void*)> { + : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { }; class PowImplBackward_cpu - : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*, void*)> { + : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { }; class PowImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp index c9c5db7e9aef07d24ba8f80c94b8f2494865e004..1146cfa77464f8bd1c33a0ec0113415dcf599b53 100644 --- a/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp @@ -15,39 +15,36 @@ #include "aidge/utils/Registrar.hpp" #include <cmath> +#include "aidge/backend/cpu/data/Broadcasting.hpp" #include "aidge/backend/cpu/operator/PowImpl.hpp" namespace Aidge { template <class I1, class I2, class O> -void PowImpl_cpu_forward_kernel(std::size_t input1Length, - std::size_t input2Length, - const void* input1_, - const void* input2_, - void* output_) { +void PowImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims, + const std::vector<std::size_t>& input2Dims, + const std::vector<std::size_t>& outputDims, + const void* input1_, + const void* input2_, + void* output_) { const I1* input_1 = static_cast<const I1*>(input1_); const I2* input_2 = static_cast<const I2*>(input2_); O* output = static_cast<O*>(output_); - if (input2Length == input1Length) - { - for (std::size_t i = 0; i < input1Length; ++i) { - output[i] = std::pow(input_1[i], input_2[i]); - } - } - else if (input2Length == 1) - { - for (std::size_t i = 0; i < input1Length; ++i) { - output[i] = std::pow(input_1[i], input_2[0]); - } - } - else // input_2 is 1d and of size the number of channels of input_1 - { - for (std::size_t i = 0; i < input1Length; ++i) { - std::size_t channelIdx = i % input2Length; - output[i] = std::pow(input_1[i], input_2[channelIdx]); - } + size_t totalElements = 1; + for (size_t dimSize : outputDims) { + totalElements *= dimSize; } + + for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) + { + std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex); + + std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); + std::size_t idx2 = getFlattenedIndex(input2Dims, indexes); + + output[oIndex] = std::pow(input_1[idx1], input_2[idx2]); + } } namespace { diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp index 496646402e33869cfcbe7dae96e1fc81b875d0dd..22b4e27afd4e327c42be066bf7eeb6effdd8b2a9 100644 --- a/src/operator/PowImpl.cpp +++ b/src/operator/PowImpl.cpp @@ -17,6 +17,7 @@ #include "aidge/operator/Pow.hpp" #include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/Broadcasting.hpp" #include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/PowImpl.hpp" @@ -34,9 +35,15 @@ void Aidge::PowImpl_cpu::forward() { std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()); + const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims()); + // Call kernel - kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(), + kernelFunc(inputDims0, + inputDims1, + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(1)), getCPUPtr(mOp.getRawOutput(0)));