diff --git a/include/aidge/backend/cpu/operator/LnImpl.hpp b/include/aidge/backend/cpu/operator/LnImpl.hpp index faa03855a4f881f2a644ebc4023871b7acd6275c..31d1ff79b0201b1cf1037a6b33fb797862feb01f 100755 --- a/include/aidge/backend/cpu/operator/LnImpl.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl.hpp @@ -28,7 +28,7 @@ class LnImplForward_cpu : public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { }; class LnImplBackward_cpu - : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> { + : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { }; class LnImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp index 70e34cf3cc2bc9aacde01eea4bef797d81d27f26..eb0e3a70011ef55d579d6526bde20d43e21867a6 100755 --- a/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl_backward_kernels.hpp @@ -18,27 +18,32 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -template <class I, class GI, class GO> +template <class I, class O, class GI, class GO> void LnImpl_cpu_backward_kernel(const std::size_t inputLenght, - const void* input_, const void* grad_output_, + const void* input_, const void* output_, const void* grad_output_, void* grad_input_) { const I* input = static_cast<const I*>(input_); + const O* output = static_cast<const O*>(output_); const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); for (std::size_t i = 0; i < inputLenght; ++i) { - grad_input[i] = grad_output[i] / input[i]; + if (output[i] > O(-100)) { + grad_input[i] = grad_output[i] / input[i]; + } else { + grad_input[i] = O(0); + } } } namespace { static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32( - {DataType::Float32, DataType::Float32, DataType::Float32}, - Aidge::LnImpl_cpu_backward_kernel<float, float, float>); + {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>); static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64( - {DataType::Float64, DataType::Float64, DataType::Float64}, - Aidge::LnImpl_cpu_backward_kernel<double, double, double>); + {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>); } // namespace } // namespace Aidge diff --git a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp index c7681730cd0557bf51137cb02147b234de43a3f9..1bc273325d3d014014b8ba1d4f9417fec091024b 100755 --- a/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/LnImpl_forward_kernels.hpp @@ -27,7 +27,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght, //#pragma omp parallel for if (inputLenght > 1024) for (std::size_t i = 0; i < inputLenght; ++i) { - output[i] = std::log(input[i]); + if (input[i] > I(0)) { + output[i] = std::log(input[i]); + if (output[i] < O(-100)) { + output[i] = O(-100); + } + } else { + output[i] = O(-100); + } } } diff --git a/src/operator/LnImpl.cpp b/src/operator/LnImpl.cpp index 12885a944be46a977463e900af4047319bb1c8b2..b4c82473c34bc30db2dc3a5ff3e834f8f6653fd6 100644 --- a/src/operator/LnImpl.cpp +++ b/src/operator/LnImpl.cpp @@ -56,10 +56,11 @@ void Aidge::LnImpl_cpu::backward() { // Find the correct kernel type auto kernelFunc = Registrar<LnImplBackward_cpu>::create({ in0->dataType(), + out0->dataType(), gra_int0->dataType(), gra_out0->dataType() }); // Call kernel - kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); + kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); }