Skip to content
Snippets Groups Projects
Commit 246a0c8b authored by Grégoire Kubler's avatar Grégoire Kubler Committed by Maxence Naud
Browse files

feat : added padding attribute to resize operator

parent 6d809877
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "aidge/data/Interpolation.hpp" #include "aidge/data/Interpolation.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -36,6 +37,7 @@ enum class ResizeAttr { ...@@ -36,6 +37,7 @@ enum class ResizeAttr {
// extrapolation_value, // extrapolation_value,
// keep_aspect_ratio_policy, // keep_aspect_ratio_policy,
InterpolationMode, InterpolationMode,
PaddingMode,
}; };
/** /**
...@@ -91,7 +93,8 @@ class Resize_Op ...@@ -91,7 +93,8 @@ class Resize_Op
StaticAttributes<ResizeAttr, StaticAttributes<ResizeAttr,
Interpolation::CoordinateTransformation, Interpolation::CoordinateTransformation,
float, float,
Interpolation::Mode>; Interpolation::Mode,
PadBorderType>;
template <ResizeAttr e> template <ResizeAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes; const std::shared_ptr<Attributes_> mAttributes;
...@@ -115,10 +118,12 @@ class Resize_Op ...@@ -115,10 +118,12 @@ class Resize_Op
* set, forward will fail. * set, forward will fail.
* @return NodePtr * @return NodePtr
*/ */
explicit Resize_Op(Interpolation::CoordinateTransformation coordTransfoMode, explicit Resize_Op(
Interpolation::Mode interpol_mode = Interpolation::CoordinateTransformation coordTransfoMode,
Interpolation::Mode::NearestRoundPreferFloor, Interpolation::Mode interpol_mode =
float cubic_coef_a = -.75f) Interpolation::Mode::NearestRoundPreferFloor,
float cubic_coef_a = -.75f,
PadBorderType paddingMode = PadBorderType::Constant)
: OperatorTensor(Type, : OperatorTensor(Type,
{InputCategory::Data, {InputCategory::Data,
InputCategory::OptionalData, InputCategory::OptionalData,
...@@ -128,7 +133,8 @@ class Resize_Op ...@@ -128,7 +133,8 @@ class Resize_Op
mAttributes(std::make_shared<Attributes_>( mAttributes(std::make_shared<Attributes_>(
attr<ResizeAttr::CubicCoeffA>(cubic_coef_a), attr<ResizeAttr::CubicCoeffA>(cubic_coef_a),
attr<ResizeAttr::CoordinateTransformationMode>(coordTransfoMode), attr<ResizeAttr::CoordinateTransformationMode>(coordTransfoMode),
attr<ResizeAttr::InterpolationMode>(interpol_mode))) {} attr<ResizeAttr::InterpolationMode>(interpol_mode),
attr<ResizeAttr::PaddingMode>(paddingMode))) {}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output * @brief Copy-constructor. Copy the operator attributes and its output
...@@ -166,12 +172,26 @@ class Resize_Op ...@@ -166,12 +172,26 @@ class Resize_Op
return mAttributes return mAttributes
->template getAttr<ResizeAttr::CoordinateTransformationMode>(); ->template getAttr<ResizeAttr::CoordinateTransformationMode>();
} }
Interpolation::CoordinateTransformation &
coordinateTransformationMode() const noexcept {
return mAttributes
->template getAttr<ResizeAttr::CoordinateTransformationMode>();
}
float &cubicCoefA() { float &cubicCoefA() {
return mAttributes->template getAttr<ResizeAttr::CubicCoeffA>(); return mAttributes->template getAttr<ResizeAttr::CubicCoeffA>();
} }
Interpolation::Mode &interpolationmode() { Interpolation::Mode &interpolationMode() const noexcept {
return mAttributes->template getAttr<ResizeAttr::InterpolationMode>(); return mAttributes->template getAttr<ResizeAttr::InterpolationMode>();
} }
Interpolation::Mode &interpolationMode() {
return mAttributes->template getAttr<ResizeAttr::InterpolationMode>();
}
PadBorderType &paddingMode() const noexcept {
return mAttributes->template getAttr<ResizeAttr::PaddingMode>();
}
PadBorderType &paddingMode() {
return mAttributes->template getAttr<ResizeAttr::PaddingMode>();
}
// bool &excludeOutside() { // bool &excludeOutside() {
// return mAttributes->template getAttr<ResizeAttr::excludeOutside>(); // return mAttributes->template getAttr<ResizeAttr::excludeOutside>();
// } // }
...@@ -199,6 +219,7 @@ class Resize_Op ...@@ -199,6 +219,7 @@ class Resize_Op
* used if interpolation_mode = Interpolation::Mode::Cubic * used if interpolation_mode = Interpolation::Mode::Cubic
* @warning Scales & ROI input cannot be set simultaneously. If bot are set, * @warning Scales & ROI input cannot be set simultaneously. If bot are set,
* forward will fail. * forward will fail.
* @warning Padding mode will tell how values out of bound are treated.
* @return NodePtr * @return NodePtr
*/ */
std::shared_ptr<Node> std::shared_ptr<Node>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment