From 3ba465b3820e002e31bffe7060ea88a0618d82a1 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 29 Sep 2024 17:33:29 +0200
Subject: [PATCH] Added meta op input category

---
 include/aidge/operator/MetaOperator.hpp            |  3 ++-
 include/aidge/operator/MetaOperatorDefs.hpp        |  2 +-
 .../operator/pybind_MetaOperatorDefs.cpp           |  6 ++++--
 src/operator/MetaOperator.cpp                      | 14 ++++++++++----
 src/operator/MetaOperatorDefs/LSTM.cpp             |  2 +-
 src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp |  2 +-
 src/operator/MetaOperatorDefs/PaddedConv.cpp       |  2 +-
 .../MetaOperatorDefs/PaddedConvDepthWise.cpp       |  2 +-
 8 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp
index ccff976cb..744dbd132 100644
--- a/include/aidge/operator/MetaOperator.hpp
+++ b/include/aidge/operator/MetaOperator.hpp
@@ -37,7 +37,7 @@ public:
     std::weak_ptr<Node> mUpperNode;
 
    public:
-    MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph);
+    MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory = {});
 
     /**
      * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
@@ -113,6 +113,7 @@ public:
 
 std::shared_ptr<Node> MetaOperator(const char *type,
                                   const std::shared_ptr<GraphView>& graph,
+                                  const std::vector<InputCategory>& forcedInputsCategory = {},
                                   const std::string& name = "");
 }  // namespace Aidge
 
diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp
index bc3348377..750a808aa 100644
--- a/include/aidge/operator/MetaOperatorDefs.hpp
+++ b/include/aidge/operator/MetaOperatorDefs.hpp
@@ -126,7 +126,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> &
         MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims, ceil_mode)
     });
 
-    return MetaOperator("PaddedMaxPooling", graph, name);
+    return MetaOperator("PaddedMaxPooling", graph, {}, name);
 }
 
 template <std::array<DimSize_t, 1>::size_type DIM>
diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp
index 8ad7b5c3b..afd682f3e 100644
--- a/python_binding/operator/pybind_MetaOperatorDefs.cpp
+++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp
@@ -195,15 +195,17 @@ void init_MetaOperatorDefs(py::module &m) {
   declare_LSTMOp(m);
 
   py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance())
-  .def(py::init<const char *, const std::shared_ptr<GraphView>&>(),
+  .def(py::init<const char *, const std::shared_ptr<GraphView>&, const std::vector<InputCategory>&>(),
           py::arg("type"),
-          py::arg("graph"))
+          py::arg("graph"),
+          py::arg("forced_inputs_category") = std::vector<InputCategory>())
   .def("get_micro_graph", &MetaOperator_Op::getMicroGraph)
   .def("set_upper_node", &MetaOperator_Op::setUpperNode);
 
   m.def("meta_operator", &MetaOperator,
     py::arg("type"),
     py::arg("graph"),
+    py::arg("forced_inputs_category") = std::vector<InputCategory>(),
     py::arg("name") = ""
   );
 
diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp
index 372c8b953..ab6bde74f 100644
--- a/src/operator/MetaOperator.cpp
+++ b/src/operator/MetaOperator.cpp
@@ -20,17 +20,22 @@
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/DynamicAttributes.hpp"
 
-Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph)
-    : OperatorTensor(type, [graph]() {
+Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory)
+    : OperatorTensor(type, [graph, forcedInputsCategory]() {
+        IOIndex_t inputIdx = 0;
         std::vector<InputCategory> inputsCategory;
         for (const auto& in : graph->getOrderedInputs()) {
-            if (in.first) {
+            if (inputIdx < forcedInputsCategory.size()) {
+                inputsCategory.push_back(forcedInputsCategory[inputIdx]);
+            }
+            else if (in.first) {
                 inputsCategory.push_back(in.first->getOperator()->inputCategory(in.second));
             }
             else {
                 // Dummy input, default to OptionalData
                 inputsCategory.push_back(InputCategory::OptionalData);
             }
+            ++inputIdx;
         }
         return inputsCategory;
     }(), graph->getOrderedOutputs().size()),
@@ -245,9 +250,10 @@ void Aidge::MetaOperator_Op::forward() {
 
 std::shared_ptr<Aidge::Node> Aidge::MetaOperator(const char *type,
                                   const std::shared_ptr<Aidge::GraphView>& graph,
+                                  const std::vector<InputCategory>& forcedInputsCategory,
                                   const std::string& name)
 {
-    auto op = std::make_shared<MetaOperator_Op>(type, graph);
+    auto op = std::make_shared<MetaOperator_Op>(type, graph, forcedInputsCategory);
     auto node = std::make_shared<Node>(op, name);
     op->setUpperNode(node);
     return node;
diff --git a/src/operator/MetaOperatorDefs/LSTM.cpp b/src/operator/MetaOperatorDefs/LSTM.cpp
index 910e7c67a..9620f0404 100644
--- a/src/operator/MetaOperatorDefs/LSTM.cpp
+++ b/src/operator/MetaOperatorDefs/LSTM.cpp
@@ -115,7 +115,7 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
         {hiddenState, 1}, {cellState, 1}});
     microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}});
 
-    auto metaOp = MetaOperator("LSTM", microGraph, name);
+    auto metaOp = MetaOperator("LSTM", microGraph, {}, name);
     addProducer(metaOp, 1, {hiddenChannel, inChannel}, "wi");
     addProducer(metaOp, 2, {hiddenChannel, inChannel}, "wo");
     addProducer(metaOp, 3, {hiddenChannel, inChannel}, "wf");
diff --git a/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp b/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp
index ef319ef38..c35d964d0 100644
--- a/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp
+++ b/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp
@@ -41,7 +41,7 @@ std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &kernel_
         AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims)
     });
 
-    return MetaOperator("PaddedAvgPooling", graph, name);
+    return MetaOperator("PaddedAvgPooling", graph, {}, name);
 }
 
 template std::shared_ptr<Node> PaddedAvgPooling<1>(const std::array<DimSize_t,1>&, const std::string&, const std::array<DimSize_t,1>&, const std::array<DimSize_t,2>&);
diff --git a/src/operator/MetaOperatorDefs/PaddedConv.cpp b/src/operator/MetaOperatorDefs/PaddedConv.cpp
index 31b1c675e..49373341a 100644
--- a/src/operator/MetaOperatorDefs/PaddedConv.cpp
+++ b/src/operator/MetaOperatorDefs/PaddedConv.cpp
@@ -43,7 +43,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConv(Aidge::DimSize_t in_channels,
         Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
         std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "")
     });
-    auto metaOpNode = MetaOperator("PaddedConv", graph, name);
+    auto metaOpNode = MetaOperator("PaddedConv", graph, {}, name);
     addProducer(metaOpNode, 1, append(out_channels, append(in_channels, kernel_dims)), "w");
     if (!no_bias) {
         addProducer(metaOpNode, 2, {out_channels}, "b");
diff --git a/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp b/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp
index 1c073b78a..12d980b40 100644
--- a/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp
+++ b/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp
@@ -40,7 +40,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConvDepthWise(const Aidge::DimSize_t n
         Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
         std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv_depth_wise" : "")
     });
-    auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, name);
+    auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, {}, name);
     addProducer(metaOpNode, 1, append(nb_channels, append(Aidge::DimSize_t(1), kernel_dims)), "w");
     if (!no_bias) {
         addProducer(metaOpNode, 2, {nb_channels}, "b");
-- 
GitLab