Skip to content
Snippets Groups Projects
Commit 67049f1c authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add allowzero attr to Reshape

parent 97d2fef9
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!122Add missing attributes to operators
...@@ -29,24 +29,25 @@ public: ...@@ -29,24 +29,25 @@ public:
void forward() override; void forward() override;
}; };
enum class ReshapeAttr { Shape }; enum class ReshapeAttr { Shape, AllowZero };
class Reshape_Op : public OperatorTensor, class Reshape_Op : public OperatorTensor,
public Registrable<Reshape_Op, std::string, std::shared_ptr<OperatorImpl>(const Reshape_Op&)>, public Registrable<Reshape_Op, std::string, std::shared_ptr<OperatorImpl>(const Reshape_Op&)>,
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool> {
public: public:
static const std::string Type; static const std::string Type;
Reshape_Op() = delete; Reshape_Op() = delete;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool>;
template <ReshapeAttr e> template <ReshapeAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape) Reshape_Op(const std::vector<std::int64_t>& shape, bool allowzero)
: OperatorTensor(Type, 2, 0, 1), : OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape)) Attributes_(attr<ReshapeAttr::Shape>(shape),
attr<ReshapeAttr::AllowZero>(allowzero))
{ {
mImpl = std::make_shared<Reshape_OpImpl>(*this); mImpl = std::make_shared<Reshape_OpImpl>(*this);
} }
...@@ -88,15 +89,16 @@ public: ...@@ -88,15 +89,16 @@ public:
}; };
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {}, inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {},
const std::string &name = "") { bool allowzero = false,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape, allowzero), name);
} }
} // namespace Aidge } // namespace Aidge
namespace { namespace {
template <> template <>
const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" }; const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape", "AllowZero" };
} }
#endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */ #endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */
...@@ -92,7 +92,7 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -92,7 +92,7 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
dimSize = 1; dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i); negativeIndex = static_cast<DimIdx_t>(i);
} }
else if (dimSize == 0) else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>())
{ {
dimSize = getInput(0) -> dims()[i]; dimSize = getInput(0) -> dims()[i];
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment