Skip to content
Snippets Groups Projects
Commit 1fdb9c99 authored by Noam Zerah's avatar Noam Zerah
Browse files

Merge dev updates in feat_operator_bitshift (New Registrar System)

parent 329ddc3c
No related branches found
No related tags found
No related merge requests found
Pipeline #55470 canceled
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_
#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ #define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/cpu/operator/OperatorImpl.hpp"
#include "aidge/operator/BitShift.hpp" #include "aidge/operator/BitShift.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -21,31 +21,18 @@ ...@@ -21,31 +21,18 @@
#include <vector> #include <vector>
namespace Aidge { namespace Aidge {
// class BitShift_Op; // Operator implementation entry point for the backend
using BitShiftImpl_cpu = OperatorImpl_cpu<BitShift_Op,
// compute kernel registry for forward and backward void(const BitShift_Op::BitShiftDirection,
class BitShiftImplForward_cpu const std::vector<std::size_t>&,
: public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { const std::vector<std::size_t>&,
}; const std::vector<std::size_t>&,
class BitShiftImplBackward_cpu const void*,
: public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { const void*,
}; void*)>;
class BitShiftImpl_cpu : public OperatorImpl { // Implementation entry point registration to Operator
public: REGISTRAR(BitShift_Op,"cpu",Aidge::BitShiftImpl_cpu::create);
BitShiftImpl_cpu(const BitShift_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<BitShiftImpl_cpu> create(const BitShift_Op& op) {
return std::make_unique<BitShiftImpl_cpu>(op);
}
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<BitShift_Op> registrarBitShiftImpl_cpu("cpu", Aidge::BitShiftImpl_cpu::create);
}
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ */ #endif /* AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ */
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
* *
********************************************************************************/ ********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_ #ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_
#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_ #define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
...@@ -57,14 +57,14 @@ void BitShiftImpl_cpu_forward_kernel( ...@@ -57,14 +57,14 @@ void BitShiftImpl_cpu_forward_kernel(
} }
} }
namespace { REGISTRAR(BitShiftImpl_cpu,
static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int32( {DataType::Int32},
{DataType::Int32, DataType::Int32, DataType::Int32}, {ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>,nullptr});
Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>); REGISTRAR(BitShiftImpl_cpu,
static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int64( {DataType::Int64},
{DataType::Int64, DataType::Int64, DataType::Int64}, {ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>,nullptr});
Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>);
} // namespace
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_FORWARD_KERNEL_H_ */ #endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_KERNELS_H_ */
\ No newline at end of file
...@@ -21,21 +21,16 @@ ...@@ -21,21 +21,16 @@
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/BitShiftImpl.hpp" #include "aidge/backend/cpu/operator/BitShiftImpl.hpp"
#include "aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp" #include "aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp"
Aidge::Elts_t Aidge::BitShiftImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return Elts_t::DataElts(0);
}
template<>
void Aidge::BitShiftImpl_cpu::forward() { void Aidge::BitShiftImpl_cpu::forward() {
const auto& op_ = dynamic_cast<const BitShift_Op&>(mOp); const auto& op_ = dynamic_cast<const BitShift_Op&>(mOp);
auto kernelFunc = Registrar<BitShiftImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), const auto impl = Registrar<BitShiftImpl_cpu>::create(getBestMatch(getRequiredSpec()));
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()); std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
...@@ -45,7 +40,7 @@ void Aidge::BitShiftImpl_cpu::forward() { ...@@ -45,7 +40,7 @@ void Aidge::BitShiftImpl_cpu::forward() {
BitShift_Op::BitShiftDirection direction = op_.direction(); BitShift_Op::BitShiftDirection direction = op_.direction();
// Call kernel // Call kernel
kernelFunc( impl.forward(
direction, direction,
inputDims0, inputDims0,
inputDims1, inputDims1,
...@@ -55,3 +50,8 @@ void Aidge::BitShiftImpl_cpu::forward() { ...@@ -55,3 +50,8 @@ void Aidge::BitShiftImpl_cpu::forward() {
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
} }
template <>
void Aidge::BitShiftImpl_cpu::backward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for BitShift_Op on backend cpu");
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment