From a5f75017c0699440e5c38b1a3c68df08d1921539 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 13 Oct 2023 15:56:57 +0000 Subject: [PATCH] Added PaddedConvDepthWise --- include/aidge/operator/MetaOperator.hpp | 5 --- include/aidge/operator/MetaOperatorDefs.hpp | 39 +++++++++++++++++++++ src/operator/MetaOperator.cpp | 1 + 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index ae62a1181..bb34fd9c7 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -13,11 +13,6 @@ #define AIDGE_CORE_OPERATOR_METAOPERATOR_H_ #include "aidge/operator/Operator.hpp" -#include "aidge/operator/AvgPooling.hpp" -#include "aidge/operator/MaxPooling.hpp" -#include "aidge/operator/Conv.hpp" -#include "aidge/operator/Conv.hpp" -#include "aidge/operator/Pad.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/scheduler/Scheduler.hpp" diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index 346905dc9..df66cec7e 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -13,6 +13,11 @@ #define AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ #include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/Pad.hpp" namespace Aidge { template <std::array<DimSize_t, 1>::size_type DIM> @@ -49,6 +54,40 @@ inline std::shared_ptr<Node> PaddedConv( return PaddedConv<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); } +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}, + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + // Construct micro-graph + auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(padding_dims, PadBorderType::Constant, 0.0), (!name.empty()) ? name + "_pad" : ""); + auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); + // Need to specify the ordered list of input operators + const std::vector<NodePtr> orderedInputNodes = {pad, conv}; + + auto metaOp = std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConvDepthWise", Sequential({pad, conv}), orderedInputNodes), name); + addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); + addProducer(metaOp, 2, {out_channels}, "b"); + return metaOp; +} + +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise( + DimSize_t in_channels, + DimSize_t out_channels, + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}, + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + return PaddedConvDepthWise<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> PaddedAvgPooling(DimSize_t in_channels, DimSize_t out_channels, diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index d33376e4f..c1f58c686 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include "aidge/operator/MetaOperator.hpp" +#include "aidge/utils/ErrorHandling.hpp" Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, std::vector<NodePtr> inputNodes, -- GitLab