From 33c779b288d8e1ca8ee46800e43d34b4dc011f9d Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Mon, 9 Oct 2023 18:43:40 +0200
Subject: [PATCH] Fixed inputs/outputs mechanism

---
 include/aidge/operator/MetaOperator.hpp | 167 +++++++++++++++---------
 include/aidge/scheduler/Scheduler.hpp   |   5 -
 include/aidge/utils/Utils.hpp           |   5 +-
 src/scheduler/Scheduler.cpp             |  16 +--
 4 files changed, 114 insertions(+), 79 deletions(-)

diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp
index b45c0ae8a..6018c7c15 100644
--- a/include/aidge/operator/MetaOperator.hpp
+++ b/include/aidge/operator/MetaOperator.hpp
@@ -13,6 +13,9 @@
 #define AIDGE_CORE_OPERATOR_METAOPERATOR_H_
 
 #include "aidge/operator/Operator.hpp"
+#include "aidge/operator/AvgPooling.hpp"
+#include "aidge/operator/MaxPooling.hpp"
+#include "aidge/operator/Conv.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/Pad.hpp"
 #include "aidge/graph/GraphView.hpp"
@@ -25,11 +28,20 @@ class MetaOperator_Op : public Operator,
 public:
     std::vector<std::shared_ptr<Tensor>> mInputs;
     std::vector<std::shared_ptr<Tensor>> mOutputs; // These are shared with micro-graph outputs tensors
+
+    // Micro-graph handling:
     std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph
     std::shared_ptr<SequentialScheduler> mScheduler;
+    // Need to store an ordored list of input/output operators for the micro-graph,
+    // because input/output nodes in a GraphView are unordered.
+    // TODO: refactor GraphView to handle ordered input/output?
+    std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mInputOps;
+    std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mOutputOps;
 
    public:
-    MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph)
+    MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph,
+        std::vector<NodePtr> inputNodes = std::vector<NodePtr>(),
+        std::vector<NodePtr> outputNodes = std::vector<NodePtr>())
         : Operator(type),
           mGraph(graph)
     {
@@ -41,6 +53,49 @@ public:
         for (std::size_t i = 0; i < mOutputs.size(); ++i) {
             mOutputs[i] = std::make_shared<Tensor>();
         }
+
+        // Fill inputsNodes and outputsNodes when there is no ambiguity
+        if (inputNodes.empty()) {
+            AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping");
+            inputNodes.push_back(*mGraph->inputNodes().begin());
+        }
+
+        if (outputNodes.empty()) {
+            AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping");
+            outputNodes.push_back(*mGraph->outputNodes().begin());
+        }
+
+        // Identify inputs that are outside the micro-graph
+        for (const auto& inputNode : inputNodes) {
+            AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph");
+            const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
+                inputNode->inputs();
+            
+            int inputIdx = 0;   // input idx relative to the current node
+            for (const auto& in : inputNodeinputs) {
+                if (in.first == nullptr || !mGraph->inView(in.first)) {
+                    // The input is not connected inside the micro-graph
+                    // (no connection to this input or connection outside the micro-graph)
+                    // => it is therefore an input for the meta-operator
+                    mInputOps.push_back(std::make_pair(inputNode->getOperator(), inputIdx));
+                }
+
+                ++inputIdx;
+            }
+        }
+
+        // The outputs of the output nodes are also the outputs of the meta-operator
+        for (const auto& outputNode : outputNodes) {
+            AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph");
+            const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs =
+                outputNode->outputs();
+
+            int outputIdx = 0;   // output idx relative to the current node
+            for (const auto& out : outputNodeoutputs) {
+                mOutputOps.push_back(std::make_pair(outputNode->getOperator(), outputIdx));
+                ++outputIdx;
+            }
+        }
     }
 
     /**
@@ -73,20 +128,8 @@ public:
     void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
         assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
 
-        // Associate micro-graph inputs
-        std::size_t nbGraphIn = 0U;
-        // TODO: FIXME: inputNodes() is unordered!
-        for (const std::shared_ptr<Node>& inputNode : mGraph->inputNodes()) {
-            const std::size_t nbIn = inputNode->nbInputs();
-
-            if (inputIdx < nbGraphIn + nbIn) {
-                // FIXME: !!!workaround only for the PaddedConv unit test!!!
-                inputNode->getOperator()->associateInput(inputIdx /*- nbGraphIn*/, data);
-                break;
-            }
-
-            nbGraphIn += nbIn;
-        }
+        const auto& inputOp = mInputOps[inputIdx];
+        inputOp.first->associateInput(inputOp.second, data);
 
         // Associate inputs for custom implementation
         mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
@@ -97,16 +140,9 @@ public:
         mGraph->forwardDims();
 
         // Associate outputs to micro-graph outputs for custom implementation
-        std::size_t nbGraphOut = 0U;
-        // TODO: FIXME: inputNodes() is unordered!
-        for (const std::shared_ptr<Node>& outputNode : mGraph->outputNodes()) {
-            const std::size_t nbOut = outputNode->nbOutputs();
-
-            for (size_t outputIdx = nbGraphOut; outputIdx < nbGraphOut + nbOut; ++outputIdx) {
-                mOutputs[outputIdx] = outputNode->getOperator()->getOutput(outputIdx - nbGraphOut);
-            }
-
-            nbGraphOut += nbOut;
+        for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) {
+            const auto& outputOp = mOutputOps[outputIdx];
+            mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second);
         }
     }
 
@@ -171,19 +207,8 @@ public:
             return mImpl->getNbRequiredData(inputIdx);
         }
         else {
-            std::size_t nbGraphIn = 0U;
-            // TODO: FIXME: inputNodes() is unordered!
-            for (const std::shared_ptr<Node>& inputNode : mGraph->inputNodes()) {
-                const std::size_t nbIn = inputNode->nbInputs();
-
-                if (inputIdx < nbGraphIn + nbIn) {
-                    return inputNode->getOperator()->getNbRequiredData(inputIdx - nbGraphIn);
-                }
-
-                nbGraphIn += nbIn;
-            }
-
-            assert(false && "inputIdx out of range");
+            const auto& inputOp = mInputOps[inputIdx];
+            return inputOp.first->getNbRequiredData(inputOp.second);
         }
     }
 
@@ -192,19 +217,8 @@ public:
             return mImpl->getNbConsumedData(inputIdx);
         }
         else {
-            std::size_t nbGraphIn = 0U;
-            // TODO: FIXME: inputNodes() is unordered!
-            for (const std::shared_ptr<Node>& inputNode : mGraph->inputNodes()) {
-                const std::size_t nbIn = inputNode->nbInputs();
-
-                if (inputIdx < nbGraphIn + nbIn) {
-                    return inputNode->getOperator()->getNbConsumedData(inputIdx - nbGraphIn);
-                }
-
-                nbGraphIn += nbIn;
-            }
-
-            assert(false && "inputIdx out of range");
+            const auto& inputOp = mInputOps[inputIdx];
+            return inputOp.first->getNbConsumedData(inputOp.second);
         }
     }
 
@@ -213,19 +227,8 @@ public:
             return mImpl->getNbProducedData(outputIdx);
         }
         else {
-            std::size_t nbGraphOut = 0U;
-            // TODO: FIXME: outputNodes() is unordered!
-            for (const std::shared_ptr<Node>& outputNode : mGraph->outputNodes()) {
-                const std::size_t nbOut = outputNode->nbOutputs();
-
-                if (outputIdx < nbGraphOut + nbOut) {
-                    return outputNode->getOperator()->getNbProducedData(outputIdx - nbGraphOut);
-                }
-
-                nbGraphOut += nbOut;
-            }
-
-            assert(false && "outputIdx out of range");
+            const auto& outputOp = mOutputOps[outputIdx];
+            return outputOp.first->getNbProducedData(outputOp.second);
         }
     }
 
@@ -296,7 +299,9 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels,
     graph->add(pad, false);
     graph->add(conv, false);
 
-    return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", graph), name);
+    // Need to specify the ordered list of input operators
+    const std::vector<NodePtr> orderedInputNodes = {pad, conv};
+    return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", graph, orderedInputNodes), name);
 }
 
 template <DimSize_t DIM>
@@ -311,6 +316,38 @@ inline std::shared_ptr<Node> PaddedConv(
 {
     return PaddedConv<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims);
 }
+
+template <std::array<DimSize_t, 1>::size_type DIM>
+inline std::shared_ptr<Node> PaddedAvgPooling(DimSize_t in_channels,
+                                  DimSize_t out_channels,
+                                  const std::array<DimSize_t, DIM> &kernel_dims,
+                                  const std::string& name = "",
+                                  const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
+                                  const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0})
+{
+    auto graph = Sequential({
+        Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
+        AvgPooling_Op<DIM>(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims)
+    });
+
+    return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedAvgPooling", graph), name);
+}
+
+template <std::array<DimSize_t, 1>::size_type DIM>
+inline std::shared_ptr<Node> PaddedMaxPooling(DimSize_t in_channels,
+                                  DimSize_t out_channels,
+                                  const std::array<DimSize_t, DIM> &kernel_dims,
+                                  const std::string& name = "",
+                                  const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
+                                  const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0})
+{
+    auto graph = Sequential({
+        Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
+        MaxPooling_Op<DIM>(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims)
+    });
+
+    return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedMaxPooling", graph), name);
+}
 }  // namespace Aidge
 
 #endif /* MetaOperator_H_ */
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 9916ee200..1896894ee 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -89,11 +89,6 @@ private:
      *
      */
     std::vector<std::shared_ptr<Node>> mStaticSchedule;
-    /**
-     * @brief Number of computation node (i.e: nb nodes != Producer)
-     *
-     */
-    std::size_t mComputationNumber = 0; // TODO: Check if not inferable from mStaticSchedule
 };
 } // namespace Aidge
 
diff --git a/include/aidge/utils/Utils.hpp b/include/aidge/utils/Utils.hpp
index 89dc25bee..5facc80ea 100644
--- a/include/aidge/utils/Utils.hpp
+++ b/include/aidge/utils/Utils.hpp
@@ -35,4 +35,7 @@ do { \
 } while (false)
 #endif
 
-#endif //AIDGE_UTILS_H_
\ No newline at end of file
+#define AIDGE_ASSERT(stm, ...) \
+if (!(stm)) { AIDGE_THROW_OR_ABORT(std::runtime_error, __VA_ARGS__); }
+
+#endif //AIDGE_UTILS_H_
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 939289414..554d745e6 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -40,13 +40,10 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
     // TODO: optimize memory usage
 
     // setup initial producers list
-    mComputationNumber = 0;
     std::set<std::shared_ptr<Node>> producers;
     for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
         if (nodePtr->type() == "Producer") {
             producers.insert(nodePtr);
-        } else {
-            ++mComputationNumber;
         }
     }
     // add Data Input
@@ -178,15 +175,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
 
 // TODO: handle multiple inputs/outputs
 void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
+    // Forward dims (if allowed)
     if (forwardDims) {mGraphView->forwardDims(); }
 
-    // add each Producer Node.
-    std::set<std::shared_ptr<Node>> computationOver;
+    // Generate scheduling *only if empty*
+    // If scheduling was already generated (in one or several steps, i.e. one or
+    // several successive call to generateScheduling()), do not generate it twice
+    if (mStaticSchedule.empty()) {
+        this->generateScheduling();
+    }
 
+    // Clear previous scheduling results
     mScheduling.clear();
-    mStaticSchedule.clear();
 
-    this->generateScheduling();
     int cpt = 0;
     for (const auto& runnable : mStaticSchedule) {
         if (verbose)
@@ -204,7 +205,6 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
     }
     if (!verbose) drawProgressBar(1.0, 50, "                                   ");
     printf("\n");
-
 }
 
 void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
-- 
GitLab