Skip to content
Snippets Groups Projects

Add missing attributes to operators

Merged Houssem ROUIS requested to merge hrouis/aidge_core:fix/add_missing_attr into dev
Files
8
@@ -29,24 +29,25 @@ public:
void forward() override;
};
enum class ReshapeAttr { Shape };
enum class ReshapeAttr { Shape, AllowZero };
class Reshape_Op : public OperatorTensor,
public Registrable<Reshape_Op, std::string, std::shared_ptr<OperatorImpl>(const Reshape_Op&)>,
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> {
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool> {
public:
static const std::string Type;
Reshape_Op() = delete;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool>;
template <ReshapeAttr e>
using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape)
Reshape_Op(const std::vector<std::int64_t>& shape, bool allowzero)
: OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape))
Attributes_(attr<ReshapeAttr::Shape>(shape),
attr<ReshapeAttr::AllowZero>(allowzero))
{
mImpl = std::make_shared<Reshape_OpImpl>(*this);
}
@@ -89,15 +90,16 @@ public:
};
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {},
const std::string &name = "") {
bool allowzero = false,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape, allowzero), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" };
const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape", "AllowZero" };
}
#endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */
Loading