Skip to content
Snippets Groups Projects
Commit 5d4b2967 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] backward kernel for ReLU, LeakyReLU & Producer

parent 2383b5e6
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!39Scheduler backprop
Pipeline #39199 canceled
This commit is part of merge request !39. Comments created here will be created in the context of that merge request.
Showing
with 251 additions and 49 deletions
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
#ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_
#define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ #define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/LeakyReLU.hpp" #include "aidge/operator/LeakyReLU.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge { namespace Aidge {
// class LeakyReLU_Op;
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class LeakyReLUImplForward_cpu class LeakyReLUImplForward_cpu
: public Registrable<LeakyReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const LeakyReLU_Op::Attrs&, std::size_t, const void*, void*)> { : public Registrable<LeakyReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const LeakyReLU_Op::Attrs&, std::size_t, const void*, void*)> {
...@@ -40,7 +40,10 @@ public: ...@@ -40,7 +40,10 @@ public:
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
}; };
namespace { namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
namespace Aidge {
template <class I, class O>
void LeakyReLUImpl_cpu_backward_kernel(const LeakyReLU_Op::Attrs& attrs,
std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
I negativeSlope = static_cast<I>(std::get<0>(attrs));
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = input[i] > 0 ? 1 : negativeSlope;
}
}
namespace {
static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<float, float>);
static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<int, int>);
static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::LeakyReLUImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_ */
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
namespace Aidge { namespace Aidge {
class ProducerImpl_cpu : public OperatorImpl { class ProducerImpl_cpu : public OperatorImpl {
...@@ -30,7 +29,10 @@ public: ...@@ -30,7 +29,10 @@ public:
} }
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void forward() override;
inline void forward() noexcept override final {}
inline void backward() noexcept override final {}
}; };
namespace { namespace {
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_RELUIMPL_H_
#define AIDGE_CPU_OPERATOR_RELUIMPL_H_ #define AIDGE_CPU_OPERATOR_RELUIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <tuple> // std::tuple
#include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge { namespace Aidge {
// class ReLU_Op; // class ReLU_Op;
...@@ -40,7 +42,10 @@ public: ...@@ -40,7 +42,10 @@ public:
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
}; };
namespace { namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
namespace Aidge {
template <class I, class O>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = (input[i] > I(0)) ? O(1) : O(0);
}
}
namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_ */
...@@ -12,16 +12,17 @@ ...@@ -12,16 +12,17 @@
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_H_ #define AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge { namespace Aidge {
// class Sqrt_Op;
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class SqrtImplForward_cpu class SqrtImplForward_cpu
...@@ -40,7 +41,10 @@ public: ...@@ -40,7 +41,10 @@ public:
} }
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
}; };
namespace { namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
#include <cmath> // std::sqrt
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
namespace Aidge {
template <class I, class O>
void SqrtImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = static_cast<O>(0.5/(std::sqrt(static_cast<float>(input[i]))));
}
}
namespace {
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::SqrtImpl_cpu_backward_kernel<float, float>);
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::SqrtImpl_cpu_backward_kernel<int, int>);
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::SqrtImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_ */
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_ #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_ #define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#include <cmath> // std::sqrt
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include <cmath>
#include "aidge/backend/cpu/operator/SqrtImpl.hpp" #include "aidge/backend/cpu/operator/SqrtImpl.hpp"
namespace Aidge { namespace Aidge {
template <class I, class O> template <class I, class O>
void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght, void SqrtImpl_cpu_forward_kernel(const std::size_t inputLenght,
const void* input_, const void* input_,
void* output_) { void* output_) {
...@@ -27,7 +29,7 @@ void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght, ...@@ -27,7 +29,7 @@ void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght,
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) { for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = std::sqrt(input[i]); output[i] = static_cast<O>(std::sqrt(static_cast<float>(input[i])));
} }
} }
......
...@@ -10,17 +10,17 @@ ...@@ -10,17 +10,17 @@
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/LeakyReLU.hpp" #include "aidge/operator/LeakyReLU.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp"
Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
...@@ -28,16 +28,36 @@ Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IO ...@@ -28,16 +28,36 @@ Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IO
} }
void Aidge::LeakyReLUImpl_cpu::forward() { void Aidge::LeakyReLUImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({ auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), in0->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); out0->dataType()});
// Call kernel // Call kernel
kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(), kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), in0->size(),
getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
} }
void Aidge::LeakyReLUImpl_cpu::backward() {
// reversing in and out Data for backprop
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({
in0->dataType(),
out0->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(),
in0->size(),
getCPUPtr(in0),
getCPUPtr(out0));
}
\ No newline at end of file
...@@ -10,13 +10,11 @@ ...@@ -10,13 +10,11 @@
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <cassert>
#include <numeric> // std::accumulate #include <memory>
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/ProducerImpl.hpp" #include "aidge/backend/cpu/operator/ProducerImpl.hpp"
...@@ -29,7 +27,3 @@ Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData( ...@@ -29,7 +27,3 @@ Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData(
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(); return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size();
} }
void Aidge::ProducerImpl_cpu::forward()
{
}
...@@ -9,18 +9,18 @@ ...@@ -9,18 +9,18 @@
* *
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <memory>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp" #include "aidge/backend/cpu/operator/ReLUImpl.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp"
Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
...@@ -28,15 +28,32 @@ Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex ...@@ -28,15 +28,32 @@ Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex
} }
void Aidge::ReLUImpl_cpu::forward() { void Aidge::ReLUImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({ auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), in0->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel // Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), kernelFunc(in0->size(),
getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
} }
void Aidge::ReLUImpl_cpu::backward() {
// reversing in and out Tensors
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->grad();
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->grad();
AIDGE_ASSERT(out0, "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
in0->dataType(),
out0->dataType()
});
// Call kernel
kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0));
}
...@@ -9,18 +9,18 @@ ...@@ -9,18 +9,18 @@
* *
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <memory>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector> #include <vector>
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp" #include "aidge/backend/cpu/operator/SqrtImpl.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp"
Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place // this implementation can be in-place
...@@ -28,15 +28,34 @@ Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex ...@@ -28,15 +28,34 @@ Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex
} }
void Aidge::SqrtImpl_cpu::forward() { void Aidge::SqrtImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
AIDGE_ASSERT(in0, "missing input #0");
// Find the correct kernel type // Find the correct kernel type
auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({ auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), in0->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); out0->dataType()});
// Call kernel // Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), kernelFunc(in0->size(),
getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
}
void Aidge::SqrtImpl_cpu::backward() {
// reversing in and out Data for backprop
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
AIDGE_ASSERT(out0, "missing output #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
in0->dataType(),
out0->dataType()});
// Call kernel
kernelFunc(in0->size(),
getCPUPtr(in0),
getCPUPtr(out0));
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment