From 246a0c8b554b8550314a577f130486974e10e007 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Gr=C3=A9goire=20KUBLER?= <gregoire.kubler@proton.me>
Date: Mon, 28 Oct 2024 13:21:50 +0100
Subject: [PATCH] feat : added padding attribute to resize operator

---
 include/aidge/operator/Resize.hpp | 35 ++++++++++++++++++++++++-------
 1 file changed, 28 insertions(+), 7 deletions(-)

diff --git a/include/aidge/operator/Resize.hpp b/include/aidge/operator/Resize.hpp
index 5e7100075..ee6009dae 100644
--- a/include/aidge/operator/Resize.hpp
+++ b/include/aidge/operator/Resize.hpp
@@ -20,6 +20,7 @@
 #include "aidge/data/Interpolation.hpp"
 #include "aidge/graph/Node.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/Pad.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/StaticAttributes.hpp"
 #include "aidge/utils/Types.h"
@@ -36,6 +37,7 @@ enum class ResizeAttr {
     //   extrapolation_value,
     //   keep_aspect_ratio_policy,
     InterpolationMode,
+    PaddingMode,
 };
 
 /**
@@ -91,7 +93,8 @@ class Resize_Op
         StaticAttributes<ResizeAttr,
                          Interpolation::CoordinateTransformation,
                          float,
-                         Interpolation::Mode>;
+                         Interpolation::Mode,
+                         PadBorderType>;
     template <ResizeAttr e>
     using attr = typename Attributes_::template attr<e>;
     const std::shared_ptr<Attributes_> mAttributes;
@@ -115,10 +118,12 @@ class Resize_Op
      * set, forward will fail.
      * @return NodePtr
      */
-    explicit Resize_Op(Interpolation::CoordinateTransformation coordTransfoMode,
-              Interpolation::Mode interpol_mode =
-                  Interpolation::Mode::NearestRoundPreferFloor,
-              float cubic_coef_a = -.75f) 
+    explicit Resize_Op(
+        Interpolation::CoordinateTransformation coordTransfoMode,
+        Interpolation::Mode interpol_mode =
+            Interpolation::Mode::NearestRoundPreferFloor,
+        float cubic_coef_a = -.75f,
+        PadBorderType paddingMode = PadBorderType::Constant)
         : OperatorTensor(Type,
                          {InputCategory::Data,
                           InputCategory::OptionalData,
@@ -128,7 +133,8 @@ class Resize_Op
           mAttributes(std::make_shared<Attributes_>(
               attr<ResizeAttr::CubicCoeffA>(cubic_coef_a),
               attr<ResizeAttr::CoordinateTransformationMode>(coordTransfoMode),
-              attr<ResizeAttr::InterpolationMode>(interpol_mode))) {}
+              attr<ResizeAttr::InterpolationMode>(interpol_mode),
+              attr<ResizeAttr::PaddingMode>(paddingMode))) {}
 
     /**
      * @brief Copy-constructor. Copy the operator attributes and its output
@@ -166,12 +172,26 @@ class Resize_Op
         return mAttributes
             ->template getAttr<ResizeAttr::CoordinateTransformationMode>();
     }
+    Interpolation::CoordinateTransformation &
+    coordinateTransformationMode() const noexcept {
+        return mAttributes
+            ->template getAttr<ResizeAttr::CoordinateTransformationMode>();
+    }
     float &cubicCoefA() {
         return mAttributes->template getAttr<ResizeAttr::CubicCoeffA>();
     }
-    Interpolation::Mode &interpolationmode() {
+    Interpolation::Mode &interpolationMode() const noexcept {
         return mAttributes->template getAttr<ResizeAttr::InterpolationMode>();
     }
+    Interpolation::Mode &interpolationMode() {
+        return mAttributes->template getAttr<ResizeAttr::InterpolationMode>();
+    }
+    PadBorderType &paddingMode() const noexcept {
+        return mAttributes->template getAttr<ResizeAttr::PaddingMode>();
+    }
+    PadBorderType &paddingMode() {
+        return mAttributes->template getAttr<ResizeAttr::PaddingMode>();
+    }
     // bool &excludeOutside() {
     //   return mAttributes->template getAttr<ResizeAttr::excludeOutside>();
     // }
@@ -199,6 +219,7 @@ class Resize_Op
  * used if interpolation_mode = Interpolation::Mode::Cubic
  * @warning Scales & ROI input cannot be set simultaneously. If bot are set,
  * forward will fail.
+ * @warning Padding mode will tell how values out of bound are treated. 
  * @return NodePtr
  */
 std::shared_ptr<Node>
-- 
GitLab