diff --git a/include/aidge/backend/cpu/operator/ReLUImpl.hpp b/include/aidge/backend/cpu/operator/ReLUImpl.hpp index cef82482813757312c638aebac9f2afd738493db..e2ebf44616db876b462157db650ff48362dd7bac 100644 --- a/include/aidge/backend/cpu/operator/ReLUImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReLUImpl.hpp @@ -30,7 +30,7 @@ class ReLUImplForward_cpu : public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { }; class ReLUImplBackward_cpu - : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { + : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> { }; class ReLUImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp index b68ea076cb94eb9550b4a7af89ef58162ee15aea..43a9714ad2d32228fac9bf9c526191f0cec5bfa0 100644 --- a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp @@ -14,31 +14,32 @@ #include <cstddef> // std::size_t -#include "aidge/utils/Registrar.hpp" - #include "aidge/backend/cpu/operator/ReLUImpl.hpp" +#include "aidge/utils/Registrar.hpp" namespace Aidge { -template <class I, class O> +template <class O, class GI, class GO> void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght, - const void* input_, - void* output_) { - - const I* input = static_cast<const I*>(input_); - O* output = static_cast<O*>(output_); - + const void* output_, const void* grad_output_, + void* grad_input_) { + const O* output = static_cast<const O*>(output_); + const GO* grad_output = static_cast<const GO*>(grad_output_); + GI* grad_input = static_cast<GI*>(grad_input_); for (std::size_t i = 0; i < inputLenght; ++i) { - output[i] = (input[i] > I(0)) ? static_cast<O>(input[i]) : O(0); + grad_input[i] = (output[i] > GO(0)) ? GI(grad_output[i]) : GI(0); } } namespace { static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32( - {DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>); + {DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>); static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32( - {DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>); + {DataType::Int32, DataType::Int32, DataType::Int32}, + Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>); static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64( - {DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>); + {DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>); } // namespace } // namespace Aidge diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp index 4bba09b6fbeea1552bf5b7cc7e491291345fca45..06859f09db169946175a93140e04f2e2a99e3362 100644 --- a/src/operator/ReLUImpl.cpp +++ b/src/operator/ReLUImpl.cpp @@ -45,16 +45,18 @@ void Aidge::ReLUImpl_cpu::forward() { void Aidge::ReLUImpl_cpu::backward() { // reversing in and out Tensors const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp); - std::shared_ptr<Tensor> in0 = op_.getOutput(0)->grad(); - std::shared_ptr<Tensor> out0 = op_.getInput(0)->grad(); + std::shared_ptr<Tensor> out0 = op_.getOutput(0); + std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); + std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad(); AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type()); // Find the correct kernel type auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({ - in0->dataType(), - out0->dataType() + out0->dataType(), + gra_out0->dataType(), + gra_int0->dataType() }); // Call kernel - kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0)); + kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); }