Skip to content
Snippets Groups Projects
Commit 1409df0f authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Concat operator

parent 4a1c4536
No related branches found
No related tags found
No related merge requests found
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CONCAT_H_
#define AIDGE_CORE_OPERATOR_CONCAT_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ConcatAttr { Axis };
class Concat_Op : public Operator,
public Registrable<Concat_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Concat_Op&)>,
public StaticAttributes<ConcatAttr, int> {
public:
// FIXME: change accessibility
IOIndex_t mNbIn;
std::vector<std::shared_ptr<Tensor>> mInputs;
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char* Type = "Concat";
Concat_Op() = delete;
using Attributes_ = StaticAttributes<ConcatAttr, int>;
template <ConcatAttr e> using attr = typename Attributes_::template attr<e>;
Concat_Op(int axis)
: Operator(Type),
Attributes_(
attr<ConcatAttr::Axis>(axis))
{
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Concat_Op(const Concat_Op& op)
: Operator(Type),
mNbIn(op.mNbIn),
Attributes_(op),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn);
for (std::size_t i = 0; i < mNbIn; ++i) {
mInputs[i] = std::make_shared<Tensor>();
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Concat_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Concat_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
// assert(inputIdx < mNbIn && "operators supports only x inputs");
if (strcmp(data->type(), Tensor::Type) == 0) {
// TODO: associate input only if of type Tensor, otherwise do nothing
if(inputIdx<mInputs.size())
mInputs.insert( mInputs.begin() + inputIdx, std::dynamic_pointer_cast<Tensor>(data));
else
mInputs.emplace_back(std::dynamic_pointer_cast<Tensor>(data));
mNbIn = mInputs.size();
}
}
void computeOutputDims() override final {
if (!mInputs.empty() && !mInputs[0]->empty())
{
// mOutput->resize(mInputs[0]->dims());
Concat_Op::Attrs attr = getStaticAttributes();
const int& axis = static_cast<const int&>(std::get<0>(attr));
std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < mInputs[0]->nbDims(); ++i) {
if(i==axis)
outputDims.push_back(mInputs.size() * mInputs[0]->dims()[i]);
else
outputDims.push_back(mInputs[0]->dims()[i]);
}
mOutput->resize(outputDims);
}
}
bool outputDimsForwarded() const override final {
return !(mOutput->empty());
}
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator");
printf("Info: using input() on a GenericOperator.\n");
return *mInputs[inputIdx];
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Concat Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < mNbIn) && "input index out of range for this instance of Concat operator");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Concat_Op>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
for (std::size_t i = 0; i < mNbIn; ++i) {
mInputs[i]->setBackend(name);
}
}
void setDatatype(const DataType& datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
for (std::size_t i = 0; i < mNbIn; ++i) {
mInputs[i]->setDatatype(datatype);
}
}
inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Concat(int axis, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Concat_Op>(axis), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ConcatAttr>::data[] = {"Axis"};
}
#endif /* AIDGE_CORE_OPERATOR_CONCAT_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Concat(py::module& m) {
py::class_<Concat_Op, std::shared_ptr<Concat_Op>, Operator, Attributes>(m, "ConcatOp", py::multiple_inheritance())
.def("get_inputs_name", &Concat_Op::getInputsName)
.def("get_outputs_name", &Concat_Op::getOutputsName);
m.def("Concat", &Concat, py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment