From 5d4b29671d6e9ccaa277fc80a9d2cbef973d691e Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 15 Feb 2024 14:37:46 +0000 Subject: [PATCH] [Add] backward kernel for ReLU, LeakyReLU & Producer --- .../backend/cpu/operator/LeakyReLUImpl.hpp | 13 ++++-- .../LeakyReLUImpl_backward_kernels.hpp | 45 ++++++++++++++++++ .../backend/cpu/operator/ProducerImpl.hpp | 6 ++- .../aidge/backend/cpu/operator/ReLUImpl.hpp | 13 ++++-- .../operator/ReLUImpl_backward_kernels.hpp | 45 ++++++++++++++++++ .../aidge/backend/cpu/operator/SqrtImpl.hpp | 14 ++++-- .../operator/SqrtImpl_backward_kernels.hpp | 46 +++++++++++++++++++ .../cpu/operator/SqrtImpl_forward_kernels.hpp | 8 ++-- src/operator/LeakyReLUImpl.cpp | 34 +++++++++++--- src/operator/ProducerImpl.cpp | 8 +--- src/operator/ReLUImpl.cpp | 31 ++++++++++--- src/operator/SqrtImpl.cpp | 37 +++++++++++---- 12 files changed, 251 insertions(+), 49 deletions(-) create mode 100644 include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp create mode 100644 include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp create mode 100644 include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp diff --git a/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp b/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp index 4a1da034..a9c87b4d 100644 --- a/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp +++ b/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp @@ -12,17 +12,17 @@ #ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ #define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ +#include <memory> +#include <tuple> +#include <vector> + #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/LeakyReLU.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" #include "aidge/backend/cpu/data/GetCPUPtr.h" -#include <memory> -#include <vector> namespace Aidge { -// class LeakyReLU_Op; - // compute kernel registry for forward and backward class LeakyReLUImplForward_cpu : public Registrable<LeakyReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const LeakyReLU_Op::Attrs&, std::size_t, const void*, void*)> { @@ -40,7 +40,10 @@ public: } NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; - void forward() override; + + void forward() override final; + + void backward() override final; }; namespace { diff --git a/include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp new file mode 100644 index 00000000..0e2fc400 --- /dev/null +++ b/include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp @@ -0,0 +1,45 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_ +#define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_ + +#include "aidge/utils/Registrar.hpp" + +#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" + +namespace Aidge { +template <class I, class O> +void LeakyReLUImpl_cpu_backward_kernel(const LeakyReLU_Op::Attrs& attrs, + std::size_t inputLenght, + const void* input_, + void* output_) { + + const I* input = static_cast<const I*>(input_); + O* output = static_cast<O*>(output_); + I negativeSlope = static_cast<I>(std::get<0>(attrs)); + + for (std::size_t i = 0; i < inputLenght; ++i) { + output[i] = input[i] > 0 ? 1 : negativeSlope; + } +} + +namespace { +static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float32( + {DataType::Float32, DataType::Float32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<float, float>); +static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Int32( + {DataType::Int32, DataType::Int32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<int, int>); +static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float64( + {DataType::Float64, DataType::Float64}, Aidge::LeakyReLUImpl_cpu_backward_kernel<double, double>); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_ */ diff --git a/include/aidge/backend/cpu/operator/ProducerImpl.hpp b/include/aidge/backend/cpu/operator/ProducerImpl.hpp index c1d27f7e..f1fc7a75 100644 --- a/include/aidge/backend/cpu/operator/ProducerImpl.hpp +++ b/include/aidge/backend/cpu/operator/ProducerImpl.hpp @@ -18,7 +18,6 @@ #include "aidge/operator/Producer.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" namespace Aidge { class ProducerImpl_cpu : public OperatorImpl { @@ -30,7 +29,10 @@ public: } NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - void forward() override; + + inline void forward() noexcept override final {} + + inline void backward() noexcept override final {} }; namespace { diff --git a/include/aidge/backend/cpu/operator/ReLUImpl.hpp b/include/aidge/backend/cpu/operator/ReLUImpl.hpp index 3338d0c4..7aff2937 100644 --- a/include/aidge/backend/cpu/operator/ReLUImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReLUImpl.hpp @@ -12,13 +12,15 @@ #ifndef AIDGE_CPU_OPERATOR_RELUIMPL_H_ #define AIDGE_CPU_OPERATOR_RELUIMPL_H_ +#include <cstddef> // std::size_t +#include <memory> +#include <tuple> // std::tuple +#include <vector> + #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" -#include <memory> -#include <vector> namespace Aidge { // class ReLU_Op; @@ -40,7 +42,10 @@ public: } NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; - void forward() override; + + void forward() override final; + + void backward() override final; }; namespace { diff --git a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp new file mode 100644 index 00000000..47d95ac4 --- /dev/null +++ b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp @@ -0,0 +1,45 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_ +#define AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_ + +#include <cstddef> // std::size_t + +#include "aidge/utils/Registrar.hpp" + +#include "aidge/backend/cpu/operator/ReLUImpl.hpp" + +namespace Aidge { +template <class I, class O> +void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght, + const void* input_, + void* output_) { + + const I* input = static_cast<const I*>(input_); + O* output = static_cast<O*>(output_); + + for (std::size_t i = 0; i < inputLenght; ++i) { + output[i] = (input[i] > I(0)) ? O(1) : O(0); + } +} + +namespace { +static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32( + {DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>); +static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32( + {DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>); +static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64( + {DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_ */ diff --git a/include/aidge/backend/cpu/operator/SqrtImpl.hpp b/include/aidge/backend/cpu/operator/SqrtImpl.hpp index b3723f27..a2c9a030 100644 --- a/include/aidge/backend/cpu/operator/SqrtImpl.hpp +++ b/include/aidge/backend/cpu/operator/SqrtImpl.hpp @@ -12,16 +12,17 @@ #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_ #define AIDGE_CPU_OPERATOR_SQRTIMPL_H_ +#include <cstddef> // std::size_t +#include <memory> +#include <tuple> +#include <vector> + #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Sqrt.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" -#include <memory> -#include <vector> namespace Aidge { -// class Sqrt_Op; // compute kernel registry for forward and backward class SqrtImplForward_cpu @@ -40,7 +41,10 @@ public: } NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; - void forward() override; + + void forward() override final; + + void backward() override final; }; namespace { diff --git a/include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp new file mode 100644 index 00000000..9cf5118a --- /dev/null +++ b/include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp @@ -0,0 +1,46 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_ +#define AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_ + +#include <cmath> // std::sqrt +#include <cstddef> // std::size_t + +#include "aidge/utils/Registrar.hpp" + +#include "aidge/backend/cpu/operator/SqrtImpl.hpp" + +namespace Aidge { +template <class I, class O> +void SqrtImpl_cpu_backward_kernel(const std::size_t inputLenght, + const void* input_, + void* output_) { + + const I* input = static_cast<const I*>(input_); + O* output = static_cast<O*>(output_); + + for (std::size_t i = 0; i < inputLenght; ++i) { + output[i] = static_cast<O>(0.5/(std::sqrt(static_cast<float>(input[i])))); + } +} + +namespace { +static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float32( + {DataType::Float32, DataType::Float32}, Aidge::SqrtImpl_cpu_backward_kernel<float, float>); +static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Int32( + {DataType::Int32, DataType::Int32}, Aidge::SqrtImpl_cpu_backward_kernel<int, int>); +static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float64( + {DataType::Float64, DataType::Float64}, Aidge::SqrtImpl_cpu_backward_kernel<double, double>); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_ */ diff --git a/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp index a180fc2c..886b978c 100644 --- a/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp @@ -12,14 +12,16 @@ #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_ #define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_ +#include <cmath> // std::sqrt +#include <cstddef> // std::size_t + #include "aidge/utils/Registrar.hpp" -#include <cmath> #include "aidge/backend/cpu/operator/SqrtImpl.hpp" namespace Aidge { template <class I, class O> -void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght, +void SqrtImpl_cpu_forward_kernel(const std::size_t inputLenght, const void* input_, void* output_) { @@ -27,7 +29,7 @@ void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght, O* output = static_cast<O*>(output_); for (std::size_t i = 0; i < inputLenght; ++i) { - output[i] = std::sqrt(input[i]); + output[i] = static_cast<O>(std::sqrt(static_cast<float>(input[i]))); } } diff --git a/src/operator/LeakyReLUImpl.cpp b/src/operator/LeakyReLUImpl.cpp index 17912eb1..4ffb230d 100644 --- a/src/operator/LeakyReLUImpl.cpp +++ b/src/operator/LeakyReLUImpl.cpp @@ -10,17 +10,17 @@ ********************************************************************************/ #include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for #include <vector> +#include "aidge/data/Tensor.hpp" #include "aidge/operator/LeakyReLU.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/Registrar.hpp" #include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp" +#include "aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp" Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place @@ -28,16 +28,36 @@ Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IO } void Aidge::LeakyReLUImpl_cpu::forward() { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); + std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0)); + std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)); + AIDGE_ASSERT(in0, "missing input #0"); // Find the correct kernel type auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({ - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + in0->dataType(), + out0->dataType()}); // Call kernel kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), + in0->size(), getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawOutput(0))); } + +void Aidge::LeakyReLUImpl_cpu::backward() { + // reversing in and out Data for backprop + std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)); + std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0)); + AIDGE_ASSERT(in0, "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({ + in0->dataType(), + out0->dataType()}); + + // Call kernel + kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(), + in0->size(), + getCPUPtr(in0), + getCPUPtr(out0)); +} \ No newline at end of file diff --git a/src/operator/ProducerImpl.cpp b/src/operator/ProducerImpl.cpp index 4c5883a9..d5432c0d 100644 --- a/src/operator/ProducerImpl.cpp +++ b/src/operator/ProducerImpl.cpp @@ -10,13 +10,11 @@ ********************************************************************************/ #include <cassert> -#include <numeric> // std::accumulate +#include <memory> #include <vector> #include "aidge/data/Tensor.hpp" -#include "aidge/operator/Producer.hpp" #include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/ProducerImpl.hpp" @@ -29,7 +27,3 @@ Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData( return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(); } - -void Aidge::ProducerImpl_cpu::forward() -{ -} diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp index 8863be28..84bb1045 100644 --- a/src/operator/ReLUImpl.cpp +++ b/src/operator/ReLUImpl.cpp @@ -9,18 +9,18 @@ * ********************************************************************************/ -#include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for +#include <memory> #include <vector> +#include "aidge/data/Tensor.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/utils/Types.h" #include "aidge/backend/cpu/data/GetCPUPtr.h" +#include "aidge/utils/ErrorHandling.hpp" #include "aidge/backend/cpu/operator/ReLUImpl.hpp" #include "aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp" +#include "aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp" Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place @@ -28,15 +28,32 @@ Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex } void Aidge::ReLUImpl_cpu::forward() { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); + std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0)); + AIDGE_ASSERT(in0, "missing input #0"); // Find the correct kernel type auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({ - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), + in0->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); // Call kernel - kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), + kernelFunc(in0->size(), getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawOutput(0))); } + +void Aidge::ReLUImpl_cpu::backward() { + // reversing in and out Tensors + std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->grad(); + std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->grad(); + AIDGE_ASSERT(out0, "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({ + in0->dataType(), + out0->dataType() + }); + + // Call kernel + kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0)); +} diff --git a/src/operator/SqrtImpl.cpp b/src/operator/SqrtImpl.cpp index 2766e8ae..ba9b57e8 100644 --- a/src/operator/SqrtImpl.cpp +++ b/src/operator/SqrtImpl.cpp @@ -9,18 +9,18 @@ * ********************************************************************************/ -#include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for +#include <memory> #include <vector> +#include "aidge/backend/cpu/data/GetCPUPtr.h" +#include "aidge/data/Tensor.hpp" #include "aidge/operator/Sqrt.hpp" +#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/SqrtImpl.hpp" #include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp" +#include "aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp" Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place @@ -28,15 +28,34 @@ Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex } void Aidge::SqrtImpl_cpu::forward() { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); + std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0)); + std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)); + AIDGE_ASSERT(in0, "missing input #0"); // Find the correct kernel type auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({ - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + in0->dataType(), + out0->dataType()}); // Call kernel - kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), + kernelFunc(in0->size(), getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawOutput(0))); +} + +void Aidge::SqrtImpl_cpu::backward() { + // reversing in and out Data for backprop + std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)); + std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0)); + AIDGE_ASSERT(out0, "missing output #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({ + in0->dataType(), + out0->dataType()}); + + // Call kernel + kernelFunc(in0->size(), + getCPUPtr(in0), + getCPUPtr(out0)); } \ No newline at end of file -- GitLab