Skip to content
Snippets Groups Projects
Commit 5160226e authored by Cyril Moineau's avatar Cyril Moineau Committed by Cyril Moineau
Browse files

Fix Relu backward.

parent 06d9bb53
No related branches found
No related tags found
1 merge request!610.2.2
...@@ -30,7 +30,7 @@ class ReLUImplForward_cpu ...@@ -30,7 +30,7 @@ class ReLUImplForward_cpu
: public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { : public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
}; };
class ReLUImplBackward_cpu class ReLUImplBackward_cpu
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
}; };
class ReLUImpl_cpu : public OperatorImpl { class ReLUImpl_cpu : public OperatorImpl {
......
...@@ -14,31 +14,32 @@ ...@@ -14,31 +14,32 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp" #include "aidge/backend/cpu/operator/ReLUImpl.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
template <class I, class O> template <class O, class GI, class GO>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght, void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* output_) { void* grad_input_) {
const O* output = static_cast<const O*>(output_);
const I* input = static_cast<const I*>(input_); const GO* grad_output = static_cast<const GO*>(grad_output_);
O* output = static_cast<O*>(output_); GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) { for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = (input[i] > I(0)) ? static_cast<O>(input[i]) : O(0); grad_input[i] = (output[i] > GO(0)) ? GI(grad_output[i]) : GI(0);
} }
} }
namespace { namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32( static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>); {DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32( static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>); {DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64( static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>); {DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>);
} // namespace } // namespace
} // namespace Aidge } // namespace Aidge
......
...@@ -45,16 +45,18 @@ void Aidge::ReLUImpl_cpu::forward() { ...@@ -45,16 +45,18 @@ void Aidge::ReLUImpl_cpu::forward() {
void Aidge::ReLUImpl_cpu::backward() { void Aidge::ReLUImpl_cpu::backward() {
// reversing in and out Tensors // reversing in and out Tensors
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp); const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getOutput(0)->grad(); std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> out0 = op_.getInput(0)->grad(); std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type()); AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type());
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({ auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
in0->dataType(), out0->dataType(),
out0->dataType() gra_out0->dataType(),
gra_int0->dataType()
}); });
// Call kernel // Call kernel
kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0)); kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment