From 1580181fcacbc19779fae930443353fc06c7536f Mon Sep 17 00:00:00 2001
From: Antoni Olivier <olivier.antoni@cea.fr>
Date: Tue, 18 Jun 2024 10:11:09 +0200
Subject: [PATCH] Add log for BCE loss function

---
 include/aidge/backend/cpu/operator/LnImpl.hpp |  2 +-
 .../cpu/operator/LnImpl_backward_kernels.hpp  | 22 +++++++++----------
 .../cpu/operator/LnImpl_forward_kernels.hpp   |  8 +++----
 src/operator/LnImpl.cpp                       |  3 +--
 4 files changed, 16 insertions(+), 19 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/LnImpl.hpp b/include/aidge/backend/cpu/operator/LnImpl.hpp
index 31d1ff79..faa03855 100755
--- a/include/aidge/backend/cpu/operator/LnImpl.hpp
+++ b/include/aidge/backend/cpu/operator/LnImpl.hpp
@@ -28,7 +28,7 @@ class LnImplForward_cpu
     : public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
 };
 class LnImplBackward_cpu
-    : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
+    : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
 };
 
 class LnImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp
index eb0e3a70..5fb82e35 100755
--- a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp
@@ -18,32 +18,32 @@
 #include "aidge/utils/Registrar.hpp"
 
 namespace Aidge {
-template <class I, class O, class GI, class GO>
+template <class I, class GI, class GO>
 void LnImpl_cpu_backward_kernel(const std::size_t inputLenght,
-                                const void* input_, const void* output_, const void* grad_output_,
-	                        void* grad_input_) {
-									 
+                                const void* input_, const void* grad_output_,
+	                            void* grad_input_) {
+						 
     const I* input = static_cast<const I*>(input_);
-	const O* output = static_cast<const O*>(output_);
     const GO* grad_output = static_cast<const GO*>(grad_output_);
     GI* grad_input = static_cast<GI*>(grad_input_);
+	const float eps = 1.0e-20f;
 	
     for (std::size_t i = 0; i < inputLenght; ++i) {
-		if (output[i] > O(-100)) {
+		if (input[i] > I(eps)) {
 			grad_input[i] = grad_output[i] / input[i];
 		} else {
-			grad_input[i] = O(0);
+			grad_input[i] = GI(0);
 		}
     }
 }
 
 namespace {
 static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32(
-    {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
-    Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>);	
+    {DataType::Float32, DataType::Float32, DataType::Float32},
+    Aidge::LnImpl_cpu_backward_kernel<float, float, float>);	
 static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64(
-    {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
-    Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>);
+    {DataType::Float64, DataType::Float64, DataType::Float64},
+    Aidge::LnImpl_cpu_backward_kernel<double, double, double>);
 }  // namespace
 }  // namespace Aidge
 
diff --git a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp
index 1bc27332..ebb97551 100755
--- a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp
@@ -24,16 +24,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght,
 
     const I* input = static_cast<const I*>(input_);
     O* output = static_cast<O*>(output_);
+	const float eps = 1.0e-20f;
 
 //#pragma omp parallel for if (inputLenght > 1024)
     for (std::size_t i = 0; i < inputLenght; ++i) {
-		if (input[i] > I(0)) {
+		if (input[i] > I(eps)) {
 			output[i] = std::log(input[i]);
-			if (output[i] < O(-100)) {
-				output[i] = O(-100);
-			}
 		} else {
-			output[i] = O(-100);
+			output[i] = std::log(I(eps));
 		}
     }
 }
diff --git a/src/operator/LnImpl.cpp b/src/operator/LnImpl.cpp
index b4c82473..12885a94 100644
--- a/src/operator/LnImpl.cpp
+++ b/src/operator/LnImpl.cpp
@@ -56,11 +56,10 @@ void Aidge::LnImpl_cpu::backward() {
     // Find the correct kernel type
     auto kernelFunc = Registrar<LnImplBackward_cpu>::create({
         in0->dataType(),
-		out0->dataType(),
 	    gra_int0->dataType(),
         gra_out0->dataType()        
     });
 
     // Call kernel
-    kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
+    kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
 }
-- 
GitLab