Skip to content
Snippets Groups Projects
Commit 29568f62 authored by Olivier Antoni's avatar Olivier Antoni
Browse files

Add backward functions for ReLU, Sigmoid and Tanh

parent 17948093
No related branches found
No related tags found
2 merge requests!73version 0.2.3,!67Add backward functions for ReLU, Sigmoid and Tanh
Pipeline #47269 failed
Showing with 183 additions and 36 deletions
......@@ -30,7 +30,7 @@ class ReLUImplForward_cpu
: public Registrable<ReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class ReLUImplBackward_cpu
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {
: public Registrable<ReLUImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
};
class ReLUImpl_cpu : public OperatorImpl {
......
......@@ -18,28 +18,29 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class O, class GI, class GO>
template <class I, class O, class GI, class GO>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* output_, const void* grad_output_,
void* grad_input_) {
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
const I* input = static_cast<const I*>(input_);
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) {
grad_input[i] = (output[i] > GO(0)) ? GI(grad_output[i]) : GI(0);
grad_input[i] = (input[i] > 0) ? grad_output[i] : 0;
}
}
namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::ReLUImpl_cpu_backward_kernel<float, float, float>);
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::ReLUImpl_cpu_backward_kernel<float, float, float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::ReLUImpl_cpu_backward_kernel<int, int, int>);
{DataType::Int32, DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::ReLUImpl_cpu_backward_kernel<int, int, int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::ReLUImpl_cpu_backward_kernel<double, double, double>);
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::ReLUImpl_cpu_backward_kernel<double, double, double, double>);
} // namespace
} // namespace Aidge
......
......@@ -27,7 +27,7 @@ void ReLUImpl_cpu_forward_kernel(std::size_t inputLenght,
//#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = input[i] > 0 ? input[i] : 0;
output[i] = (input[i] > 0) ? input[i] : 0;
}
}
......
......@@ -28,7 +28,7 @@ class SigmoidImplForward_cpu
: public Registrable<SigmoidImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class SigmoidImplBackward_cpu
: public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
: public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
};
class SigmoidImpl_cpu : public OperatorImpl {
......@@ -40,7 +40,10 @@ public:
}
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
};
namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SIGMOIDIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SIGMOIDIMPL_BACKWARD_KERNEL_H_
#include <cstddef> // std::size_t
#include "aidge/backend/cpu/operator/SigmoidImpl.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O, class GI, class GO>
void SigmoidImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
const I* input = static_cast<const I*>(input_);
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) {
grad_input[i] = output[i] * (O(1) - output[i]) * grad_output[i];
}
}
namespace {
static Registrar<SigmoidImplBackward_cpu> registrarSigmoidImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::SigmoidImpl_cpu_backward_kernel<float, float, float, float>);
static Registrar<SigmoidImplBackward_cpu> registrarSigmoidImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::SigmoidImpl_cpu_backward_kernel<double, double, double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SIGMOIDIMPL_BACKWARD_KERNEL_H_ */
......@@ -19,15 +19,15 @@
namespace Aidge {
template <class I, class O>
void SigmoidImpl_cpu_forward_kernel(std::size_t inputLenght,
const void* input_,
void* output_) {
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
//#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = static_cast<O>(1.0) / (static_cast<O>(1.0) + std::exp(-input[i]));
output[i] = O(1) / (O(1) + std::exp(-input[i]));
}
}
......
......@@ -28,7 +28,7 @@ class TanhImplForward_cpu
: public Registrable<TanhImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class TanhImplBackward_cpu
: public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
: public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {
};
class TanhImpl_cpu : public OperatorImpl {
......@@ -40,7 +40,10 @@ public:
}
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
};
namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_TANHIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_TANHIMPL_BACKWARD_KERNEL_H_
#include <cstddef> // std::size_t
#include "aidge/backend/cpu/operator/TanhImpl.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O, class GI, class GO>
void TanhImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_, const void* output_, const void* grad_output_,
void* grad_input_) {
const I* input = static_cast<const I*>(input_);
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
for (std::size_t i = 0; i < inputLenght; ++i) {
grad_input[i] = (O(1) - output[i] * output[i]) * grad_output[i];
}
}
namespace {
static Registrar<TanhImplBackward_cpu> registrarTanhImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::TanhImpl_cpu_backward_kernel<float, float, float, float>);
static Registrar<TanhImplBackward_cpu> registrarTanhImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::TanhImpl_cpu_backward_kernel<double, double, double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_TANHIMPL_BACKWARD_KERNEL_H_ */
......@@ -28,13 +28,15 @@ Aidge::Elts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t
}
void Aidge::ReLUImpl_cpu::forward() {
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({
in0->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
out0->dataType()});
// Call kernel
kernelFunc(in0->size(),
......@@ -43,20 +45,21 @@ void Aidge::ReLUImpl_cpu::forward() {
}
void Aidge::ReLUImpl_cpu::backward() {
// reversing in and out Tensors
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
AIDGE_ASSERT(out0, "current {} operator output#0 has not gradient Tensor.", op_.type());
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type());
// Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
gra_out0->dataType(),
gra_int0->dataType()
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
......@@ -21,6 +21,7 @@
#include "aidge/backend/cpu/operator/SigmoidImpl.hpp"
#include "aidge/backend/cpu/operator/SigmoidImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/SigmoidImpl_backward_kernels.hpp"
Aidge::Elts_t Aidge::SigmoidImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
......@@ -28,15 +29,38 @@ Aidge::Elts_t Aidge::SigmoidImpl_cpu::getNbRequiredProtected(const Aidge::IOInde
}
void Aidge::SigmoidImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
const Sigmoid_Op& op_ = dynamic_cast<const Sigmoid_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SigmoidImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
in0->dataType(),
out0->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
kernelFunc(in0->size(),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0)));
}
void Aidge::SigmoidImpl_cpu::backward() {
const Sigmoid_Op& op_ = dynamic_cast<const Sigmoid_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type());
// Find the correct kernel type
auto kernelFunc = Registrar<SigmoidImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
\ No newline at end of file
......@@ -21,6 +21,7 @@
#include "aidge/backend/cpu/operator/TanhImpl.hpp"
#include "aidge/backend/cpu/operator/TanhImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/TanhImpl_backward_kernels.hpp"
Aidge::Elts_t Aidge::TanhImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
......@@ -28,15 +29,39 @@ Aidge::Elts_t Aidge::TanhImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t
}
void Aidge::TanhImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
const Tanh_Op& op_ = dynamic_cast<const Tanh_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<TanhImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
in0->dataType(),
out0->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
kernelFunc(in0->size(),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0)));
}
void Aidge::TanhImpl_cpu::backward() {
const Tanh_Op& op_ = dynamic_cast<const Tanh_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
AIDGE_ASSERT(out0, "missing output #0 for current {} operator", op_.type());
// Find the correct kernel type
auto kernelFunc = Registrar<TanhImplBackward_cpu>::create({
in0->dataType(),
out0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(gra_int0->size(), getCPUPtr(in0), getCPUPtr(out0), getCPUPtr(gra_out0), getCPUPtr(gra_int0));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment