Skip to content
Snippets Groups Projects
Commit f43121fa authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

Fix Pow Backward kernel

parent e32b8199
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!90Fix/PowBackwardKernel
......@@ -28,7 +28,7 @@ class PowImplForward_cpu
: public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
};
class PowImplBackward_cpu
: public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, std::size_t , const void*, const void*, const void*, const void*, void*, void*)> {
: public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, const void*, void*, void*)> {
};
class PowImpl_cpu : public OperatorImpl {
......
......@@ -12,37 +12,41 @@
#ifndef AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include <cmath>
#include <numeric>
#include <vector>
#include "aidge/backend/cpu/data/Broadcasting.hpp"
#include "aidge/backend/cpu/operator/PowImpl.hpp"
#include <iostream>
#include <vector>
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I1, class I2, class O>
void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
const std::vector<std::size_t>& input1Dims,
const std::vector<std::size_t>& outputDims,
std::size_t totalElements,
const void* input0_,
const void* input1_,
const void* output_,
const void* gradOutput_,
const void* input0_,
const void* input1_,
const void* gradOutput_,
void* gradientInput0_,
void* gradientInput1_) {
const I1* input0 = static_cast<const I1*>(input0_);
I1* grad0 = static_cast<I1*>(gradientInput0_);
const I2* input1 = static_cast<const I2*>(input1_);
I2* grad1 = static_cast<I2*>(gradientInput1_);
const O* output = static_cast<const O*>(output_);
const O* gradOut = static_cast<const O*>(gradOutput_);
auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad0, grad0 + input0Elements, I1(0));
auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad1, grad1 + input1Elements, I1(0));
for (size_t i = 0; i < totalElements; ++i)
{
std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i);
std::size_t idx0 = getFlattenedIndex(input0Dims, indexes);
std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
......@@ -50,7 +54,7 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1);
// grad1 = grad_output * (output * ln(input0))
grad1[idx1] += gradOut[i]*output[i] * std::log(input0[idx0]);
grad1[idx1] += gradOut[i] * std::pow(input0[idx0], input1[idx1]) * std::log(input0[idx0]);
}
}
......
......@@ -21,6 +21,7 @@
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/PowImpl.hpp"
#include "aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp"
#include "aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp"
Aidge::Elts_t Aidge::PowImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
......@@ -57,17 +58,15 @@ void Aidge::PowImpl_cpu::backward() {
op_.getInput(0)->grad()->dataType(),
op_.getInput(1)->grad()->dataType()});
const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getInput(0)->grad()->dims(),
op_.getOutput(0)->grad()->dims());
const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getInput(1)->grad()->dims(),
op_.getOutput(0)->grad()->dims());
const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getOutput(0)->grad()->dims(),
op_.getInput(0)->grad()->dims());
const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getOutput(0)->grad()->dims(),
op_.getInput(1)->grad()->dims());
// Call kernel
kernelFunc(input0gradDims,
input1gradDims,
op_.getOutput(0)->grad()->dims(),
op_.getOutput(0)->size(),
getCPUPtr(mOp.getRawOutput(0)),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawInput(1)),
getCPUPtr(op_.getOutput(0)->grad()),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment