diff --git a/include/aidge/backend/cpu/operator/LnImpl.hpp b/include/aidge/backend/cpu/operator/LnImpl.hpp index 31d1ff79b0201b1cf1037a6b33fb797862feb01f..faa03855a4f881f2a644ebc4023871b7acd6275c 100755 --- a/include/aidge/backend/cpu/operator/LnImpl.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl.hpp @@ -28,7 +28,7 @@ class LnImplForward_cpu : public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { }; class LnImplBackward_cpu - : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { + : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> { }; class LnImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp index eb0e3a70011ef55d579d6526bde20d43e21867a6..5fb82e35f8855d9d6e2eb85e9ab380c9f1fc9b90 100755 --- a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp @@ -18,32 +18,32 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -template <class I, class O, class GI, class GO> +template <class I, class GI, class GO> void LnImpl_cpu_backward_kernel(const std::size_t inputLenght, - const void* input_, const void* output_, const void* grad_output_, - void* grad_input_) { - + const void* input_, const void* grad_output_, + void* grad_input_) { + const I* input = static_cast<const I*>(input_); - const O* output = static_cast<const O*>(output_); const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); + const float eps = 1.0e-20f; for (std::size_t i = 0; i < inputLenght; ++i) { - if (output[i] > O(-100)) { + if (input[i] > I(eps)) { grad_input[i] = grad_output[i] / input[i]; } else { - grad_input[i] = O(0); + grad_input[i] = GI(0); } } } namespace { static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32( - {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, - Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>); + {DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::LnImpl_cpu_backward_kernel<float, float, float>); static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64( - {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64}, - Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>); + {DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::LnImpl_cpu_backward_kernel<double, double, double>); } // namespace } // namespace Aidge diff --git a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp index 1bc273325d3d014014b8ba1d4f9417fec091024b..ebb975512a6e7c0f7225c305372f0ec6e7060786 100755 --- a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp @@ -24,16 +24,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght, const I* input = static_cast<const I*>(input_); O* output = static_cast<O*>(output_); + const float eps = 1.0e-20f; //#pragma omp parallel for if (inputLenght > 1024) for (std::size_t i = 0; i < inputLenght; ++i) { - if (input[i] > I(0)) { + if (input[i] > I(eps)) { output[i] = std::log(input[i]); - if (output[i] < O(-100)) { - output[i] = O(-100); - } } else { - output[i] = O(-100); + output[i] = std::log(I(eps)); } } } diff --git a/src/operator/LnImpl.cpp b/src/operator/LnImpl.cpp index b4c82473c34bc30db2dc3a5ff3e834f8f6653fd6..12885a944be46a977463e900af4047319bb1c8b2 100644 --- a/src/operator/LnImpl.cpp +++ b/src/operator/LnImpl.cpp @@ -56,11 +56,10 @@ void Aidge::LnImpl_cpu::backward() { // Find the correct kernel type auto kernelFunc = Registrar<LnImplBackward_cpu>::create({ in0->dataType(), - out0->dataType(), gra_int0->dataType(), gra_out0->dataType() }); // Call kernel - kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); + kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); }