/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #ifndef AIDGE_CPU_OPERATOR_ADDIMPL_H_ #define AIDGE_CPU_OPERATOR_ADDIMPL_H_ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Add.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" #include <memory> #include <vector> namespace Aidge { // class Add_Op<2>; // compute kernel registry for forward and backward template <DimIdx_t NUM> class AddImplForward_cpu; template <DimIdx_t NUM> class AddImplBackward_cpu; template <> class AddImplForward_cpu<1> : public Registrable<AddImplForward_cpu<1>, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {}; template <> class AddImplBackward_cpu<1> : public Registrable<AddImplBackward_cpu<1>, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {}; template <> class AddImplForward_cpu<2> : public Registrable<AddImplForward_cpu<2>, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {}; template <> class AddImplBackward_cpu<2> : public Registrable<AddImplBackward_cpu<2>, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {}; template <> class AddImplForward_cpu<3> : public Registrable<AddImplForward_cpu<3>, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { }; template <> class AddImplBackward_cpu<3> : public Registrable<AddImplBackward_cpu<3>, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {}; template <DimIdx_t NUM> class AddImpl_cpu : public OperatorImpl { public: AddImpl_cpu(const Add_Op<NUM>& op) : OperatorImpl(op) {} static std::unique_ptr<AddImpl_cpu<NUM>> create(const Add_Op<NUM>& op) { return std::make_unique<AddImpl_cpu<NUM>>(op); } }; template <> class AddImpl_cpu<1> : public OperatorImpl { public: AddImpl_cpu(const Add_Op<1>& op) : OperatorImpl(op) {} static std::unique_ptr<AddImpl_cpu<1>> create(const Add_Op<1>& op) { return std::make_unique<AddImpl_cpu<1>>(op); } NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; void forward() override; }; template <> class AddImpl_cpu<2> : public OperatorImpl { public: AddImpl_cpu(const Add_Op<2>& op) : OperatorImpl(op) {} static std::unique_ptr<AddImpl_cpu<2>> create(const Add_Op<2>& op) { return std::make_unique<AddImpl_cpu<2>>(op); } NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; template <> class AddImpl_cpu<3> : public OperatorImpl { public: AddImpl_cpu(const Add_Op<3>& op) : OperatorImpl(op) {} static std::unique_ptr<AddImpl_cpu<3>> create(const Add_Op<3>& op) { return std::make_unique<AddImpl_cpu<3>>(op); } NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; void forward() override; }; namespace { static Registrar<Add_Op<1>> registrarAddImpl1I_cpu("cpu", Aidge::AddImpl_cpu<1>::create); static Registrar<Add_Op<2>> registrarAddImpl2I_cpu("cpu", Aidge::AddImpl_cpu<2>::create); static Registrar<Add_Op<3>> registrarAddImpl3I_cpu("cpu", Aidge::AddImpl_cpu<3>::create); } // namespace } // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_ADDIMPL_H_ */