Skip to content
Snippets Groups Projects
Commit 5d4b2967 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] backward kernel for ReLU, LeakyReLU & Producer

parent 2383b5e6
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!39Scheduler backprop
Pipeline #39199 canceled
Showing
with 251 additions and 49 deletions
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
#ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_
#define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ #define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/LeakyReLU.hpp" #include "aidge/operator/LeakyReLU.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge { namespace Aidge {
// class LeakyReLU_Op;
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class LeakyReLUImplForward_cpu class LeakyReLUImplForward_cpu
: public Registrable<LeakyReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const LeakyReLU_Op::Attrs&, std::size_t, const void*, void*)> { : public Registrable<LeakyReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const LeakyReLU_Op::Attrs&, std::size_t, const void*, void*)> {
...@@ -40,7 +40,10 @@ public: ...@@ -40,7 +40,10 @@ public:
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
}; };
namespace { namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
namespace Aidge {
template <class I, class O>
void LeakyReLUImpl_cpu_backward_kernel(const LeakyReLU_Op::Attrs& attrs,
std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
I negativeSlope = static_cast<I>(std::get<0>(attrs));
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = input[i] > 0 ? 1 : negativeSlope;
}
}
namespace {
static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<float, float>);
static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<int, int>);
static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::LeakyReLUImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_ */
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
namespace Aidge { namespace Aidge {
class ProducerImpl_cpu : public OperatorImpl { class ProducerImpl_cpu : public OperatorImpl {
...@@ -30,7 +29,10 @@ public: ...@@ -30,7 +29,10 @@ public:
} }
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void forward() override;
inline void forward() noexcept override final {}
inline void backward() noexcept override final {}
}; };
namespace { namespace {
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_RELUIMPL_H_
#define AIDGE_CPU_OPERATOR_RELUIMPL_H_ #define AIDGE_CPU_OPERATOR_RELUIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <tuple> // std::tuple
#include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge { namespace Aidge {
// class ReLU_Op; // class ReLU_Op;
...@@ -40,7 +42,10 @@ public: ...@@ -40,7 +42,10 @@ public:
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
}; };
namespace { namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
namespace Aidge {
template <class I, class O>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = (input[i] > I(0)) ? O(1) : O(0);
}
}
namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_ */
...@@ -12,16 +12,17 @@ ...@@ -12,16 +12,17 @@
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_H_ #define AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge { namespace Aidge {
// class Sqrt_Op;
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class SqrtImplForward_cpu class SqrtImplForward_cpu
...@@ -40,7 +41,10 @@ public: ...@@ -40,7 +41,10 @@ public:
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
}; };
namespace { namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
#include <cmath> // std::sqrt
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
namespace Aidge {
template <class I, class O>
void SqrtImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = static_cast<O>(0.5/(std::sqrt(static_cast<float>(input[i]))));
}
}
namespace {
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::SqrtImpl_cpu_backward_kernel<float, float>);
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::SqrtImpl_cpu_backward_kernel<int, int>);
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::SqrtImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_ */
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_ #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_ #define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#include <cmath> // std::sqrt
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include <cmath>
#include "aidge/backend/cpu/operator/SqrtImpl.hpp" #include "aidge/backend/cpu/operator/SqrtImpl.hpp"
namespace Aidge { namespace Aidge {
template <class I, class O> template <class I, class O>
void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght, void SqrtImpl_cpu_forward_kernel(const std::size_t inputLenght,
const void* input_, const void* input_,
void* output_) { void* output_) {
...@@ -27,7 +29,7 @@ void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght, ...@@ -27,7 +29,7 @@ void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght,
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) { for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = std::sqrt(input[i]); output[i] = static_cast<O>(std::sqrt(static_cast<float>(input[i])));
} }
} }
......
...@@ -10,17 +10,17 @@ ...@@ -10,17 +10,17 @@
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/LeakyReLU.hpp" #include "aidge/operator/LeakyReLU.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp"
Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
...@@ -28,16 +28,36 @@ Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IO ...@@ -28,16 +28,36 @@ Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IO
} }
void Aidge::LeakyReLUImpl_cpu::forward() { void Aidge::LeakyReLUImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({ auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), in0->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); out0->dataType()});
// Call kernel // Call kernel
kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(), kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), in0->size(),
getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
} }
void Aidge::LeakyReLUImpl_cpu::backward() {
// reversing in and out Data for backprop
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({
in0->dataType(),
out0->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(),
in0->size(),
getCPUPtr(in0),
getCPUPtr(out0));
}
\ No newline at end of file
...@@ -10,13 +10,11 @@ ...@@ -10,13 +10,11 @@
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <cassert>
#include <numeric> // std::accumulate #include <memory>
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/ProducerImpl.hpp" #include "aidge/backend/cpu/operator/ProducerImpl.hpp"
...@@ -29,7 +27,3 @@ Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData( ...@@ -29,7 +27,3 @@ Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData(
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(); return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size();
} }
void Aidge::ProducerImpl_cpu::forward()
{
}
...@@ -9,18 +9,18 @@ ...@@ -9,18 +9,18 @@
* *
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <memory>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp" #include "aidge/backend/cpu/operator/ReLUImpl.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp"
Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
...@@ -28,15 +28,32 @@ Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex ...@@ -28,15 +28,32 @@ Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex
} }
void Aidge::ReLUImpl_cpu::forward() { void Aidge::ReLUImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({ auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), in0->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel // Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), kernelFunc(in0->size(),
getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
} }
void Aidge::ReLUImpl_cpu::backward() {
// reversing in and out Tensors
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->grad();
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->grad();
AIDGE_ASSERT(out0, "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
in0->dataType(),
out0->dataType()
});
// Call kernel
kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0));
}
...@@ -9,18 +9,18 @@ ...@@ -9,18 +9,18 @@
* *
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <memory>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector> #include <vector>
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp" #include "aidge/backend/cpu/operator/SqrtImpl.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp"
Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
...@@ -28,15 +28,34 @@ Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex ...@@ -28,15 +28,34 @@ Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex
} }
void Aidge::SqrtImpl_cpu::forward() { void Aidge::SqrtImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({ auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), in0->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); out0->dataType()});
// Call kernel // Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), kernelFunc(in0->size(),
getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
}
void Aidge::SqrtImpl_cpu::backward() {
// reversing in and out Data for backprop
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
AIDGE_ASSERT(out0, "missing output #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
in0->dataType(),
out0->dataType()});
// Call kernel
kernelFunc(in0->size(),
getCPUPtr(in0),
getCPUPtr(out0));
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment