Skip to content
Snippets Groups Projects
Commit a5f75017 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added PaddedConvDepthWise

parent 9a75edae
No related branches found
No related tags found
1 merge request!28Added PaddedConvDepthWise
...@@ -13,11 +13,6 @@ ...@@ -13,11 +13,6 @@
#define AIDGE_CORE_OPERATOR_METAOPERATOR_H_ #define AIDGE_CORE_OPERATOR_METAOPERATOR_H_
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
......
...@@ -13,6 +13,11 @@ ...@@ -13,6 +13,11 @@
#define AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ #define AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_
#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/Pad.hpp"
namespace Aidge { namespace Aidge {
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
...@@ -49,6 +54,40 @@ inline std::shared_ptr<Node> PaddedConv( ...@@ -49,6 +54,40 @@ inline std::shared_ptr<Node> PaddedConv(
return PaddedConv<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); return PaddedConv<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims);
} }
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> PaddedConvDepthWise(DimSize_t in_channels,
DimSize_t out_channels,
const std::array<DimSize_t, DIM> &kernel_dims,
const std::string& name = "",
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0},
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
{
// Construct micro-graph
auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(padding_dims, PadBorderType::Constant, 0.0), (!name.empty()) ? name + "_pad" : "");
auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
auto metaOp = std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConvDepthWise", Sequential({pad, conv}), orderedInputNodes), name);
addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w");
addProducer(metaOp, 2, {out_channels}, "b");
return metaOp;
}
template <DimSize_t DIM>
inline std::shared_ptr<Node> PaddedConvDepthWise(
DimSize_t in_channels,
DimSize_t out_channels,
DimSize_t const (&kernel_dims)[DIM],
const std::string& name = "",
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0},
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
{
return PaddedConvDepthWise<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims);
}
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> PaddedAvgPooling(DimSize_t in_channels, inline std::shared_ptr<Node> PaddedAvgPooling(DimSize_t in_channels,
DimSize_t out_channels, DimSize_t out_channels,
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
********************************************************************************/ ********************************************************************************/
#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph,
std::vector<NodePtr> inputNodes, std::vector<NodePtr> inputNodes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment