From a724058e5fe6ee278c918a035dcfa2b222665072 Mon Sep 17 00:00:00 2001 From: Antoni Olivier <olivier.antoni@cea.fr> Date: Mon, 17 Jun 2024 17:23:10 +0200 Subject: [PATCH] Add log for BCE loss function --- include/aidge/backend/cpu/operator/LnImpl.hpp | 2 +- .../cpu/operator/LnImpl_backward_kernels.hpp | 19 ++++++++++++------- .../cpu/operator/LnImpl_forward_kernels.hpp | 9 ++++++++- src/operator/LnImpl.cpp | 3 ++- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/include/aidge/backend/cpu/operator/LnImpl.hpp b/include/aidge/backend/cpu/operator/LnImpl.hpp index faa03855..31d1ff79 100755 --- a/include/aidge/backend/cpu/operator/LnImpl.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl.hpp @@ -28,7 +28,7 @@ class LnImplForward_cpu : public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { }; class LnImplBackward_cpu - : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> { + : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { }; class LnImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp index 70e34cf3..eb0e3a70 100755 --- a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp @@ -18,27 +18,32 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -template <class I, class GI, class GO> +template <class I, class O, class GI, class GO> void LnImpl_cpu_backward_kernel(const std::size_t inputLenght, - const void* input_, const void* grad_output_, + const void* input_, const void* output_, const void* grad_output_, void* grad_input_) { const I* input = static_cast<const I*>(input_); + const O* output = static_cast<const O*>(output_); const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); for (std::size_t i = 0; i < inputLenght; ++i) { - grad_input[i] = grad_output[i] / input[i]; + if (output[i] > O(-100)) { + grad_input[i] = grad_output[i] / input[i]; + } else { + grad_input[i] = O(0); + } } } namespace { static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32( - {DataType::Float32, DataType::Float32, DataType::Float32}, - Aidge::LnImpl_cpu_backward_kernel<float, float, float>); + {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>); static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64( - {DataType::Float64, DataType::Float64, DataType::Float64}, - Aidge::LnImpl_cpu_backward_kernel<double, double, double>); + {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>); } // namespace } // namespace Aidge diff --git a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp index c7681730..1bc27332 100755 --- a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp @@ -27,7 +27,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght, //#pragma omp parallel for if (inputLenght > 1024) for (std::size_t i = 0; i < inputLenght; ++i) { - output[i] = std::log(input[i]); + if (input[i] > I(0)) { + output[i] = std::log(input[i]); + if (output[i] < O(-100)) { + output[i] = O(-100); + } + } else { + output[i] = O(-100); + } } } diff --git a/src/operator/LnImpl.cpp b/src/operator/LnImpl.cpp index 12885a94..b4c82473 100644 --- a/src/operator/LnImpl.cpp +++ b/src/operator/LnImpl.cpp @@ -56,10 +56,11 @@ void Aidge::LnImpl_cpu::backward() { // Find the correct kernel type auto kernelFunc = Registrar<LnImplBackward_cpu>::create({ in0->dataType(), + out0->dataType(), gra_int0->dataType(), gra_out0->dataType() }); // Call kernel - kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); + kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); } -- GitLab