Skip to content
Snippets Groups Projects
Commit ac1d2bbb authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Changed Pad DIM handling

parent 19ab0c3e
No related branches found
No related tags found
1 merge request!34Changed Pad DIM handling
Pipeline #32988 passed
...@@ -30,7 +30,7 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, ...@@ -30,7 +30,7 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels,
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
{ {
// Construct micro-graph // Construct micro-graph
auto pad = Pad(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators // Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv}; const std::vector<NodePtr> orderedInputNodes = {pad, conv};
...@@ -63,7 +63,7 @@ inline std::shared_ptr<Node> PaddedConvDepthWise(const std::array<DimSize_t, DIM ...@@ -63,7 +63,7 @@ inline std::shared_ptr<Node> PaddedConvDepthWise(const std::array<DimSize_t, DIM
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
{ {
// Construct micro-graph // Construct micro-graph
auto pad = Pad(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators // Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv}; const std::vector<NodePtr> orderedInputNodes = {pad, conv};
...@@ -93,7 +93,7 @@ inline std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> & ...@@ -93,7 +93,7 @@ inline std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &
const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0))
{ {
auto graph = Sequential({ auto graph = Sequential({
Pad(padding_dims, (!name.empty()) ? name + "_pad" : ""), Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims) AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims)
}); });
...@@ -118,7 +118,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> & ...@@ -118,7 +118,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> &
const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0))
{ {
auto graph = Sequential({ auto graph = Sequential({
Pad(padding_dims, (!name.empty()) ? name + "_pad" : ""), Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims) MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims)
}); });
......
...@@ -168,29 +168,25 @@ public: ...@@ -168,29 +168,25 @@ public:
} }
}; };
// We use DIMX2 rather than DIM because the compiler cannot infer DIM from std::array<DimSize_t, 2*DIM> template <std::array<DimSize_t, 1>::size_type DIM>
template <std::array<DimSize_t, 1>::size_type DIMX2> inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, 2*DIM> &beginEndTuples,
inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, DIMX2> &beginEndTuples,
const std::string& name = "", const std::string& name = "",
const PadBorderType &borderType = PadBorderType::Constant, const PadBorderType &borderType = PadBorderType::Constant,
double borderValue = 0.0) double borderValue = 0.0)
{ {
static_assert(DIMX2%2==0,"Invalid dimension"); static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported");
static_assert(DIMX2/2<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); return std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, borderType, borderValue), name);
return std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIMX2/2)>>(beginEndTuples, borderType, borderValue), name);
} }
// helper with C-style array instead of std::array for beginEndTuples to allow automatic template DIM deduction // helper with C-style array instead of std::array for beginEndTuples to allow automatic template DIM deduction
template <DimSize_t DIMX2> template <DimSize_t DIM>
inline std::shared_ptr<Node> Pad( inline std::shared_ptr<Node> Pad(
DimSize_t const (&beginEndTuples)[DIMX2], DimSize_t const (&beginEndTuples)[2*DIM],
const std::string& name = "", const std::string& name = "",
const PadBorderType &borderType = PadBorderType::Constant, const PadBorderType &borderType = PadBorderType::Constant,
double borderValue = 0.0) double borderValue = 0.0)
{ {
static_assert(DIMX2%2==0,"Invalid dimension"); return Pad<DIM>(to_array(beginEndTuples), name, borderType, borderValue);
static_assert(DIMX2/2<=MaxDim,"Too many kernel dimensions required by Pad, not supported");
return Pad<DIMX2/2>(to_array(beginEndTuples), name, borderType, borderValue);
} }
} // namespace Aidge } // namespace Aidge
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment