Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • mszczep/aidge_backend_cpu
  • eclipse/aidge/aidge_backend_cpu
  • hrouis/aidge_backend_cpu
  • oantoni/aidge_backend_cpu
  • raphaelmillet/aidge_backend_cpu
  • cguillon/aidge_backend_cpu
  • jeromeh/aidge_backend_cpu
  • axelfarr/aidge_backend_cpu
  • noamzerah/aidge_backend_cpu
  • silvanosky/aidge_backend_cpu
  • maab05/aidge_backend_cpu
  • lucaslopez/aidge_backend_cpu_ll
  • farnez/aidge_backend_cpu
13 results
Show changes
Showing
with 381 additions and 187 deletions
......@@ -12,13 +12,15 @@
#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_H_
#define AIDGE_CPU_OPERATOR_RELUIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <tuple> // std::tuple
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
// class ReLU_Op;
......@@ -33,14 +35,17 @@ class ReLUImplBackward_cpu
class ReLUImpl_cpu : public OperatorImpl {
public:
ReLUImpl_cpu(const ReLU_Op& op) : OperatorImpl(op) {}
ReLUImpl_cpu(const ReLU_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<ReLUImpl_cpu> create(const ReLU_Op& op) {
return std::make_unique<ReLUImpl_cpu>(op);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
};
namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
namespace Aidge {
template <class I, class O>
void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = (input[i] > I(0)) ? static_cast<O>(input[i]) : O(0);
}
}
namespace {
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>);
static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_ */
......@@ -25,55 +25,22 @@
namespace Aidge {
// class ReduceMean_Op;
// compute kernel registry for forward and backward
// DIM 1
class ReduceMeanImpl1DForward_cpu
: public Registrable<ReduceMeanImpl1DForward_cpu,
// Every DIM
class ReduceMeanImplForward_cpu
: public Registrable<ReduceMeanImplForward_cpu,
std::tuple<DataType, DataType>,
void(const ReduceMean_Op<1>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
void(const ReduceMean_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
class ReduceMeanImpl1DBackward_cpu
: public Registrable<ReduceMeanImpl1DBackward_cpu,
std::tuple<DataType, DataType>,
void(const ReduceMean_Op<1>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
void(const ReduceMean_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// DIM 2
class ReduceMeanImpl2DForward_cpu
: public Registrable<ReduceMeanImpl2DForward_cpu,
std::tuple<DataType, DataType>,
void(const ReduceMean_Op<2>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
class ReduceMeanImpl2DBackward_cpu
: public Registrable<ReduceMeanImpl2DBackward_cpu,
std::tuple<DataType, DataType>,
void(const ReduceMean_Op<2>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// DIM 3
class ReduceMeanImpl3DForward_cpu
: public Registrable<ReduceMeanImpl3DForward_cpu,
std::tuple<DataType, DataType>,
void(const ReduceMean_Op<3>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
class ReduceMeanImpl3DBackward_cpu
: public Registrable<ReduceMeanImpl3DBackward_cpu,
std::tuple<DataType, DataType>,
void(const ReduceMean_Op<3>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
class ReduceMeanImpl1D_cpu : public OperatorImpl {
public:
ReduceMeanImpl1D_cpu(const ReduceMean_Op<1>& op) : OperatorImpl(op) {}
static std::unique_ptr<ReduceMeanImpl1D_cpu> create(const ReduceMean_Op<1> &op) {
return std::make_unique<ReduceMeanImpl1D_cpu>(op);
}
public:
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
class ReduceMeanImpl2D_cpu : public OperatorImpl {
class ReduceMeanImpl_cpu : public OperatorImpl {
public:
ReduceMeanImpl2D_cpu(const ReduceMean_Op<2>& op) : OperatorImpl(op) {}
ReduceMeanImpl_cpu(const ReduceMean_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<ReduceMeanImpl2D_cpu> create(const ReduceMean_Op<2> &op) {
return std::make_unique<ReduceMeanImpl2D_cpu>(op);
static std::unique_ptr<ReduceMeanImpl_cpu> create(const ReduceMean_Op &op) {
return std::make_unique<ReduceMeanImpl_cpu>(op);
}
public:
......@@ -81,23 +48,80 @@ class ReduceMeanImpl2D_cpu : public OperatorImpl {
void forward() override;
};
class ReduceMeanImpl3D_cpu : public OperatorImpl {
public:
ReduceMeanImpl3D_cpu(const ReduceMean_Op<3>& op) : OperatorImpl(op) {}
static std::unique_ptr<ReduceMeanImpl3D_cpu> create(const ReduceMean_Op<3> &op) {
return std::make_unique<ReduceMeanImpl3D_cpu>(op);
}
public:
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
// // compute kernel registry for forward and backward
// // DIM 1
// class ReduceMeanImpl1DForward_cpu
// : public Registrable<ReduceMeanImpl1DForward_cpu,
// std::tuple<DataType, DataType>,
// void(const ReduceMean_Op<1>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// class ReduceMeanImpl1DBackward_cpu
// : public Registrable<ReduceMeanImpl1DBackward_cpu,
// std::tuple<DataType, DataType>,
// void(const ReduceMean_Op<1>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// // DIM 2
// class ReduceMeanImpl2DForward_cpu
// : public Registrable<ReduceMeanImpl2DForward_cpu,
// std::tuple<DataType, DataType>,
// void(const ReduceMean_Op<2>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// class ReduceMeanImpl2DBackward_cpu
// : public Registrable<ReduceMeanImpl2DBackward_cpu,
// std::tuple<DataType, DataType>,
// void(const ReduceMean_Op<2>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// // DIM 3
// class ReduceMeanImpl3DForward_cpu
// : public Registrable<ReduceMeanImpl3DForward_cpu,
// std::tuple<DataType, DataType>,
// void(const ReduceMean_Op<3>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// class ReduceMeanImpl3DBackward_cpu
// : public Registrable<ReduceMeanImpl3DBackward_cpu,
// std::tuple<DataType, DataType>,
// void(const ReduceMean_Op<3>::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// class ReduceMeanImpl1D_cpu : public OperatorImpl {
// public:
// ReduceMeanImpl1D_cpu(const ReduceMean_Op<1>& op) : OperatorImpl(op, "cpu") {}
// static std::unique_ptr<ReduceMeanImpl1D_cpu> create(const ReduceMean_Op<1> &op) {
// return std::make_unique<ReduceMeanImpl1D_cpu>(op);
// }
// public:
// NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
// void forward() override;
// };
// class ReduceMeanImpl2D_cpu : public OperatorImpl {
// public:
// ReduceMeanImpl2D_cpu(const ReduceMean_Op<2>& op) : OperatorImpl(op, "cpu") {}
// static std::unique_ptr<ReduceMeanImpl2D_cpu> create(const ReduceMean_Op<2> &op) {
// return std::make_unique<ReduceMeanImpl2D_cpu>(op);
// }
// public:
// NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
// void forward() override;
// };
// class ReduceMeanImpl3D_cpu : public OperatorImpl {
// public:
// ReduceMeanImpl3D_cpu(const ReduceMean_Op<3>& op) : OperatorImpl(op, "cpu") {}
// static std::unique_ptr<ReduceMeanImpl3D_cpu> create(const ReduceMean_Op<3> &op) {
// return std::make_unique<ReduceMeanImpl3D_cpu>(op);
// }
// public:
// NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
// void forward() override;
// };
namespace {
// add cpu backend to ReduceMean_Op<2> implementation registry
static Registrar<ReduceMean_Op<1>> registrarReduceMeanImpl1D_cpu("cpu", Aidge::ReduceMeanImpl1D_cpu::create);
static Registrar<ReduceMean_Op<2>> registrarReduceMeanImpl2D_cpu("cpu", Aidge::ReduceMeanImpl2D_cpu::create);
static Registrar<ReduceMean_Op<3>> registrarReduceMeanImpl3D_cpu("cpu", Aidge::ReduceMeanImpl3D_cpu::create);
static Registrar<ReduceMean_Op> registrarReduceMeanImpl_cpu("cpu", Aidge::ReduceMeanImpl_cpu::create);
// static Registrar<ReduceMean_Op<1>> registrarReduceMeanImpl1D_cpu("cpu", Aidge::ReduceMeanImpl1D_cpu::create);
// static Registrar<ReduceMean_Op<2>> registrarReduceMeanImpl2D_cpu("cpu", Aidge::ReduceMeanImpl2D_cpu::create);
// static Registrar<ReduceMean_Op<3>> registrarReduceMeanImpl3D_cpu("cpu", Aidge::ReduceMeanImpl3D_cpu::create);
} // namespace
} // namespace Aidge
......
......@@ -12,10 +12,12 @@
#ifndef AIDGE_CPU_OPERATOR_REDUCEMEANIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_REDUCEMEANIMPL_FORWARD_KERNEL_H_
#include <cstddef>
#include <algorithm> // std::copy, std::for_each
#include <numeric> //std::accumulate
#include <algorithm> // std::for_each
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <functional> //std::multiplies
#include <numeric> //std::accumulate
#include <vector>
#include "aidge/backend/cpu/operator/ReduceMeanImpl.hpp"
#include "aidge/data/Data.hpp"
......@@ -23,8 +25,8 @@
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O, DimSize_t DIM>
void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs& attrs,
template <class I, class O>
void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op::Attrs& attrs,
const std::vector<DimSize_t>& inputDims,
const void* input_,
void* output_) {
......@@ -32,14 +34,15 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs&
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const std::vector<std::int32_t>& axes = std::get<0>(attrs);
const std::size_t nb_dims = inputDims.size();
const std::size_t totalElements = std::accumulate(inputDims.cbegin(), inputDims.cend(), 1, std::multiplies<std::size_t>());
if (DIM == 1) {
const std::size_t stride_pre = std::accumulate(inputDims.cbegin(), inputDims.cbegin() + std::get<0>(attrs)[0], 1, std::multiplies<std::size_t>());
const std::size_t stride_post = std::accumulate(inputDims.crbegin(), inputDims.crbegin() + nb_dims -1 - std::get<0>(attrs)[0], 1, std::multiplies<std::size_t>());
if (axes.size() == 1) {
const std::size_t stride_pre = std::accumulate(inputDims.cbegin(), inputDims.cbegin() + axes[0], 1, std::multiplies<std::size_t>());
const std::size_t stride_post = std::accumulate(inputDims.crbegin(), inputDims.crbegin() + nb_dims -1 - axes[0], 1, std::multiplies<std::size_t>());
const std::size_t dim_i = inputDims[std::get<0>(attrs)[0]];
const std::size_t dim_i = inputDims[axes[0]];
for (std::size_t pre = 0; pre < stride_pre; ++pre) {
for (std::size_t post = 0; post < stride_post; ++post) {
const std::size_t idx_i = pre * dim_i * stride_post + post;
......@@ -68,7 +71,7 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs&
const I* inputAccumulation = input;
I* outputAccumulation = nullptr;
for (const auto& axisInt : std::get<0>(attrs)) {
for (const auto& axisInt : axes) {
const std::size_t a = static_cast<std::size_t>(axisInt);
outputElements /= inputDims[a];
outputAccumulation = new I[outputElements];
......@@ -93,7 +96,7 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs&
// Copy elements from inputAccumulation to output while dividing by divisor
I divisor = totalElements / outputElements;
std::transform(inputAccumulation, inputAccumulation + outputElements, output,
[divisor](int element) { return element / divisor; });
[divisor](I element) { return element / divisor; });
if (outputAccumulation) {
delete[] outputAccumulation;
}
......@@ -103,29 +106,36 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op<DIM>::Attrs&
}
namespace {
// DIM = 1
static Registrar<ReduceMeanImpl1DForward_cpu> registrarReduceMeanImplForward_1D_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<float, float,1>);
static Registrar<ReduceMeanImpl1DForward_cpu> registrarReduceMeanImplForward_1D_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<int, int,1>);
static Registrar<ReduceMeanImpl1DForward_cpu> registrarReduceMeanImplForward_1D_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReduceMeanImpl_cpu_forward_kernel<double, double,1>);
// DIM = 2
static Registrar<ReduceMeanImpl2DForward_cpu> registrarReduceMeanImplForward_2D_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<float, float,2>);
static Registrar<ReduceMeanImpl2DForward_cpu> registrarReduceMeanImplForward_2D_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<int, int,2>);
static Registrar<ReduceMeanImpl2DForward_cpu> registrarReduceMeanImplForward_2D_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReduceMeanImpl_cpu_forward_kernel<double, double,2>);
// DIM = 3
static Registrar<ReduceMeanImpl3DForward_cpu> registrarReduceMeanImplForward_3D_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<float, float,3>);
static Registrar<ReduceMeanImpl3DForward_cpu> registrarReduceMeanImplForward_3D_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<int, int,3>);
static Registrar<ReduceMeanImpl3DForward_cpu> registrarReduceMeanImplForward_3D_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReduceMeanImpl_cpu_forward_kernel<double, double,3>);
static Registrar<ReduceMeanImplForward_cpu> registrarReduceMeanImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<float, float>);
static Registrar<ReduceMeanImplForward_cpu> registrarReduceMeanImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<int, int>);
static Registrar<ReduceMeanImplForward_cpu> registrarReduceMeanImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ReduceMeanImpl_cpu_forward_kernel<double, double>);
// // DIM = 1
// static Registrar<ReduceMeanImpl1DForward_cpu> registrarReduceMeanImplForward_1D_cpu_Float32(
// {DataType::Float32, DataType::Float32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<float, float,1>);
// static Registrar<ReduceMeanImpl1DForward_cpu> registrarReduceMeanImplForward_1D_cpu_Int32(
// {DataType::Int32, DataType::Int32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<int, int,1>);
// static Registrar<ReduceMeanImpl1DForward_cpu> registrarReduceMeanImplForward_1D_cpu_Float64(
// {DataType::Float64, DataType::Float64}, Aidge::ReduceMeanImpl_cpu_forward_kernel<double, double,1>);
// // DIM = 2
// static Registrar<ReduceMeanImpl2DForward_cpu> registrarReduceMeanImplForward_2D_cpu_Float32(
// {DataType::Float32, DataType::Float32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<float, float,2>);
// static Registrar<ReduceMeanImpl2DForward_cpu> registrarReduceMeanImplForward_2D_cpu_Int32(
// {DataType::Int32, DataType::Int32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<int, int,2>);
// static Registrar<ReduceMeanImpl2DForward_cpu> registrarReduceMeanImplForward_2D_cpu_Float64(
// {DataType::Float64, DataType::Float64}, Aidge::ReduceMeanImpl_cpu_forward_kernel<double, double,2>);
// // DIM = 3
// static Registrar<ReduceMeanImpl3DForward_cpu> registrarReduceMeanImplForward_3D_cpu_Float32(
// {DataType::Float32, DataType::Float32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<float, float,3>);
// static Registrar<ReduceMeanImpl3DForward_cpu> registrarReduceMeanImplForward_3D_cpu_Int32(
// {DataType::Int32, DataType::Int32}, Aidge::ReduceMeanImpl_cpu_forward_kernel<int, int,3>);
// static Registrar<ReduceMeanImpl3DForward_cpu> registrarReduceMeanImplForward_3D_cpu_Float64(
// {DataType::Float64, DataType::Float64}, Aidge::ReduceMeanImpl_cpu_forward_kernel<double, double,3>);
} // namespace
} // namespace Aidge
......
......@@ -32,7 +32,7 @@ class ReshapeImplBackward_cpu
class ReshapeImpl_cpu : public OperatorImpl {
public:
ReshapeImpl_cpu(const Reshape_Op& op) : OperatorImpl(op) {}
ReshapeImpl_cpu(const Reshape_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<ReshapeImpl_cpu> create(const Reshape_Op& op) {
return std::make_unique<ReshapeImpl_cpu>(op);
......
......@@ -34,7 +34,7 @@ class ScalingImplBackward_cpu
class ScalingImpl_cpu : public OperatorImpl {
public:
ScalingImpl_cpu(const Scaling_Op& op) : OperatorImpl(op) {}
ScalingImpl_cpu(const Scaling_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<ScalingImpl_cpu> create(const Scaling_Op& op) {
return std::make_unique<ScalingImpl_cpu>(op);
......
......@@ -33,7 +33,7 @@ class SigmoidImplBackward_cpu
class SigmoidImpl_cpu : public OperatorImpl {
public:
SigmoidImpl_cpu(const Sigmoid_Op& op) : OperatorImpl(op) {}
SigmoidImpl_cpu(const Sigmoid_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<SigmoidImpl_cpu> create(const Sigmoid_Op& op) {
return std::make_unique<SigmoidImpl_cpu>(op);
......
......@@ -40,7 +40,7 @@ class SliceImplBackward_cpu
class SliceImpl_cpu : public OperatorImpl {
public:
SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {}
SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<SliceImpl_cpu> create(const Slice_Op& op) {
return std::make_unique<SliceImpl_cpu>(op);
......
......@@ -33,7 +33,7 @@ class SoftmaxImplBackward_cpu
class SoftmaxImpl_cpu : public OperatorImpl {
public:
SoftmaxImpl_cpu(const Softmax_Op& op) : OperatorImpl(op) {}
SoftmaxImpl_cpu(const Softmax_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<SoftmaxImpl_cpu> create(const Softmax_Op& op) {
return std::make_unique<SoftmaxImpl_cpu>(op);
......
......@@ -12,16 +12,17 @@
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Sqrt_Op;
// compute kernel registry for forward and backward
class SqrtImplForward_cpu
......@@ -33,14 +34,17 @@ class SqrtImplBackward_cpu
class SqrtImpl_cpu : public OperatorImpl {
public:
SqrtImpl_cpu(const Sqrt_Op& op) : OperatorImpl(op) {}
SqrtImpl_cpu(const Sqrt_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<SqrtImpl_cpu> create(const Sqrt_Op& op) {
return std::make_unique<SqrtImpl_cpu>(op);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void forward() override final;
void backward() override final;
};
namespace {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
#include <cmath> // std::sqrt
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
namespace Aidge {
template <class I, class O>
void SqrtImpl_cpu_backward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = static_cast<O>(0.5/(std::sqrt(static_cast<float>(input[i]))));
}
}
namespace {
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::SqrtImpl_cpu_backward_kernel<float, float>);
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::SqrtImpl_cpu_backward_kernel<int, int>);
static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::SqrtImpl_cpu_backward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_ */
......@@ -12,14 +12,16 @@
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#include <cmath> // std::sqrt
#include <cstddef> // std::size_t
#include "aidge/utils/Registrar.hpp"
#include <cmath>
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
namespace Aidge {
template <class I, class O>
void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght,
void SqrtImpl_cpu_forward_kernel(const std::size_t inputLenght,
const void* input_,
void* output_) {
......@@ -27,7 +29,7 @@ void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght,
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = std::sqrt(input[i]);
output[i] = static_cast<O>(std::sqrt(static_cast<float>(input[i])));
}
}
......
......@@ -33,7 +33,7 @@ class SubImplBackward_cpu
class SubImpl_cpu : public OperatorImpl {
public:
SubImpl_cpu(const Sub_Op& op) : OperatorImpl(op) {}
SubImpl_cpu(const Sub_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<SubImpl_cpu> create(const Sub_Op& op) {
return std::make_unique<SubImpl_cpu>(op);
......
......@@ -33,7 +33,7 @@ class TanhImplBackward_cpu
class TanhImpl_cpu : public OperatorImpl {
public:
TanhImpl_cpu(const Tanh_Op& op) : OperatorImpl(op) {}
TanhImpl_cpu(const Tanh_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<TanhImpl_cpu> create(const Tanh_Op& op) {
return std::make_unique<TanhImpl_cpu>(op);
......
......@@ -57,7 +57,7 @@ class TransposeImpl6DBackward_cpu
class TransposeImpl2D_cpu : public OperatorImpl {
public:
TransposeImpl2D_cpu(const Transpose_Op<2>& op) : OperatorImpl(op) {}
TransposeImpl2D_cpu(const Transpose_Op<2>& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<TransposeImpl2D_cpu> create(const Transpose_Op<2>& op) {
return std::make_unique<TransposeImpl2D_cpu>(op);
......@@ -68,7 +68,7 @@ public:
};
class TransposeImpl3D_cpu : public OperatorImpl {
public:
TransposeImpl3D_cpu(const Transpose_Op<3>& op) : OperatorImpl(op) {}
TransposeImpl3D_cpu(const Transpose_Op<3>& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<TransposeImpl3D_cpu> create(const Transpose_Op<3>& op) {
return std::make_unique<TransposeImpl3D_cpu>(op);
......@@ -79,7 +79,7 @@ public:
};
class TransposeImpl4D_cpu : public OperatorImpl {
public:
TransposeImpl4D_cpu(const Transpose_Op<4>& op) : OperatorImpl(op) {}
TransposeImpl4D_cpu(const Transpose_Op<4>& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<TransposeImpl4D_cpu> create(const Transpose_Op<4>& op) {
return std::make_unique<TransposeImpl4D_cpu>(op);
......@@ -90,7 +90,7 @@ public:
};
class TransposeImpl5D_cpu : public OperatorImpl {
public:
TransposeImpl5D_cpu(const Transpose_Op<5>& op) : OperatorImpl(op) {}
TransposeImpl5D_cpu(const Transpose_Op<5>& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<TransposeImpl5D_cpu> create(const Transpose_Op<5>& op) {
return std::make_unique<TransposeImpl5D_cpu>(op);
......@@ -101,7 +101,7 @@ public:
};
class TransposeImpl6D_cpu : public OperatorImpl {
public:
TransposeImpl6D_cpu(const Transpose_Op<6>& op) : OperatorImpl(op) {}
TransposeImpl6D_cpu(const Transpose_Op<6>& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<TransposeImpl6D_cpu> create(const Transpose_Op<6>& op) {
return std::make_unique<TransposeImpl6D_cpu>(op);
......
......@@ -9,17 +9,18 @@
*
********************************************************************************/
#include "aidge/backend/cpu/operator/AddImpl.hpp"
#include <cassert>
#include <numeric> // std::accumulate
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/AddImpl_forward_kernels.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/cpu/operator/AddImpl.hpp"
#include "aidge/backend/cpu/operator/AddImpl_forward_kernels.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::NbElts_t Aidge::AddImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
......@@ -27,15 +28,18 @@ Aidge::NbElts_t Aidge::AddImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex
}
void Aidge::AddImpl_cpu::forward() {
assert(mOp.getRawInput(0) && "missing input in Add operator");
DataType datatypeFirstInput = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType();
for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) {
assert(mOp.getRawInput(i) && "missing input in Add operator");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->dataType() == datatypeFirstInput);
const auto& opTensor = static_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(opTensor.getInput(0)->hasImpl(), "cannot run Add forward because the 0-th input has no implementation.");
assert(opTensor.getInput(0) && "missing input in Add operator");
DataType datatypeFirstInput = opTensor.getInput(0)->dataType();
for (IOIndex_t i = 1; i < opTensor.nbInputs(); ++i) {
AIDGE_ASSERT(opTensor.getInput(i)->hasImpl(), "cannot run Add forward because the {}-th input has no implementation.", i);
assert(opTensor.getInput(i) && "missing input in Add operator");
assert(opTensor.getInput(i)->dataType() == datatypeFirstInput);
}
// Find the correct kernel type
const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
const auto outputDataType = opTensor.getOutput(0)->dataType();
const Registrar<AddImplForward_cpu>::registrar_key registrarKey = {
datatypeFirstInput,
outputDataType};
......@@ -55,26 +59,26 @@ void Aidge::AddImpl_cpu::forward() {
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::size_t nbDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->nbDims();
const std::size_t nbDims = opTensor.getOutput(0)->nbDims();
std::vector<std::vector<std::size_t>> inputsDims;
std::vector<const void*> opInputs;
std::vector<std::shared_ptr<Tensor>> inputsFallback(mOp.nbInputs());
for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
std::vector<std::shared_ptr<Tensor>> inputsFallback(opTensor.nbInputs());
for (IOIndex_t i = 0; i < opTensor.nbInputs(); ++i) {
std::vector<std::size_t> inputDims(nbDims, 1);
auto dims = std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->dims();
auto dims = opTensor.getInput(i)->dims();
for(std::size_t j=dims.size()-1; j+1>0; --j)
{
std::size_t idx = nbDims - (dims.size()-j);
inputDims[idx] = dims[j];
}
inputsDims.push_back(inputDims);
const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->refCastFrom(inputsFallback[i], *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input = opTensor.getInput(i)->refCastFrom(inputsFallback[i], *opTensor.getOutput(0));
opInputs.push_back(input.getImpl()->rawPtr());
}
kernelFunc(opInputs,
inputsDims,
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
getCPUPtr(mOp.getRawOutput(0)));
opTensor.getOutput(0)->size(),
opTensor.getOutput(0)->dims(),
getCPUPtr(opTensor.getRawOutput(0)));
}
......@@ -28,17 +28,19 @@ Aidge::NbElts_t Aidge::ConvImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputI
}
void Aidge::ConvImpl2D_cpu::forward() {
const auto& opTensor = static_cast<const OperatorTensor&>(mOp);
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
// Find the correct kernel type
const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
const auto outputDataType = opTensor.getOutput(0)->dataType();
const Registrar<ConvImpl2DForward_cpu>::registrar_key registrarKey = {
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
opTensor.getInput(0)->dataType(),
opTensor.getInput(1)->dataType(),
opTensor.getInput(2)->dataType(),
outputDataType};
Registrar<ConvImpl2DForward_cpu>::registrar_type kernelFunc;
......@@ -57,12 +59,12 @@ void Aidge::ConvImpl2D_cpu::forward() {
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input0 = opTensor.getInput(0)->refCastFrom(input0Fallback, *opTensor.getOutput(0));
const auto& input1 = opTensor.getInput(1)->refCastFrom(input1Fallback, *opTensor.getOutput(0));
const auto& input2 = opTensor.getInput(2)->refCastFrom(input2Fallback, *opTensor.getOutput(0));
// Call kernel
kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), opTensor.getInput(0)->template dims<4>(),
input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
getCPUPtr(mOp.getRawOutput(0)));
}
......@@ -57,17 +57,18 @@ void Aidge::DivImpl_cpu::forward() {
// 3. Compute the highest number of contiguous data -> 7
// 4. Compute stride and offset step for the broadcast mechnism
// 5. Call a simple kernel
const auto& opTensor = static_cast<const Div_Op&>(mOp);
// Find the correct kernel type
auto kernelFunc = Registrar<DivImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
opTensor.getInput(0)->dataType(),
opTensor.getInput(1)->dataType(),
opTensor.getOutput(0)->dataType()});
// Compute compatible input dimensions
std::vector<std::size_t> dims0 = static_cast<const Div_Op&>(mOp).getInput(0)->dims();
std::vector<std::size_t> dims1 = static_cast<const Div_Op&>(mOp).getInput(1)->dims();
const std::vector<std::size_t>& outDims = static_cast<const Div_Op&>(mOp).getOutput(0)->dims();
std::vector<std::size_t> dims0 = opTensor.getInput(0)->dims();
std::vector<std::size_t> dims1 = opTensor.getInput(1)->dims();
const std::vector<std::size_t>& outDims = opTensor.getOutput(0)->dims();
// if (dims0 == dims1) {
// const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin(), dims0.cend(), std::size_t(1), std::multiplies<std::size_t>());
......@@ -108,24 +109,24 @@ void Aidge::DivImpl_cpu::forward() {
const std::size_t output_contiguous_size = std::accumulate(outDims.cbegin()+contiguousIdx, outDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
// initialize strides to iterate through data because of broadcasting
std::size_t *stride_post0;
std::size_t *stride_post1;
std::int32_t *stride_post0;
std::int32_t *stride_post1;
std::int32_t *stride_step0;
std::int32_t *stride_step1;
if (contiguousIdx > 0) {
stride_post0 = new std::size_t[contiguousIdx];
stride_post0 = new std::int32_t[contiguousIdx];
stride_post0[contiguousIdx - 1] = 1;
stride_post1 = new std::size_t[contiguousIdx];
stride_post1 = new std::int32_t[contiguousIdx];
stride_post1[contiguousIdx - 1] = 1;
for (std::size_t i = contiguousIdx - 2; i != static_cast<std::size_t>(-1); --i) {
stride_post0[i] = stride_post0[i+1]*dims0[i+1];
stride_post1[i] = stride_post1[i+1]*dims1[i+1];
stride_post0[i] = stride_post0[i+1]*static_cast<std::int32_t>(dims0[i+1]);
stride_post1[i] = stride_post1[i+1]*static_cast<std::int32_t>(dims1[i+1]);
}
stride_step0 = new std::int32_t[contiguousIdx];
stride_step1 = new std::int32_t[contiguousIdx];
for (std::size_t i = 0; i != contiguousIdx; ++i) {
stride_step0[i] = (dims0[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post0[i]) : 1;
stride_step1[i] = (dims1[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post1[i]) : 1;
stride_step0[i] = (dims0[i] == 1) ? 1 - stride_post0[i] : 1;
stride_step1[i] = (dims1[i] == 1) ? 1 - stride_post1[i] : 1;
}
}
......
......@@ -9,32 +9,34 @@
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include "aidge/backend/cpu/operator/ErfImpl.hpp"
#include <memory>
#include <vector>
#include "aidge/backend/cpu/operator/ErfImpl_forward_kernels.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Erf.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/operator/ErfImpl.hpp"
#include "aidge/backend/cpu/operator/ErfImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::ErfImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::ErfImpl_cpu::forward() {
const Erf_Op& op = static_cast<const Erf_Op&>(mOp);
// Find the correct kernel type
auto kernelFunc = Registrar<ErfImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
op.getInput(0)->dataType(),
op.getOutput(0)->dataType()
});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
kernelFunc(
op.getInput(0)->size(),
op.getInput(0)->getImpl()->rawPtr(),
op.getOutput(0)->getImpl()->rawPtr()
);
}
......@@ -9,31 +9,34 @@
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/backend/cpu/operator/FCImpl.hpp"
#include <cstddef> // std::size_t
#include <functional>
#include <memory>
#include <tuple>
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/FCImpl_backward_kernels.hpp"
#include "aidge/backend/cpu/operator/FCImpl_forward_kernels.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/FCImpl.hpp"
#include "aidge/backend/cpu/operator/FCImpl_forward_kernels.hpp"
void Aidge::FCImpl_cpu::forward()
{
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(2)) && "missing input #2");
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
AIDGE_ASSERT(op_.getInput(0), "missing input #0");
AIDGE_ASSERT(op_.getInput(1), "missing input #1");
AIDGE_ASSERT(op_.getInput(2), "missing input #2");
// Find the correct kernel type
const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
const auto outputDataType = op_.getOutput(0)->dataType();
const Registrar<FCImplForward_cpu>::registrar_key registrarKey = {
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
op_.getInput(0)->dataType(),
op_.getInput(1)->dataType(),
op_.getInput(2)->dataType(),
outputDataType};
Registrar<FCImplForward_cpu>::registrar_type kernelFunc;
......@@ -52,9 +55,9 @@ void Aidge::FCImpl_cpu::forward()
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input0 = op_.getInput(0)->refCastFrom(input0Fallback, *(op_.getOutput(0)));
const auto& input1 = op_.getInput(1)->refCastFrom(input1Fallback, *(op_.getOutput(0)));
const auto& input2 = op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0)));
// Call kernel
const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
......@@ -64,3 +67,49 @@ void Aidge::FCImpl_cpu::forward()
input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
getCPUPtr(mOp.getRawOutput(0)));
}
void Aidge::FCImpl_cpu::backward()
{
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
const auto& fc_grad = op_.getOutput(0)->grad();
assert(fc_grad && "missing ouput #0 gradient");
// Find the correct kernel type
const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = {
fc_grad->dataType(),
op_.getInput(0)->grad()->dataType(),
op_.getInput(1)->grad()->dataType(),
op_.getInput(2)->grad()->dataType()};
Registrar<FCImplBackward_cpu>::registrar_type kernelFunc;
if (Registrar<FCImplBackward_cpu>::exists(registrarKey)) {
// One exists with the right inputs/output types
kernelFunc = Registrar<FCImplBackward_cpu>::create(registrarKey);
}
else {
// Otherwise, fallback to the kernel with all types matching output type
kernelFunc = Registrar<FCImplBackward_cpu>::create({
fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType()});
}
// Convert input data (no overhead if not needed!)
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0gradFallback, input1gradFallback, input2gradFallback;
const auto& input0grad = op_.getInput(0)->grad()->refCastFrom(input0gradFallback, *(op_.getOutput(0)));
const auto& input1grad = op_.getInput(1)->grad()->refCastFrom(input1gradFallback, *(op_.getOutput(0)));
const auto& input2grad = op_.getInput(2)->grad()->refCastFrom(input2gradFallback, *(op_.getOutput(0)));
// Call kernel
const auto batchSize = (input0grad.dims().size() > 1) ? input0grad.dims()[0] : 1;
kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
batchSize,
input0grad.size() / batchSize,
getCPUPtr(fc_grad),
getCPUPtr(op_.getInput(0)),
getCPUPtr(mOp.getRawInput(1)),
input0grad.getImpl()->rawPtr(),
input1grad.getImpl()->rawPtr(),
input2grad.getImpl()->rawPtr());
}