Skip to content
Snippets Groups Projects
Commit 1580181f authored by Olivier Antoni's avatar Olivier Antoni
Browse files

Add log for BCE loss function

parent a724058e
No related branches found
No related tags found
2 merge requests!73version 0.2.3,!68Add log operator for BCE loss function
Pipeline #48725 passed
...@@ -28,7 +28,7 @@ class LnImplForward_cpu ...@@ -28,7 +28,7 @@ class LnImplForward_cpu
: public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { : public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
}; };
class LnImplBackward_cpu class LnImplBackward_cpu
: public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { : public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
}; };
class LnImpl_cpu : public OperatorImpl { class LnImpl_cpu : public OperatorImpl {
......
...@@ -18,32 +18,32 @@ ...@@ -18,32 +18,32 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
template <class I, class O, class GI, class GO> template <class I, class GI, class GO>
void LnImpl_cpu_backward_kernel(const std::size_t inputLenght, void LnImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_, const void* input_, const void* grad_output_,
void* grad_input_) { void* grad_input_) {
const I* input = static_cast<const I*>(input_); const I* input = static_cast<const I*>(input_);
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_); const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_); GI* grad_input = static_cast<GI*>(grad_input_);
const float eps = 1.0e-20f;
for (std::size_t i = 0; i < inputLenght; ++i) { for (std::size_t i = 0; i < inputLenght; ++i) {
if (output[i] > O(-100)) { if (input[i] > I(eps)) {
grad_input[i] = grad_output[i] / input[i]; grad_input[i] = grad_output[i] / input[i];
} else { } else {
grad_input[i] = O(0); grad_input[i] = GI(0);
} }
} }
} }
namespace { namespace {
static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32( static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, {DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>); Aidge::LnImpl_cpu_backward_kernel<float, float, float>);
static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64( static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64}, {DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>); Aidge::LnImpl_cpu_backward_kernel<double, double, double>);
} // namespace } // namespace
} // namespace Aidge } // namespace Aidge
......
...@@ -24,16 +24,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght, ...@@ -24,16 +24,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght,
const I* input = static_cast<const I*>(input_); const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
const float eps = 1.0e-20f;
//#pragma omp parallel for if (inputLenght > 1024) //#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) { for (std::size_t i = 0; i < inputLenght; ++i) {
if (input[i] > I(0)) { if (input[i] > I(eps)) {
output[i] = std::log(input[i]); output[i] = std::log(input[i]);
if (output[i] < O(-100)) {
output[i] = O(-100);
}
} else { } else {
output[i] = O(-100); output[i] = std::log(I(eps));
} }
} }
} }
......
...@@ -56,11 +56,10 @@ void Aidge::LnImpl_cpu::backward() { ...@@ -56,11 +56,10 @@ void Aidge::LnImpl_cpu::backward() {
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<LnImplBackward_cpu>::create({ auto kernelFunc = Registrar<LnImplBackward_cpu>::create({
in0->dataType(), in0->dataType(),
out0->dataType(),
gra_int0->dataType(), gra_int0->dataType(),
gra_out0->dataType() gra_out0->dataType()
}); });
// Call kernel // Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment