Skip to content
Snippets Groups Projects
Commit af4cacb7 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix Relu backward.

parent 350ea993
No related branches found
No related tags found
2 merge requests!610.2.2,!59Fix Relu backward.
Pipeline #45019 passed
......@@ -30,7 +30,7 @@ class ReLUImplForward_cpu
: public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class ReLUImplBackward_cpu
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
};
class ReLUImpl_cpu : public OperatorImpl {
......
......@@ -14,31 +14,32 @@
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O>
template <class O, class GI, class GO>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const void* output_, const void* grad_output_,
void* grad_input_) {
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = (input[i] > I(0)) ? static_cast<O>(input[i]) : O(0);
grad_input[i] = (output[i] > GO(0)) ? GI(grad_output[i]) : GI(0);
}
}
namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>);
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>);
{DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>);
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>);
} // namespace
} // namespace Aidge
......
......@@ -45,16 +45,18 @@ void Aidge::ReLUImpl_cpu::forward() {
void Aidge::ReLUImpl_cpu::backward() {
// reversing in and out Tensors
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> out0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type());
// Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
in0->dataType(),
out0->dataType()
out0->dataType(),
gra_out0->dataType(),
gra_int0->dataType()
});
// Call kernel
kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0));
kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment