Skip to content
Snippets Groups Projects
Commit 1fdb9c99 authored by Noam Zerah's avatar Noam Zerah
Browse files

Merge dev updates in feat_operator_bitshift (New Registrar System)

parent 329ddc3c
No related branches found
No related tags found
1 merge request!88Added Bitshift Operator
Pipeline #55471 passed
......@@ -12,7 +12,7 @@
#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_
#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/backend/cpu/operator/OperatorImpl.hpp"
#include "aidge/operator/BitShift.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
......@@ -21,31 +21,18 @@
#include <vector>
namespace Aidge {
// class BitShift_Op;
// compute kernel registry for forward and backward
class BitShiftImplForward_cpu
: public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
};
class BitShiftImplBackward_cpu
: public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
};
class BitShiftImpl_cpu : public OperatorImpl {
public:
BitShiftImpl_cpu(const BitShift_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<BitShiftImpl_cpu> create(const BitShift_Op& op) {
return std::make_unique<BitShiftImpl_cpu>(op);
}
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<BitShift_Op> registrarBitShiftImpl_cpu("cpu", Aidge::BitShiftImpl_cpu::create);
}
// Operator implementation entry point for the backend
using BitShiftImpl_cpu = OperatorImpl_cpu<BitShift_Op,
void(const BitShift_Op::BitShiftDirection,
const std::vector<std::size_t>&,
const std::vector<std::size_t>&,
const std::vector<std::size_t>&,
const void*,
const void*,
void*)>;
// Implementation entry point registration to Operator
REGISTRAR(BitShift_Op,"cpu",Aidge::BitShiftImpl_cpu::create);
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ */
......@@ -9,8 +9,8 @@
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_
#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_
#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_
#include "aidge/utils/Registrar.hpp"
......@@ -57,14 +57,14 @@ void BitShiftImpl_cpu_forward_kernel(
}
}
namespace {
static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32, DataType::Int32},
Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>);
static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int64(
{DataType::Int64, DataType::Int64, DataType::Int64},
Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>);
} // namespace
REGISTRAR(BitShiftImpl_cpu,
{DataType::Int32},
{ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>,nullptr});
REGISTRAR(BitShiftImpl_cpu,
{DataType::Int64},
{ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>,nullptr});
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_FORWARD_KERNEL_H_ */
#endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_KERNELS_H_ */
\ No newline at end of file
......@@ -21,21 +21,16 @@
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/BitShiftImpl.hpp"
#include "aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp"
Aidge::Elts_t Aidge::BitShiftImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return Elts_t::DataElts(0);
}
#include "aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp"
template<>
void Aidge::BitShiftImpl_cpu::forward() {
const auto& op_ = dynamic_cast<const BitShift_Op&>(mOp);
auto kernelFunc = Registrar<BitShiftImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
const auto impl = Registrar<BitShiftImpl_cpu>::create(getBestMatch(getRequiredSpec()));
const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
......@@ -45,7 +40,7 @@ void Aidge::BitShiftImpl_cpu::forward() {
BitShift_Op::BitShiftDirection direction = op_.direction();
// Call kernel
kernelFunc(
impl.forward(
direction,
inputDims0,
inputDims1,
......@@ -55,3 +50,8 @@ void Aidge::BitShiftImpl_cpu::forward() {
getCPUPtr(mOp.getRawOutput(0)));
}
template <>
void Aidge::BitShiftImpl_cpu::backward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for BitShift_Op on backend cpu");
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment