From 3ba465b3820e002e31bffe7060ea88a0618d82a1 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 29 Sep 2024 17:33:29 +0200 Subject: [PATCH] Added meta op input category --- include/aidge/operator/MetaOperator.hpp | 3 ++- include/aidge/operator/MetaOperatorDefs.hpp | 2 +- .../operator/pybind_MetaOperatorDefs.cpp | 6 ++++-- src/operator/MetaOperator.cpp | 14 ++++++++++---- src/operator/MetaOperatorDefs/LSTM.cpp | 2 +- src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp | 2 +- src/operator/MetaOperatorDefs/PaddedConv.cpp | 2 +- .../MetaOperatorDefs/PaddedConvDepthWise.cpp | 2 +- 8 files changed, 21 insertions(+), 12 deletions(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index ccff976cb..744dbd132 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -37,7 +37,7 @@ public: std::weak_ptr<Node> mUpperNode; public: - MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph); + MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory = {}); /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -113,6 +113,7 @@ public: std::shared_ptr<Node> MetaOperator(const char *type, const std::shared_ptr<GraphView>& graph, + const std::vector<InputCategory>& forcedInputsCategory = {}, const std::string& name = ""); } // namespace Aidge diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index bc3348377..750a808aa 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -126,7 +126,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> & MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims, ceil_mode) }); - return MetaOperator("PaddedMaxPooling", graph, name); + return MetaOperator("PaddedMaxPooling", graph, {}, name); } template <std::array<DimSize_t, 1>::size_type DIM> diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index 8ad7b5c3b..afd682f3e 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -195,15 +195,17 @@ void init_MetaOperatorDefs(py::module &m) { declare_LSTMOp(m); py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance()) - .def(py::init<const char *, const std::shared_ptr<GraphView>&>(), + .def(py::init<const char *, const std::shared_ptr<GraphView>&, const std::vector<InputCategory>&>(), py::arg("type"), - py::arg("graph")) + py::arg("graph"), + py::arg("forced_inputs_category") = std::vector<InputCategory>()) .def("get_micro_graph", &MetaOperator_Op::getMicroGraph) .def("set_upper_node", &MetaOperator_Op::setUpperNode); m.def("meta_operator", &MetaOperator, py::arg("type"), py::arg("graph"), + py::arg("forced_inputs_category") = std::vector<InputCategory>(), py::arg("name") = "" ); diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 372c8b953..ab6bde74f 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -20,17 +20,22 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/DynamicAttributes.hpp" -Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph) - : OperatorTensor(type, [graph]() { +Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory) + : OperatorTensor(type, [graph, forcedInputsCategory]() { + IOIndex_t inputIdx = 0; std::vector<InputCategory> inputsCategory; for (const auto& in : graph->getOrderedInputs()) { - if (in.first) { + if (inputIdx < forcedInputsCategory.size()) { + inputsCategory.push_back(forcedInputsCategory[inputIdx]); + } + else if (in.first) { inputsCategory.push_back(in.first->getOperator()->inputCategory(in.second)); } else { // Dummy input, default to OptionalData inputsCategory.push_back(InputCategory::OptionalData); } + ++inputIdx; } return inputsCategory; }(), graph->getOrderedOutputs().size()), @@ -245,9 +250,10 @@ void Aidge::MetaOperator_Op::forward() { std::shared_ptr<Aidge::Node> Aidge::MetaOperator(const char *type, const std::shared_ptr<Aidge::GraphView>& graph, + const std::vector<InputCategory>& forcedInputsCategory, const std::string& name) { - auto op = std::make_shared<MetaOperator_Op>(type, graph); + auto op = std::make_shared<MetaOperator_Op>(type, graph, forcedInputsCategory); auto node = std::make_shared<Node>(op, name); op->setUpperNode(node); return node; diff --git a/src/operator/MetaOperatorDefs/LSTM.cpp b/src/operator/MetaOperatorDefs/LSTM.cpp index 910e7c67a..9620f0404 100644 --- a/src/operator/MetaOperatorDefs/LSTM.cpp +++ b/src/operator/MetaOperatorDefs/LSTM.cpp @@ -115,7 +115,7 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel, {hiddenState, 1}, {cellState, 1}}); microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}}); - auto metaOp = MetaOperator("LSTM", microGraph, name); + auto metaOp = MetaOperator("LSTM", microGraph, {}, name); addProducer(metaOp, 1, {hiddenChannel, inChannel}, "wi"); addProducer(metaOp, 2, {hiddenChannel, inChannel}, "wo"); addProducer(metaOp, 3, {hiddenChannel, inChannel}, "wf"); diff --git a/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp b/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp index ef319ef38..c35d964d0 100644 --- a/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp +++ b/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp @@ -41,7 +41,7 @@ std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &kernel_ AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims) }); - return MetaOperator("PaddedAvgPooling", graph, name); + return MetaOperator("PaddedAvgPooling", graph, {}, name); } template std::shared_ptr<Node> PaddedAvgPooling<1>(const std::array<DimSize_t,1>&, const std::string&, const std::array<DimSize_t,1>&, const std::array<DimSize_t,2>&); diff --git a/src/operator/MetaOperatorDefs/PaddedConv.cpp b/src/operator/MetaOperatorDefs/PaddedConv.cpp index 31b1c675e..49373341a 100644 --- a/src/operator/MetaOperatorDefs/PaddedConv.cpp +++ b/src/operator/MetaOperatorDefs/PaddedConv.cpp @@ -43,7 +43,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConv(Aidge::DimSize_t in_channels, Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "") }); - auto metaOpNode = MetaOperator("PaddedConv", graph, name); + auto metaOpNode = MetaOperator("PaddedConv", graph, {}, name); addProducer(metaOpNode, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); if (!no_bias) { addProducer(metaOpNode, 2, {out_channels}, "b"); diff --git a/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp b/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp index 1c073b78a..12d980b40 100644 --- a/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp +++ b/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp @@ -40,7 +40,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConvDepthWise(const Aidge::DimSize_t n Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv_depth_wise" : "") }); - auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, name); + auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, {}, name); addProducer(metaOpNode, 1, append(nb_channels, append(Aidge::DimSize_t(1), kernel_dims)), "w"); if (!no_bias) { addProducer(metaOpNode, 2, {nb_channels}, "b"); -- GitLab