/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #ifndef __AIDGE_CPU_OPERATOR_ADDIMPL_H__ #define __AIDGE_CPU_OPERATOR_ADDIMPL_H__ #include "backend/OperatorImpl.hpp" #include "operator/Add.hpp" #include "utils/Registrar.hpp" #include "utils/Types.h" #include <memory> #include <vector> namespace Aidge { // class Add_Op<2>; // compute kernel registry for forward and backward template <DimIdx_t NUM> class AddImplForward_cpu; template <DimIdx_t NUM> class AddImplBackward_cpu; template <> class AddImplForward_cpu<1> : public Registrable<AddImplForward_cpu<1>, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {}; template <> class AddImplBackward_cpu<1> : public Registrable<AddImplBackward_cpu<1>, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {}; template <> class AddImplForward_cpu<2> : public Registrable<AddImplForward_cpu<2>, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {}; template <> class AddImplBackward_cpu<2> : public Registrable<AddImplBackward_cpu<2>, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, void*)> {}; template <> class AddImplForward_cpu<3> : public Registrable<AddImplForward_cpu<3>, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> { }; template <> class AddImplBackward_cpu<3> : public Registrable<AddImplBackward_cpu<3>, std::tuple<DataType, DataType, DataType, DataType>, void(const std::size_t, const void*, const void*, const void*, void*)> {}; template <DimIdx_t NUM> class AddImpl_cpu : public OperatorImpl { private: const Add_Op<NUM>& mOp; std::array<NbElts_t, NUM> mNbConsumedData = {}; std::array<NbElts_t, 1> mNbProducedData = {}; public: AddImpl_cpu(const Add_Op<NUM>& op) : mOp(op) {} static std::unique_ptr<AddImpl_cpu<NUM>> create(const Add_Op<NUM>& op) { return std::make_unique<AddImpl_cpu<NUM>>(op); } public: NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final { assert(mOp.getInput(inputIdx) && "requires valid input"); // Requires the whole tensors const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->dims(); return std::accumulate(inputDims.begin(), inputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>()); } NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final { // for the direct convolution algorithm, convolutions can be in-place, if there is no padding! return 0; } NbElts_t getRequiredMemory(__attribute__((unused)) const IOIndex_t outputIdx, const std::vector<DimSize_t>& inputsSize) const override final { // Requires the whole tensors, regardless of available data on inputs assert(outputIdx == 0 && "operator has only one output"); const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims(); return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>()); } NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final { assert(inputIdx < mNbConsumedData.size()); return mNbConsumedData[inputIdx]; } NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final { assert(outputIdx < mNbProducedData.size()); return mNbProducedData[outputIdx]; } void forward() { // nothing } void backward() { printf("Not implemented yet.\n"); } }; template <> class AddImpl_cpu<1> : public OperatorImpl { private: const Add_Op<1>& mOp; std::array<NbElts_t, 1> mNbConsumedData; std::array<NbElts_t, 1> mNbProducedData; public: AddImpl_cpu(const Add_Op<1>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} static std::unique_ptr<AddImpl_cpu<1>> create(const Add_Op<1>& op) { return std::make_unique<AddImpl_cpu<1>>(op); } public: NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getRequiredMemory(__attribute__((unused)) const IOIndex_t outputIdx, __attribute__((unused)) const std::vector<DimSize_t> &inputsSize) const override final; NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbProducedData(const IOIndex_t /*outputIdx*/) const override final; void forward(); void backward(); }; template <> class AddImpl_cpu<2> : public OperatorImpl { private: const Add_Op<2>& mOp; std::array<NbElts_t, 2> mNbConsumedData; std::array<NbElts_t, 1> mNbProducedData; public: AddImpl_cpu(const Add_Op<2>& op) : mOp(op), mNbConsumedData({0, 0}), mNbProducedData({0}) {} static std::unique_ptr<AddImpl_cpu<2>> create(const Add_Op<2>& op) { return std::make_unique<AddImpl_cpu<2>>(op); } public: NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; NbElts_t getRequiredMemory(__attribute__((unused)) const IOIndex_t outputIdx, __attribute__((unused)) const std::vector<DimSize_t>& inputsSize) const override final; NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final; NbElts_t getNbProducedData(const IOIndex_t /*outputIdx*/) const override final; void forward(); void backward(); }; template <> class AddImpl_cpu<3> : public OperatorImpl { private: const Add_Op<3>& mOp; std::array<NbElts_t, 3> mNbConsumedData; std::array<NbElts_t, 1> mNbProducedData; public: AddImpl_cpu(const Add_Op<3>& op) : mOp(op), mNbConsumedData({0, 0, 0}), mNbProducedData({0}) {} static std::unique_ptr<AddImpl_cpu<3>> create(const Add_Op<3>& op) { return std::make_unique<AddImpl_cpu<3>>(op); } public: NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getRequiredMemory(__attribute__((unused)) const IOIndex_t outputIdx, const std::vector<DimSize_t>& /*inputsSize*/) const override final; NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final; NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; void forward(); void backward(); }; namespace { static Registrar<Add_Op<1>> registrarAddImpl1I_cpu("cpu", Aidge::AddImpl_cpu<1>::create); static Registrar<Add_Op<2>> registrarAddImpl2I_cpu("cpu", Aidge::AddImpl_cpu<2>::create); static Registrar<Add_Op<3>> registrarAddImpl3I_cpu("cpu", Aidge::AddImpl_cpu<3>::create); } // namespace } // namespace Aidge #endif /* __AIDGE_CPU_OPERATOR_ADDIMPL_H__ */