diff --git a/include/aidge/backend/cpu/operator/DivImpl.hpp b/include/aidge/backend/cpu/operator/DivImpl.hpp index 73809ee81e26fff23e40763405857ddd2c95db0c..9d1614018dd34374f4324b2d9708f1525a23b0e5 100644 --- a/include/aidge/backend/cpu/operator/DivImpl.hpp +++ b/include/aidge/backend/cpu/operator/DivImpl.hpp @@ -25,10 +25,10 @@ namespace Aidge { // compute kernel registry for forward and backward class DivImplForward_cpu - : public Registrable<DivImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*,void*)> { + : public Registrable<DivImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { }; class DivImplBackward_cpu - : public Registrable<DivImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*, void*)> { + : public Registrable<DivImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { }; class DivImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp index e2ead9ca8de3ed8328b659906336766fbfbb6a47..494fb5ad37d66f91748810f6cb35d2c8ab8448cb 100644 --- a/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp @@ -14,37 +14,36 @@ #include "aidge/utils/Registrar.hpp" +#include "aidge/backend/cpu/data/Broadcasting.hpp" #include "aidge/backend/cpu/operator/DivImpl.hpp" namespace Aidge { template <class I1, class I2, class O> -void DivImpl_cpu_forward_kernel(std::size_t input1Length, - std::size_t input2Length, - const void* input1_, - const void* input2_, - void* output_) { +void DivImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims, + const std::vector<std::size_t>& input2Dims, + const std::vector<std::size_t>& outputDims, + const void* input1_, + const void* input2_, + void* output_) { const I1* input_1 = static_cast<const I1*>(input1_); const I2* input_2 = static_cast<const I2*>(input2_); O* output = static_cast<O*>(output_); - if (input2Length == input1Length) - { - for (std::size_t i = 0; i < input1Length; ++i) { - output[i] = input_1[i] / input_2[i]; - } - } - else if (input2Length == 1) - { - for (std::size_t i = 0; i < input1Length; ++i) { - output[i] = input_1[i] / input_2[0]; - } + + size_t totalElements = 1; + for (size_t dimSize : outputDims) { + totalElements *= dimSize; } - else // input_2 is 1d and of size the number of channels of input_1 - { - for (std::size_t i = 0; i < input1Length; ++i) { - std::size_t channelIdx = i % input2Length; - output[i] = input_1[i] / input_2[channelIdx]; - } + + for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) + { + std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex); + + std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); + std::size_t idx2 = getFlattenedIndex(input2Dims, indexes); + + // TODO assert if input_2 is bad? + output[oIndex] = input_1[idx1] / input_2[idx2]; } } diff --git a/src/operator/DivImpl.cpp b/src/operator/DivImpl.cpp index f5cde077bd5a414d8b9add8b8b8715952a27ad01..fc6207cf72e7005716c520ce3d54612c7e814022 100644 --- a/src/operator/DivImpl.cpp +++ b/src/operator/DivImpl.cpp @@ -17,6 +17,7 @@ #include "aidge/operator/Div.hpp" #include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/Broadcasting.hpp" #include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/DivImpl.hpp" @@ -34,9 +35,15 @@ void Aidge::DivImpl_cpu::forward() { std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()); + const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), + std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims()); + // Call kernel - kernelFunc(std::static_pointer_cast<Tensor>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)))->size(), - std::static_pointer_cast<Tensor>(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)))->size(), + kernelFunc(inputDims0, + inputDims1, + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(1)), getCPUPtr(mOp.getRawOutput(0)));