Skip to content
Snippets Groups Projects
Commit ac1d2bbb authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Changed Pad DIM handling

parent 19ab0c3e
No related branches found
No related tags found
1 merge request!34Changed Pad DIM handling
Pipeline #32988 passed
......@@ -30,7 +30,7 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels,
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
{
// Construct micro-graph
auto pad = Pad(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
......@@ -63,7 +63,7 @@ inline std::shared_ptr<Node> PaddedConvDepthWise(const std::array<DimSize_t, DIM
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
{
// Construct micro-graph
auto pad = Pad(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
......@@ -93,7 +93,7 @@ inline std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &
const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0))
{
auto graph = Sequential({
Pad(padding_dims, (!name.empty()) ? name + "_pad" : ""),
Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims)
});
......@@ -118,7 +118,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> &
const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0))
{
auto graph = Sequential({
Pad(padding_dims, (!name.empty()) ? name + "_pad" : ""),
Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims)
});
......
......@@ -168,29 +168,25 @@ public:
}
};
// We use DIMX2 rather than DIM because the compiler cannot infer DIM from std::array<DimSize_t, 2*DIM>
template <std::array<DimSize_t, 1>::size_type DIMX2>
inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, DIMX2> &beginEndTuples,
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, 2*DIM> &beginEndTuples,
const std::string& name = "",
const PadBorderType &borderType = PadBorderType::Constant,
double borderValue = 0.0)
{
static_assert(DIMX2%2==0,"Invalid dimension");
static_assert(DIMX2/2<=MaxDim,"Too many kernel dimensions required by Pad, not supported");
return std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIMX2/2)>>(beginEndTuples, borderType, borderValue), name);
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported");
return std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, borderType, borderValue), name);
}
// helper with C-style array instead of std::array for beginEndTuples to allow automatic template DIM deduction
template <DimSize_t DIMX2>
template <DimSize_t DIM>
inline std::shared_ptr<Node> Pad(
DimSize_t const (&beginEndTuples)[DIMX2],
DimSize_t const (&beginEndTuples)[2*DIM],
const std::string& name = "",
const PadBorderType &borderType = PadBorderType::Constant,
double borderValue = 0.0)
{
static_assert(DIMX2%2==0,"Invalid dimension");
static_assert(DIMX2/2<=MaxDim,"Too many kernel dimensions required by Pad, not supported");
return Pad<DIMX2/2>(to_array(beginEndTuples), name, borderType, borderValue);
return Pad<DIM>(to_array(beginEndTuples), name, borderType, borderValue);
}
} // namespace Aidge
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment