Skip to content
Snippets Groups Projects
Commit 1580181f authored by Olivier Antoni's avatar Olivier Antoni
Browse files

Add log for BCE loss function

parent a724058e
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ class LnImplForward_cpu
: public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class LnImplBackward_cpu
: public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
: public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
};
class LnImpl_cpu : public OperatorImpl {
......
......@@ -18,32 +18,32 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O, class GI, class GO>
template <class I, class GI, class GO>
void LnImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
const void* input_, const void* grad_output_,
void* grad_input_) {
const I* input = static_cast<const I*>(input_);
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
const float eps = 1.0e-20f;
for (std::size_t i = 0; i < inputLenght; ++i) {
if (output[i] > O(-100)) {
if (input[i] > I(eps)) {
grad_input[i] = grad_output[i] / input[i];
} else {
grad_input[i] = O(0);
grad_input[i] = GI(0);
}
}
}
namespace {
static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>);
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::LnImpl_cpu_backward_kernel<float, float, float>);
static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>);
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::LnImpl_cpu_backward_kernel<double, double, double>);
} // namespace
} // namespace Aidge
......
......@@ -24,16 +24,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght,
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const float eps = 1.0e-20f;
//#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) {
if (input[i] > I(0)) {
if (input[i] > I(eps)) {
output[i] = std::log(input[i]);
if (output[i] < O(-100)) {
output[i] = O(-100);
}
} else {
output[i] = O(-100);
output[i] = std::log(I(eps));
}
}
}
......
......@@ -56,11 +56,10 @@ void Aidge::LnImpl_cpu::backward() {
// Find the correct kernel type
auto kernelFunc = Registrar<LnImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment