Skip to content
Snippets Groups Projects

Add a Stack operator

Merged Jerome Hue requested to merge jeromeh/aidge_core:operator-stack into dev
Files
6
+ 106
0
#ifndef AIDGE_CORE_OPERATOR_STACK_H_
#define AIDGE_CORE_OPERATOR_STACK_H_
#include <memory>
#include <string>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class StackProdConso : public ProdConso {
public:
StackProdConso(const Operator &op) : ProdConso(op) {}
Elts_t getRequiredMemory(
const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override final;
void resetConsummerProducer() override;
};
class StackOpImpl : public OperatorImpl {
public:
StackOpImpl(const Operator &op, const std::string &backend = "")
: OperatorImpl(op, backend) {}
std::shared_ptr<ProdConso> getProdConso() const override {
return std::make_shared<StackProdConso>(mOp);
};
void forward() override;
};
enum class StackAttr { ForwardStep, MaxElements };
class StackOp
: public OperatorTensor,
public Registrable<
StackOp,
std::string,
std::function<std::unique_ptr<OperatorImpl>(const StackOp &)>> {
private:
using Attributes_ =
StaticAttributes<StackAttr, std::uint32_t, std::uint32_t>;
template <StackAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
static const std::string s_type;
StackOp(std::uint32_t maxElements);
/**
* @brief Copy-constructor. Copy the operator attributes and its output
* tensor(s), but not its input tensors (the new operator has no input
* associated).
* @param op Operator to copy.
*/
StackOp(const StackOp &op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::StackOp
*/
std::shared_ptr<Operator> clone() const override;
void setBackend(const std::string &name,
DeviceIdx_t device = 0) override final;
std::set<std::string> getAvailableBackends() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void forward() override;
inline std::shared_ptr<Attributes> attributes() const override {
return mAttributes;
}
inline std::uint32_t &maxElements() const {
return mAttributes->template getAttr<StackAttr::MaxElements>();
}
inline std::uint32_t &forwardStep() const {
return mAttributes->template getAttr<StackAttr::ForwardStep>();
}
static const std::vector<std::string> getInputsName() {
return {"data_input"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
std::shared_ptr<Node> stack(std::uint32_t maxElements,
const std::string &name = "");
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::StackAttr>::data[] = {"max_elements"};
}
#endif
Loading