diff --git a/include/aidge/backend/cpu/operator/HeavisideImpl_kernels.hpp b/include/aidge/backend/cpu/operator/HeavisideImpl_kernels.hpp index 7fc0eb0a86210d896ad0ccec9b88150670f4e328..4e2a7db20ad1474115d31abcd4d899a59c476c7f 100644 --- a/include/aidge/backend/cpu/operator/HeavisideImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/HeavisideImpl_kernels.hpp @@ -42,14 +42,18 @@ void HeavisideImplCpuBackwardKernel(std::size_t inputLength, const void* grad_output_, void* grad_input_) { + /* + * Heaviside is approximated by an arctan function for backward : + * S ~= \frac{1}{\pi}\text{arctan}(\pi U \frac{\alpha}{2}) + * \frac{dS}{dU} = \frac{\alpha}{2} \frac{1}{(1+(\frac{\pi U \alpha}{2})^2)}} + * */ + const O* output = static_cast<const O*>(output_); const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); for (size_t i = 0; i < inputLength; ++i) { - // dx = dy * (1/PI) * (1 / (1 + (PI * x)^2)) - // grad_input[i] = (1 / M_PI) * grad_output[i] * static_cast<O>(1.0 / (1.0 + (output[i] * output[i]) * (M_PI * M_PI))); - grad_input[i] = grad_output[i] * static_cast<O>(1.0 / (1.0 + (output[i] * output[i]) * (M_PI * M_PI))); + grad_input[i] = grad_output[i] * static_cast<O>(1.0 / (1.0 + (output[i] * M_PI) * (output[i] * M_PI))); } } diff --git a/unit_tests/operator/Test_HeavisideImpl.cpp b/unit_tests/operator/Test_HeavisideImpl.cpp index 515d6802d56f2c4a90ecf0bc2d86f2e9a727aeed..e6aa38b88a37191cc7765ad9d3037bc733687481 100644 --- a/unit_tests/operator/Test_HeavisideImpl.cpp +++ b/unit_tests/operator/Test_HeavisideImpl.cpp @@ -117,13 +117,10 @@ TEST_CASE("[cpu/operator] Heaviside(backward)", "[Heaviside][CPU]") { op->setInput(IOIndex_t(0), std::make_shared<Tensor>(input)); op->forward(); - Log::info("Output : "); - op->getOutput(0)->print(); - op->getOutput(0)->setGrad(std::make_shared<Tensor>(grad)); op->backward(); - Log::info("Gradient : "); - op->getInput(0)->grad()->print(); + auto expectedResult = Tensor(Array1D<float,3>({0.0920, 0.0920, 0.0920})); + REQUIRE(approxEq<float>(*(op->getInput(0)->grad()), expectedResult)); } }