Skip to content
Snippets Groups Projects
Commit a724058e authored by Olivier Antoni's avatar Olivier Antoni
Browse files

Add log for BCE loss function

parent b9fd7b57
No related branches found
No related tags found
2 merge requests!73version 0.2.3,!68Add log operator for BCE loss function
......@@ -28,7 +28,7 @@ class LnImplForward_cpu
: public Registrable<LnImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class LnImplBackward_cpu
: public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
: public Registrable<LnImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
};
class LnImpl_cpu : public OperatorImpl {
......
......@@ -18,27 +18,32 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class GI, class GO>
template <class I, class O, class GI, class GO>
void LnImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* grad_output_,
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
const I* input = static_cast<const I*>(input_);
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) {
grad_input[i] = grad_output[i] / input[i];
if (output[i] > O(-100)) {
grad_input[i] = grad_output[i] / input[i];
} else {
grad_input[i] = O(0);
}
}
}
namespace {
static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::LnImpl_cpu_backward_kernel<float, float, float>);
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::LnImpl_cpu_backward_kernel<float, float, float, float>);
static Registrar<LnImplBackward_cpu> registrarLnImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::LnImpl_cpu_backward_kernel<double, double, double>);
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::LnImpl_cpu_backward_kernel<double, double, double, double>);
} // namespace
} // namespace Aidge
......
......@@ -27,7 +27,14 @@ void LnImpl_cpu_forward_kernel(std::size_t inputLenght,
//#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = std::log(input[i]);
if (input[i] > I(0)) {
output[i] = std::log(input[i]);
if (output[i] < O(-100)) {
output[i] = O(-100);
}
} else {
output[i] = O(-100);
}
}
}
......
......@@ -56,10 +56,11 @@ void Aidge::LnImpl_cpu::backward() {
// Find the correct kernel type
auto kernelFunc = Registrar<LnImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment