Skip to content
Snippets Groups Projects

Add a Stack operator

Merged Jerome Hue requested to merge jeromeh/aidge_core:operator-stack into dev
3 files
+ 221
0
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 88
0
#ifndef AIDGE_CORE_OPERATOR_STACK_H_
#define AIDGE_CORE_OPERATOR_STACK_H_
#include <memory>
#include <string>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class StackOpImpl : public OperatorImpl {
public:
StackOpImpl(const Operator &op, const std::string &backend = "")
: OperatorImpl(op, backend) {}
void forward() override;
};
enum class StackAttr { MaxElements };
class StackOp
: public OperatorTensor,
public Registrable<
StackOp,
std::string,
std::function<std::unique_ptr<OperatorImpl>(const StackOp &)>> {
private:
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t>;
template <StackAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
static const std::string s_type;
StackOp(std::uint32_t maxElements);
/**
* @brief Copy-constructor. Copy the operator attributes and its output
* tensor(s), but not its input tensors (the new operator has no input
* associated).
* @param op Operator to copy.
*/
StackOp(const StackOp &op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::StackOp
*/
std::shared_ptr<Operator> clone() const override;
void setBackend(const std::string &name,
DeviceIdx_t device = 0) override final;
std::set<std::string> getAvailableBackends() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void forward() override;
inline std::shared_ptr<Attributes> attributes() const override {
return mAttributes;
}
inline std::uint32_t &maxElements() const {
return mAttributes->template getAttr<StackAttr::MaxElements>();
}
static const std::vector<std::string> getInputsName() {
return {"data_input"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
std::shared_ptr<Node> stack(std::uint32_t maxElements,
const std::string &name = "");
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::StackAttr>::data[] = {"max_elements"};
}
#endif
Loading