From a724058e5fe6ee278c918a035dcfa2b222665072 Mon Sep 17 00:00:00 2001
From: Antoni Olivier <olivier.antoni@cea.fr>
Date: Mon, 17 Jun 2024 17:23:10 +0200
Subject: [PATCH] Add log for BCE loss function

---
 include/aidge/backend/cpu/operator/LnImpl.hpp |  2 +-
 .../cpu/operator/LnImpl_backward_kernels.hpp  | 19 ++++++++++++-------
 .../cpu/operator/LnImpl_forward_kernels.hpp   |  9 ++++++++-
 src/operator/LnImpl.cpp                       |  3 ++-
 4 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/LnImpl.hpp b/include/aidge/backend/cpu/operator/LnImpl.hpp
index faa03855..31d1ff79 100755
--- a/include/aidge/backend/cpu/operator/LnImpl.hpp
+++ b/include/aidge/backend/cpu/operator/LnImpl.hpp
@@ -28,7 +28,7 @@ class LnImplForward_cpu
     : public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
 };
 class LnImplBackward_cpu
-    : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
+    : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
 };
 
 class LnImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp
index 70e34cf3..eb0e3a70 100755
--- a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp
@@ -18,27 +18,32 @@
 #include "aidge/utils/Registrar.hpp"
 
 namespace Aidge {
-template <class I, class GI, class GO>
+template <class I, class O, class GI, class GO>
 void LnImpl_cpu_backward_kernel(const std::size_t inputLenght,
-                                const void* input_, const void* grad_output_,
+                                const void* input_, const void* output_, const void* grad_output_,
 	                        void* grad_input_) {
 									 
     const I* input = static_cast<const I*>(input_);
+	const O* output = static_cast<const O*>(output_);
     const GO* grad_output = static_cast<const GO*>(grad_output_);
     GI* grad_input = static_cast<GI*>(grad_input_);
 	
     for (std::size_t i = 0; i < inputLenght; ++i) {
-        grad_input[i] = grad_output[i] / input[i];
+		if (output[i] > O(-100)) {
+			grad_input[i] = grad_output[i] / input[i];
+		} else {
+			grad_input[i] = O(0);
+		}
     }
 }
 
 namespace {
 static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32(
-    {DataType::Float32, DataType::Float32, DataType::Float32},
-    Aidge::LnImpl_cpu_backward_kernel<float, float, float>);	
+    {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
+    Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>);	
 static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64(
-    {DataType::Float64, DataType::Float64, DataType::Float64},
-    Aidge::LnImpl_cpu_backward_kernel<double, double, double>);
+    {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
+    Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>);
 }  // namespace
 }  // namespace Aidge
 
diff --git a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp
index c7681730..1bc27332 100755
--- a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp
@@ -27,7 +27,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght,
 
 //#pragma omp parallel for if (inputLenght > 1024)
     for (std::size_t i = 0; i < inputLenght; ++i) {
-        output[i] = std::log(input[i]);
+		if (input[i] > I(0)) {
+			output[i] = std::log(input[i]);
+			if (output[i] < O(-100)) {
+				output[i] = O(-100);
+			}
+		} else {
+			output[i] = O(-100);
+		}
     }
 }
 
diff --git a/src/operator/LnImpl.cpp b/src/operator/LnImpl.cpp
index 12885a94..b4c82473 100644
--- a/src/operator/LnImpl.cpp
+++ b/src/operator/LnImpl.cpp
@@ -56,10 +56,11 @@ void Aidge::LnImpl_cpu::backward() {
     // Find the correct kernel type
     auto kernelFunc = Registrar<LnImplBackward_cpu>::create({
         in0->dataType(),
+		out0->dataType(),
 	    gra_int0->dataType(),
         gra_out0->dataType()        
     });
 
     // Call kernel
-    kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
+    kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
 }
-- 
GitLab