From 67049f1cc990797791017d4717b8f94ce78f0361 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Fri, 17 May 2024 15:01:28 +0200 Subject: [PATCH] add allowzero attr to Reshape --- include/aidge/operator/Reshape.hpp | 18 ++++++++++-------- src/operator/Reshape.cpp | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index aa1f4f697..01c32004e 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -29,24 +29,25 @@ public: void forward() override; }; -enum class ReshapeAttr { Shape }; +enum class ReshapeAttr { Shape, AllowZero }; class Reshape_Op : public OperatorTensor, public Registrable<Reshape_Op, std::string, std::shared_ptr<OperatorImpl>(const Reshape_Op&)>, - public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { + public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool> { public: static const std::string Type; Reshape_Op() = delete; - using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; + using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool>; template <ReshapeAttr e> using attr = typename Attributes_::template attr<e>; - Reshape_Op(const std::vector<std::int64_t>& shape) + Reshape_Op(const std::vector<std::int64_t>& shape, bool allowzero) : OperatorTensor(Type, 2, 0, 1), - Attributes_(attr<ReshapeAttr::Shape>(shape)) + Attributes_(attr<ReshapeAttr::Shape>(shape), + attr<ReshapeAttr::AllowZero>(allowzero)) { mImpl = std::make_shared<Reshape_OpImpl>(*this); } @@ -88,15 +89,16 @@ public: }; inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {}, - const std::string &name = "") { + bool allowzero = false, + const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases - return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); + return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape, allowzero), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" }; +const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape", "AllowZero" }; } #endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */ diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index 0cce7a5b9..18b35548e 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -92,7 +92,7 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { dimSize = 1; negativeIndex = static_cast<DimIdx_t>(i); } - else if (dimSize == 0) + else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>()) { dimSize = getInput(0) -> dims()[i]; } -- GitLab