From 1ab9eb0474102e8fe9269388abb0907ee7765c46 Mon Sep 17 00:00:00 2001 From: Antoni Olivier <olivier.antoni@cea.fr> Date: Thu, 6 Jun 2024 10:22:44 +0200 Subject: [PATCH] Add backward functions for ReLU, Sigmoid and Tanh --- .../aidge/backend/cpu/operator/ReLUImpl.hpp | 2 +- .../operator/ReLUImpl_backward_kernels.hpp | 19 +++++++++---------- .../backend/cpu/operator/SigmoidImpl.hpp | 2 +- .../operator/SigmoidImpl_backward_kernels.hpp | 15 +++++++-------- .../aidge/backend/cpu/operator/TanhImpl.hpp | 2 +- .../operator/TanhImpl_backward_kernels.hpp | 15 +++++++-------- src/operator/ReLUImpl.cpp | 11 +++++------ src/operator/SigmoidImpl.cpp | 12 +++++------- src/operator/TanhImpl.cpp | 8 +++----- 9 files changed, 39 insertions(+), 47 deletions(-) diff --git a/include/aidge/backend/cpu/operator/ReLUImpl.hpp b/include/aidge/backend/cpu/operator/ReLUImpl.hpp index f8abfcf2..e2ebf446 100644 --- a/include/aidge/backend/cpu/operator/ReLUImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReLUImpl.hpp @@ -30,7 +30,7 @@ class ReLUImplForward_cpu : public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { }; class ReLUImplBackward_cpu - : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { + : public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> { }; class ReLUImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp index e67a4588..1bd932e4 100644 --- a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp @@ -18,12 +18,11 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -template <class I, class O, class GI, class GO> +template <class I, class GI, class GO> void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght, - const void* input_, const void* output_, const void* grad_output_, - void* grad_input_) { + const void* input_, const void* grad_output_, + void* grad_input_) { const I* input = static_cast<const I*>(input_); - //const O* output = static_cast<const O*>(output_); const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); for (std::size_t i = 0; i < inputLenght; ++i) { @@ -33,14 +32,14 @@ void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght, namespace { static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32( - {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, - Aidge::ReLUImpl_cpu_backward_kernel<float, float, float, float>); + {DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>); static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32( - {DataType::Int32, DataType::Int32, DataType::Int32, DataType::Int32}, - Aidge::ReLUImpl_cpu_backward_kernel<int, int, int, int>); + {DataType::Int32, DataType::Int32, DataType::Int32}, + Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>); static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64( - {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64}, - Aidge::ReLUImpl_cpu_backward_kernel<double, double, double, double>); + {DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>); } // namespace } // namespace Aidge diff --git a/include/aidge/backend/cpu/operator/SigmoidImpl.hpp b/include/aidge/backend/cpu/operator/SigmoidImpl.hpp index ed9ffe13..34340e61 100644 --- a/include/aidge/backend/cpu/operator/SigmoidImpl.hpp +++ b/include/aidge/backend/cpu/operator/SigmoidImpl.hpp @@ -28,7 +28,7 @@ class SigmoidImplForward_cpu : public Registrable<SigmoidImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { }; class SigmoidImplBackward_cpu - : public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { + : public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> { }; class SigmoidImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/SigmoidImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/SigmoidImpl_backward_kernels.hpp index 931a2e4d..4ceb3bd7 100644 --- a/include/aidge/backend/cpu/operator/SigmoidImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SigmoidImpl_backward_kernels.hpp @@ -18,11 +18,10 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -template <class I, class O, class GI, class GO> +template <class O, class GI, class GO> void SigmoidImpl_cpu_backward_kernel(const std::size_t inputLenght, - const void* input_, const void* output_, const void* grad_output_, - void* grad_input_) { - //const I* input = static_cast<const I*>(input_); + const void* output_, const void* grad_output_, + void* grad_input_) { const O* output = static_cast<const O*>(output_); const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); @@ -33,11 +32,11 @@ void SigmoidImpl_cpu_backward_kernel(const std::size_t inputLenght, namespace { static Registrar<SigmoidImplBackward_cpu> registrarSigmoidImplBackward_cpu_Float32( - {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, - Aidge::SigmoidImpl_cpu_backward_kernel<float, float, float, float>); + {DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::SigmoidImpl_cpu_backward_kernel<float, float, float>); static Registrar<SigmoidImplBackward_cpu> registrarSigmoidImplBackward_cpu_Float64( - {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64}, - Aidge::SigmoidImpl_cpu_backward_kernel<double, double, double, double>); + {DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::SigmoidImpl_cpu_backward_kernel<double, double, double>); } // namespace } // namespace Aidge diff --git a/include/aidge/backend/cpu/operator/TanhImpl.hpp b/include/aidge/backend/cpu/operator/TanhImpl.hpp index a62cd050..0bf851e7 100644 --- a/include/aidge/backend/cpu/operator/TanhImpl.hpp +++ b/include/aidge/backend/cpu/operator/TanhImpl.hpp @@ -28,7 +28,7 @@ class TanhImplForward_cpu : public Registrable<TanhImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { }; class TanhImplBackward_cpu - : public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { + : public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> { }; class TanhImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/TanhImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/TanhImpl_backward_kernels.hpp index 3f49e4c4..3a13c2ca 100644 --- a/include/aidge/backend/cpu/operator/TanhImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/TanhImpl_backward_kernels.hpp @@ -18,11 +18,10 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { -template <class I, class O, class GI, class GO> +template <class O, class GI, class GO> void TanhImpl_cpu_backward_kernel(const std::size_t inputLenght, - const void* input_, const void* output_, const void* grad_output_, - void* grad_input_) { - //const I* input = static_cast<const I*>(input_); + const void* output_, const void* grad_output_, + void* grad_input_) { const O* output = static_cast<const O*>(output_); const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); @@ -33,11 +32,11 @@ void TanhImpl_cpu_backward_kernel(const std::size_t inputLenght, namespace { static Registrar<TanhImplBackward_cpu> registrarTanhImplBackward_cpu_Float32( - {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, - Aidge::TanhImpl_cpu_backward_kernel<float, float, float, float>); + {DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::TanhImpl_cpu_backward_kernel<float, float, float>); static Registrar<TanhImplBackward_cpu> registrarTanhImplBackward_cpu_Float64( - {DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64}, - Aidge::TanhImpl_cpu_backward_kernel<double, double, double, double>); + {DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::TanhImpl_cpu_backward_kernel<double, double, double>); } // namespace } // namespace Aidge diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp index 8de2190e..4a0fb9f5 100644 --- a/src/operator/ReLUImpl.cpp +++ b/src/operator/ReLUImpl.cpp @@ -46,20 +46,19 @@ void Aidge::ReLUImpl_cpu::forward() { void Aidge::ReLUImpl_cpu::backward() { const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp); - std::shared_ptr<Tensor> in0 = op_.getInput(0); + std::shared_ptr<Tensor> in0 = op_.getInput(0); std::shared_ptr<Tensor> out0 = op_.getOutput(0); std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad(); - std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); + std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type()); // Find the correct kernel type auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({ - in0->dataType(), - out0->dataType(), + in0->dataType(), gra_int0->dataType(), - gra_out0->dataType() + gra_out0->dataType() }); // Call kernel - kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); + kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); } diff --git a/src/operator/SigmoidImpl.cpp b/src/operator/SigmoidImpl.cpp index fe92ef43..ad69935c 100644 --- a/src/operator/SigmoidImpl.cpp +++ b/src/operator/SigmoidImpl.cpp @@ -47,20 +47,18 @@ void Aidge::SigmoidImpl_cpu::forward() { void Aidge::SigmoidImpl_cpu::backward() { const Sigmoid_Op& op_ = dynamic_cast<const Sigmoid_Op&>(mOp); - std::shared_ptr<Tensor> in0 = op_.getInput(0); std::shared_ptr<Tensor> out0 = op_.getOutput(0); - std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad(); - std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); + std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad(); + std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type()); // Find the correct kernel type auto kernelFunc = Registrar<SigmoidImplBackward_cpu>::create({ - in0->dataType(), out0->dataType(), - gra_int0->dataType(), + gra_int0->dataType(), gra_out0->dataType() }); // Call kernel - kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); -} \ No newline at end of file + kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); +} diff --git a/src/operator/TanhImpl.cpp b/src/operator/TanhImpl.cpp index 8b5988e9..a2469ed9 100644 --- a/src/operator/TanhImpl.cpp +++ b/src/operator/TanhImpl.cpp @@ -47,21 +47,19 @@ void Aidge::TanhImpl_cpu::forward() { void Aidge::TanhImpl_cpu::backward() { const Tanh_Op& op_ = dynamic_cast<const Tanh_Op&>(mOp); - std::shared_ptr<Tensor> in0 = op_.getInput(0); std::shared_ptr<Tensor> out0 = op_.getOutput(0); std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad(); - std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); + std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type()); // Find the correct kernel type auto kernelFunc = Registrar<TanhImplBackward_cpu>::create({ - in0->dataType(), out0->dataType(), - gra_int0->dataType(), + gra_int0->dataType(), gra_out0->dataType() }); // Call kernel - kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); + kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0)); } -- GitLab