Skip to content
Snippets Groups Projects
Commit 1409df0f authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Concat operator

parent 4a1c4536
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CONCAT_H_
#define AIDGE_CORE_OPERATOR_CONCAT_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ConcatAttr { Axis };
class Concat_Op : public Operator,
public Registrable<Concat_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Concat_Op&)>,
public StaticAttributes<ConcatAttr, int> {
public:
// FIXME: change accessibility
IOIndex_t mNbIn;
std::vector<std::shared_ptr<Tensor>> mInputs;
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char* Type = "Concat";
Concat_Op() = delete;
using Attributes_ = StaticAttributes<ConcatAttr, int>;
template <ConcatAttr e> using attr = typename Attributes_::template attr<e>;
Concat_Op(int axis)
: Operator(Type),
Attributes_(
attr<ConcatAttr::Axis>(axis))
{
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Concat_Op(const Concat_Op& op)
: Operator(Type),
mNbIn(op.mNbIn),
Attributes_(op),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn);
for (std::size_t i = 0; i < mNbIn; ++i) {
mInputs[i] = std::make_shared<Tensor>();
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Concat_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Concat_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
// assert(inputIdx < mNbIn && "operators supports only x inputs");
if (strcmp(data->type(), Tensor::Type) == 0) {
// TODO: associate input only if of type Tensor, otherwise do nothing
if(inputIdx<mInputs.size())
mInputs.insert( mInputs.begin() + inputIdx, std::dynamic_pointer_cast<Tensor>(data));
else
mInputs.emplace_back(std::dynamic_pointer_cast<Tensor>(data));
mNbIn = mInputs.size();
}
}
void computeOutputDims() override final {
if (!mInputs.empty() && !mInputs[0]->empty())
{
// mOutput->resize(mInputs[0]->dims());
Concat_Op::Attrs attr = getStaticAttributes();
const int& axis = static_cast<const int&>(std::get<0>(attr));
std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) {
if(i==axis)
outputDims.push_back(mInputs.size() * mInputs[0]->dims()[i]);
else
outputDims.push_back(mInputs[0]->dims()[i]);
}
mOutput->resize(outputDims);
}
}
bool outputDimsForwarded() const override final {
return !(mOutput->empty());
}
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator");
printf("Info: using input() on a GenericOperator.\n");
return *mInputs[inputIdx];
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Concat Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Concat_Op>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
for (std::size_t i = 0; i < mNbIn; ++i) {
mInputs[i]->setBackend(name);
}
}
void setDatatype(const DataType& datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
for (std::size_t i = 0; i < mNbIn; ++i) {
mInputs[i]->setDatatype(datatype);
}
}
inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Concat(int axis, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Concat_Op>(axis), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ConcatAttr>::data[] = {"Axis"};
}
#endif /* AIDGE_CORE_OPERATOR_CONCAT_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Concat(py::module& m) {
py::class_<Concat_Op, std::shared_ptr<Concat_Op>, Operator, Attributes>(m, "ConcatOp", py::multiple_inheritance())
.def("get_inputs_name", &Concat_Op::getInputsName)
.def("get_outputs_name", &Concat_Op::getOutputsName);
m.def("Concat", &Concat, py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment