Skip to content
Snippets Groups Projects
Commit 1ab9eb04 authored by Olivier Antoni's avatar Olivier Antoni
Browse files

Add backward functions for ReLU, Sigmoid and Tanh

parent 33d46ff3
No related branches found
No related tags found
2 merge requests!73version 0.2.3,!67Add backward functions for ReLU, Sigmoid and Tanh
Pipeline #47525 passed
......@@ -30,7 +30,7 @@ class ReLUImplForward_cpu
: public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class ReLUImplBackward_cpu
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
};
class ReLUImpl_cpu : public OperatorImpl {
......
......@@ -18,12 +18,11 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O, class GI, class GO>
template <class I, class GI, class GO>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
const void* input_, const void* grad_output_,
void* grad_input_) {
const I* input = static_cast<const I*>(input_);
//const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) {
......@@ -33,14 +32,14 @@ void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::ReLUImpl_cpu_backward_kernel<float, float, float, float>);
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::ReLUImpl_cpu_backward_kernel<int, int, int, int>);
{DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::ReLUImpl_cpu_backward_kernel<double, double, double, double>);
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>);
} // namespace
} // namespace Aidge
......
......@@ -28,7 +28,7 @@ class SigmoidImplForward_cpu
: public Registrable<SigmoidImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class SigmoidImplBackward_cpu
: public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
: public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
};
class SigmoidImpl_cpu : public OperatorImpl {
......
......@@ -18,11 +18,10 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O, class GI, class GO>
template <class O, class GI, class GO>
void SigmoidImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
//const I* input = static_cast<const I*>(input_);
const void* output_, const void* grad_output_,
void* grad_input_) {
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
......@@ -33,11 +32,11 @@ void SigmoidImpl_cpu_backward_kernel(const std::size_t inputLenght,
namespace {
static Registrar<SigmoidImplBackward_cpu> registrarSigmoidImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::SigmoidImpl_cpu_backward_kernel<float, float, float, float>);
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::SigmoidImpl_cpu_backward_kernel<float, float, float>);
static Registrar<SigmoidImplBackward_cpu> registrarSigmoidImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::SigmoidImpl_cpu_backward_kernel<double, double, double, double>);
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::SigmoidImpl_cpu_backward_kernel<double, double, double>);
} // namespace
} // namespace Aidge
......
......@@ -28,7 +28,7 @@ class TanhImplForward_cpu
: public Registrable<TanhImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class TanhImplBackward_cpu
: public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
: public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
};
class TanhImpl_cpu : public OperatorImpl {
......
......@@ -18,11 +18,10 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O, class GI, class GO>
template <class O, class GI, class GO>
void TanhImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
//const I* input = static_cast<const I*>(input_);
const void* output_, const void* grad_output_,
void* grad_input_) {
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
......@@ -33,11 +32,11 @@ void TanhImpl_cpu_backward_kernel(const std::size_t inputLenght,
namespace {
static Registrar<TanhImplBackward_cpu> registrarTanhImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::TanhImpl_cpu_backward_kernel<float, float, float, float>);
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::TanhImpl_cpu_backward_kernel<float, float, float>);
static Registrar<TanhImplBackward_cpu> registrarTanhImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::TanhImpl_cpu_backward_kernel<double, double, double, double>);
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::TanhImpl_cpu_backward_kernel<double, double, double>);
} // namespace
} // namespace Aidge
......
......@@ -46,20 +46,19 @@ void Aidge::ReLUImpl_cpu::forward() {
void Aidge::ReLUImpl_cpu::backward() {
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type());
// Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
in0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
......@@ -47,20 +47,18 @@ void Aidge::SigmoidImpl_cpu::forward() {
void Aidge::SigmoidImpl_cpu::backward() {
const Sigmoid_Op& op_ = dynamic_cast<const Sigmoid_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type());
// Find the correct kernel type
auto kernelFunc = Registrar<SigmoidImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
gra_int0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
\ No newline at end of file
kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
......@@ -47,21 +47,19 @@ void Aidge::TanhImpl_cpu::forward() {
void Aidge::TanhImpl_cpu::backward() {
const Tanh_Op& op_ = dynamic_cast<const Tanh_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type());
// Find the correct kernel type
auto kernelFunc = Registrar<TanhImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
gra_int0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment