Skip to content
Snippets Groups Projects
Commit e32b8199 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Pow backward kernel

parent aa30d4b3
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!90Fix/PowBackwardKernel
...@@ -28,7 +28,7 @@ class PowImplForward_cpu ...@@ -28,7 +28,7 @@ class PowImplForward_cpu
: public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
}; };
class PowImplBackward_cpu class PowImplBackward_cpu
: public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, std::size_t , const void*, const void*, const void*, const void*, void*, void*)> {
}; };
class PowImpl_cpu : public OperatorImpl { class PowImpl_cpu : public OperatorImpl {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include <cmath>
#include "aidge/backend/cpu/data/Broadcasting.hpp"
#include "aidge/backend/cpu/operator/PowImpl.hpp"
#include <iostream>
#include <vector>
namespace Aidge {
template <class I1, class I2, class O>
void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
const std::vector<std::size_t>& input1Dims,
const std::vector<std::size_t>& outputDims,
std::size_t totalElements,
const void* input0_,
const void* input1_,
const void* output_,
const void* gradOutput_,
void* gradientInput0_,
void* gradientInput1_) {
const I1* input0 = static_cast<const I1*>(input0_);
I1* grad0 = static_cast<I1*>(gradientInput0_);
const I2* input1 = static_cast<const I2*>(input1_);
I2* grad1 = static_cast<I2*>(gradientInput1_);
const O* output = static_cast<const O*>(output_);
const O* gradOut = static_cast<const O*>(gradOutput_);
for (size_t i = 0; i < totalElements; ++i)
{
std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i);
std::size_t idx0 = getFlattenedIndex(input0Dims, indexes);
std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
// grad0 = input1 * pow (input0, (input1 -1))
grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1);
// grad1 = grad_output * (output * ln(input0))
grad1[idx1] += gradOut[i]*output[i] * std::log(input0[idx0]);
}
}
namespace {
static Registrar<PowImplBackward_cpu> registrarPowImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::PowImpl_cpu_backward_kernel<float, float, float>);
static Registrar<PowImplBackward_cpu> registrarPowImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::PowImpl_cpu_backward_kernel<int, int, int>);
static Registrar<PowImplBackward_cpu> registrarPowImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::PowImpl_cpu_backward_kernel<double, double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_ */
...@@ -52,7 +52,7 @@ void Aidge::PowImpl_cpu::forward() { ...@@ -52,7 +52,7 @@ void Aidge::PowImpl_cpu::forward() {
void Aidge::PowImpl_cpu::backward() { void Aidge::PowImpl_cpu::backward() {
// Find the correct kernel type // Find the correct kernel type
const Pow_Op& op_ = dynamic_cast<const Pow_Op&>(mOp); const Pow_Op& op_ = dynamic_cast<const Pow_Op&>(mOp);
auto kernelFunc = Registrar<PowImplForward_cpu>::create({ auto kernelFunc = Registrar<PowImplBackward_cpu>::create({
op_.getOutput(0)->grad()->dataType(), op_.getOutput(0)->grad()->dataType(),
op_.getInput(0)->grad()->dataType(), op_.getInput(0)->grad()->dataType(),
op_.getInput(1)->grad()->dataType()}); op_.getInput(1)->grad()->dataType()});
...@@ -63,10 +63,14 @@ void Aidge::PowImpl_cpu::backward() { ...@@ -63,10 +63,14 @@ void Aidge::PowImpl_cpu::backward() {
op_.getOutput(0)->grad()->dims()); op_.getOutput(0)->grad()->dims());
// Call kernel // Call kernel
kernelFunc(op_.getOutput(0)->grad()->dims(), kernelFunc(input0gradDims,
input0gradDims,
input1gradDims, input1gradDims,
op_.getOutput(0)->grad()->dims(),
op_.getOutput(0)->size(),
getCPUPtr(mOp.getRawOutput(0)), getCPUPtr(mOp.getRawOutput(0)),
getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawInput(1))); getCPUPtr(mOp.getRawInput(1)),
getCPUPtr(op_.getOutput(0)->grad()),
getCPUPtr(op_.getInput(0)->grad()),
getCPUPtr(op_.getInput(1)->grad()));
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment