Skip to content
Snippets Groups Projects
Commit af4cacb7 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix Relu backward.

parent 350ea993
No related branches found
No related tags found
2 merge requests!610.2.2,!59Fix Relu backward.
Pipeline #45019 passed
...@@ -30,7 +30,7 @@ class ReLUImplForward_cpu ...@@ -30,7 +30,7 @@ class ReLUImplForward_cpu
: public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { : public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
}; };
class ReLUImplBackward_cpu class ReLUImplBackward_cpu
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
}; };
class ReLUImpl_cpu : public OperatorImpl { class ReLUImpl_cpu : public OperatorImpl {
......
...@@ -14,31 +14,32 @@ ...@@ -14,31 +14,32 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp" #include "aidge/backend/cpu/operator/ReLUImpl.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
template <class I, class O> template <class O, class GI, class GO>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght, void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* output_) { void* grad_input_) {
const O* output = static_cast<const O*>(output_);
const I* input = static_cast<const I*>(input_); const GO* grad_output = static_cast<const GO*>(grad_output_);
O* output = static_cast<O*>(output_); GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) { for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = (input[i] > I(0)) ? static_cast<O>(input[i]) : O(0); grad_input[i] = (output[i] > GO(0)) ? GI(grad_output[i]) : GI(0);
} }
} }
namespace { namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32( static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>); {DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32( static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>); {DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64( static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>); {DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>);
} // namespace } // namespace
} // namespace Aidge } // namespace Aidge
......
...@@ -45,16 +45,18 @@ void Aidge::ReLUImpl_cpu::forward() { ...@@ -45,16 +45,18 @@ void Aidge::ReLUImpl_cpu::forward() {
void Aidge::ReLUImpl_cpu::backward() { void Aidge::ReLUImpl_cpu::backward() {
// reversing in and out Tensors // reversing in and out Tensors
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp); const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getOutput(0)->grad(); std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> out0 = op_.getInput(0)->grad(); std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type()); AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type());
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({ auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
in0->dataType(), out0->dataType(),
out0->dataType() gra_out0->dataType(),
gra_int0->dataType()
}); });
// Call kernel // Call kernel
kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0)); kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment