diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp index 514e63af5ae5d1d1d00f7f328f5367df2bfa163d..8d0c0831595cddd6047ccd9a5dfa1c444e9f0585 100644 --- a/include/aidge/backend/cpu/operator/PowImpl.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl.hpp @@ -28,7 +28,7 @@ class PowImplForward_cpu : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { }; class PowImplBackward_cpu - : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { + : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, std::size_t , const void*, const void*, const void*, const void*, void*, void*)> { }; class PowImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f3483eacf2bfd23b9c71a3e2bcef48904c991b84 --- /dev/null +++ b/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp @@ -0,0 +1,70 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_ +#define AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_ + +#include "aidge/utils/Registrar.hpp" +#include <cmath> + +#include "aidge/backend/cpu/data/Broadcasting.hpp" +#include "aidge/backend/cpu/operator/PowImpl.hpp" +#include <iostream> +#include <vector> + +namespace Aidge { +template <class I1, class I2, class O> +void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims, + const std::vector<std::size_t>& input1Dims, + const std::vector<std::size_t>& outputDims, + std::size_t totalElements, + const void* input0_, + const void* input1_, + const void* output_, + const void* gradOutput_, + void* gradientInput0_, + void* gradientInput1_) { + const I1* input0 = static_cast<const I1*>(input0_); + I1* grad0 = static_cast<I1*>(gradientInput0_); + const I2* input1 = static_cast<const I2*>(input1_); + I2* grad1 = static_cast<I2*>(gradientInput1_); + const O* output = static_cast<const O*>(output_); + const O* gradOut = static_cast<const O*>(gradOutput_); + + for (size_t i = 0; i < totalElements; ++i) + { + std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i); + + std::size_t idx0 = getFlattenedIndex(input0Dims, indexes); + std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); + + // grad0 = input1 * pow (input0, (input1 -1)) + grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1); + + // grad1 = grad_output * (output * ln(input0)) + grad1[idx1] += gradOut[i]*output[i] * std::log(input0[idx0]); + } +} + +namespace { +static Registrar<PowImplBackward_cpu> registrarPowImplBackward_cpu_Float32( + {DataType::Float32, DataType::Float32, DataType::Float32}, + Aidge::PowImpl_cpu_backward_kernel<float, float, float>); +static Registrar<PowImplBackward_cpu> registrarPowImplBackward_cpu_Int32( + {DataType::Int32, DataType::Int32, DataType::Int32}, + Aidge::PowImpl_cpu_backward_kernel<int, int, int>); +static Registrar<PowImplBackward_cpu> registrarPowImplBackward_cpu_Float64( + {DataType::Float64, DataType::Float64, DataType::Float64}, + Aidge::PowImpl_cpu_backward_kernel<double, double, double>); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_ */ diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp index 811d13804cffdd2477fc830f1779b0fb6271eb0b..97055e4b740558f16df2f29e4fb6272125b72c24 100644 --- a/src/operator/PowImpl.cpp +++ b/src/operator/PowImpl.cpp @@ -52,7 +52,7 @@ void Aidge::PowImpl_cpu::forward() { void Aidge::PowImpl_cpu::backward() { // Find the correct kernel type const Pow_Op& op_ = dynamic_cast<const Pow_Op&>(mOp); - auto kernelFunc = Registrar<PowImplForward_cpu>::create({ + auto kernelFunc = Registrar<PowImplBackward_cpu>::create({ op_.getOutput(0)->grad()->dataType(), op_.getInput(0)->grad()->dataType(), op_.getInput(1)->grad()->dataType()}); @@ -63,10 +63,14 @@ void Aidge::PowImpl_cpu::backward() { op_.getOutput(0)->grad()->dims()); // Call kernel - kernelFunc(op_.getOutput(0)->grad()->dims(), - input0gradDims, + kernelFunc(input0gradDims, input1gradDims, + op_.getOutput(0)->grad()->dims(), + op_.getOutput(0)->size(), getCPUPtr(mOp.getRawOutput(0)), getCPUPtr(mOp.getRawInput(0)), - getCPUPtr(mOp.getRawInput(1))); + getCPUPtr(mOp.getRawInput(1)), + getCPUPtr(op_.getOutput(0)->grad()), + getCPUPtr(op_.getInput(0)->grad()), + getCPUPtr(op_.getInput(1)->grad())); } \ No newline at end of file