Skip to content
Snippets Groups Projects
Commit 28ca848f authored by Jerome Hue's avatar Jerome Hue Committed by Olivier BICHLER
Browse files

feat: Add surrogate backward function for Heaviside operator

parent da5f38fa
No related branches found
No related tags found
1 merge request!146Implement a backward for Heaviside
......@@ -19,7 +19,6 @@
#include "aidge/backend/cpu/operator/HeavisideImpl.hpp"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge {
template <class I, class O>
......@@ -35,6 +34,24 @@ void HeavisideImplCpuForwardKernel(std::size_t inputLength,
}
}
// Surrogate Gradient
template <class O, class GO, class GI>
void HeavisideImplCpuBackwardKernel(std::size_t inputLength,
const void* output_,
const void* grad_output_,
void* grad_input_) {
const O* output = static_cast<const O*>(output_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
for (size_t i = 0; i < inputLength; ++i) {
// dx = dy * (1/PI) * (1 / (1 + (PI * x)^2))
grad_input[i] = (1 / M_PI) * grad_output[i] * static_cast<O>(1.0 / (1.0 + output[i] * output[i]));
}
}
// Kernels registration to implementation entry point
REGISTRAR(HeavisideImplCpu,
{DataType::Float32},
......
......@@ -32,6 +32,23 @@ template <> void Aidge::HeavisideImplCpu::forward() {
op_.value());
}
template <> void Aidge::HeavisideImplCpu::backward() {
template <>
void Aidge::HeavisideImplCpu::backward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Heaviside backward not implemented yet");
// TODO: The following lines are assuming that the surrogate gradient is Atan
// remove that assumption by providing an attribute to Heaviside,
// allowing to choose between different surrogate gradients.
// const Heavisde_Op& op_ = dynamic_cast<const Heavisie_Op &>(mOp);
// ! backward of hs = forward of atan
//const auto impl = Registrar<HeavisideImplCpu>::create(getBestMatch(getRequiredSpec()));
// std::shared_ptr<Tensor> in0 = op_.getInput(0);
// std::shared_ptr<Tensor> out0 = op_.getOutput(0);
//impl.forward()
}
......@@ -11,6 +11,7 @@
#include "aidge/backend/cpu/operator/HeavisideImpl_kernels.hpp"
#include <aidge/operator/Memorize.hpp>
#include <memory>
#include <cstdlib>
#include <random>
......@@ -22,6 +23,8 @@
#include "aidge/graph/Node.hpp"
#include "aidge/utils/TensorUtils.hpp"
#include "aidge/operator/Add.hpp"
namespace Aidge
{
......@@ -95,4 +98,12 @@ TEST_CASE("[cpu/operator] Heaviside(forward)", "[Heaviside][CPU]") {
REQUIRE(approxEq<float>(*(op->getOutput(0)), *T1));
}
}
TEST_CASE("[cpu/operator] Heaviside(backward)", "[Heaviside][CPU]") {
auto add = Add();
auto mem = Memorize(2);
auto hs = Heaviside(1);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment