diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 8a1b50a0e95fad842b7a2c8f7d0ec5434e13ad4e..16fa9967ce58e3dd557a62393cea30eda6c6da4b 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -33,7 +33,7 @@ #include "aidge/operator/Add.hpp" #include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/BatchNorm.hpp" -// #include "aidge/operator/Concat.hpp" +#include "aidge/operator/Concat.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/FC.hpp" diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..20ae1be634b5a1b5ff07cd71b76b0e3e102324bc --- /dev/null +++ b/include/aidge/operator/Concat.hpp @@ -0,0 +1,186 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_CONCAT_H_ +#define AIDGE_CORE_OPERATOR_CONCAT_H_ + +#include <numeric> +#include <vector> +#include <cmath> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class ConcatAttr { NbInputs, Axis}; + +class Concat_Op : public Operator, + public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>, + public StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t> { +private: + // FIXME: change accessibility + std::vector<std::shared_ptr<Tensor>> mInputs; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Concat"; + + using Attributes_ = StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t>; + template <ConcatAttr e> + using attr = typename Attributes_::template attr<e>; + + Concat_Op(const IOIndex_t nbIn, const DimSize_t axis) + : Operator(Type), + mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())), + Attributes_(attr<ConcatAttr::NbInputs>(nbIn), + attr<ConcatAttr::Axis>(axis)) + { + assert(nbIn > 0 && "Concat should have at least one input"); + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Concat_Op(const Concat_Op& op) + : Operator(Type), + Attributes_(op), + mInputs(std::vector<std::shared_ptr<Tensor>>(op.getAttr<ConcatAttr::NbInputs>())), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + assert(op.getAttr<ConcatAttr::NbInputs>() > 0 && "Concat should have at least one input"); + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Concat_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Concat_Op>(*this); + } + + // Data operator[](const char* inputName) override final { + // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : + // (strcmp(inputName, "weight") ? mInputs[1] : + // (strcmp(inputName, "bias") ? mInputs[2] : + // nullptr)); + // assert((in!=nullptr) && "No such parameter"); + // return *in; + // } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + bool computable = !(mInputs[0]->empty()) && (getAttr<ConcatAttr::Axis>() < mInputs[0]->nbDims()); + for (const auto& input : mInputs) { + computable &= !(input->empty()); + computable &= (input->nbDims() == mInputs[0]->nbDims()); + } + // Every input is non-empty with the same number of dimensions + if (computable) { + auto outputDims = mInputs[0]->dims(); + + for (std::size_t i = 1; i < getAttr<ConcatAttr::NbInputs>(); ++i) { + outputDims[getAttr<ConcatAttr::Axis>()] += mInputs[i]->dims()[getAttr<ConcatAttr::Axis>()]; + } + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + // void checkDims() const override final { + // assert(outputDimsForwarded()); + // for (const auto& in : mInputs) { + // assert(in->dims() == mOutput->dims()); + // } + // } + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "Concat Operators has only 1 outputs"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Concat_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) { + mInputs[i]->setBackend(name); + } + } + + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) { + mInputs[i]->setDatatype(datatype); + } + } + + inline IOIndex_t nbInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + + static const std::vector<std::string> getInputsName(){ + return {"data_input_0", "data_input_n"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis = 0, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Concat_Op>(nbIn), axis, name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_CONCAT_H_ */