From f80511fbf2b55d1c4ab77215f061880e354e3c64 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 9 Feb 2024 15:12:53 +0100
Subject: [PATCH] Added LSTM meta-operator (not tested yet with actuel values)

---
 include/aidge/operator/MetaOperatorDefs.hpp |  92 ++++++++++++
 include/aidge/scheduler/Scheduler.hpp       |  14 ++
 include/aidge/utils/Formatting.hpp          |  11 +-
 src/scheduler/Scheduler.cpp                 | 147 ++++++++++++++------
 unit_tests/operator/Test_MetaOperator.cpp   |  29 ++++
 5 files changed, 245 insertions(+), 48 deletions(-)

diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp
index 2832f9fce..492bcd95d 100644
--- a/include/aidge/operator/MetaOperatorDefs.hpp
+++ b/include/aidge/operator/MetaOperatorDefs.hpp
@@ -18,6 +18,14 @@
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/ConvDepthWise.hpp"
 #include "aidge/operator/Pad.hpp"
+#include "aidge/operator/Memorize.hpp"
+#include "aidge/operator/Add.hpp"
+#include "aidge/operator/Mul.hpp"
+#include "aidge/operator/FC.hpp"
+#include "aidge/operator/Identity.hpp"
+#include "aidge/operator/Concat.hpp"
+#include "aidge/operator/Tanh.hpp"
+#include "aidge/operator/Sigmoid.hpp"
 
 namespace Aidge {
 template <std::array<DimSize_t, 1>::size_type DIM>
@@ -135,6 +143,90 @@ inline std::shared_ptr<Node> PaddedMaxPooling(
 {
     return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode);
 }
+
+inline std::shared_ptr<Node> LTSM(DimSize_t in_channels,
+                                  DimSize_t hidden_channels,
+                                  DimSize_t seq_length,
+                                  const std::string& name = "")
+{
+    // Construct micro-graph
+    auto input = Identity((!name.empty()) ? name + "_input" : "");
+    auto hiddenState = Memorize(seq_length, (!name.empty()) ? name + "_hidden_state" : "");
+    auto cellState = Memorize(seq_length, (!name.empty()) ? name + "_cell_state" : "");
+    auto add = Add(2, (!name.empty()) ? name + "_add" : "");
+
+    // Forget gate
+    auto forgetGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateX" : "");
+    input->addChild(forgetGateX, 0, 0);
+    auto forgetGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateH" : "");
+    hiddenState->addChild(forgetGateH, 1, 0);
+    auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : "");
+    forgetGateX->addChild(forgetGate, 0, 0);
+    forgetGateH->addChild(forgetGate, 0, 1);
+    auto forgetGateAct = Sigmoid((!name.empty()) ? name + "_forgetGateAct" : "");
+    auto forgetGateMul = Mul((!name.empty()) ? name + "_forgetGateMul" : "");
+    forgetGate->addChild(forgetGateAct, 0, 0);
+    forgetGateAct->addChild(forgetGateMul, 0, 0);
+    forgetGateMul->addChild(add, 0, 0);
+    cellState->addChild(forgetGateMul, 1, 1);
+
+    // Input gate
+    auto inputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateX" : "");
+    input->addChild(inputGateX, 0, 0);
+    auto inputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateH" : "");
+    hiddenState->addChild(inputGateH, 1, 0);
+    auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : "");
+    inputGateX->addChild(inputGate, 0, 0);
+    inputGateH->addChild(inputGate, 0, 1);
+    auto inputGateAct = Sigmoid((!name.empty()) ? name + "_inputGateAct" : "");
+    auto inputGateMul = Mul((!name.empty()) ? name + "_inputGateMul" : "");
+    inputGate->addChild(inputGateAct, 0, 0);
+    inputGateAct->addChild(inputGateMul, 0, 0);
+    inputGateMul->addChild(add, 0, 1);
+
+    // Candidate for cell update
+    auto cellCandidateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateX" : "");
+    input->addChild(cellCandidateX, 0, 0);
+    auto cellCandidateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateH" : "");
+    hiddenState->addChild(cellCandidateH, 1, 0);
+    auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : "");
+    cellCandidateX->addChild(cellCandidate, 0, 0);
+    cellCandidateH->addChild(cellCandidate, 0, 1);
+    auto cellCandidateAct = Tanh((!name.empty()) ? name + "_cellCandidateAct" : "");
+    cellCandidate->addChild(cellCandidateAct, 0, 0);
+    cellCandidateAct->addChild(inputGateMul, 0, 1);
+
+    // Output gate
+    auto outputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateX" : "");
+    input->addChild(outputGateX, 0, 0);
+    auto outputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateH" : "");
+    hiddenState->addChild(outputGateH, 1, 0);
+    auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : "");
+    outputGateX->addChild(outputGate, 0, 0);
+    outputGateH->addChild(outputGate, 0, 1);
+    auto outputGateAct = Sigmoid((!name.empty()) ? name + "_outputGateAct" : "");
+    auto outputGateMul = Mul((!name.empty()) ? name + "_outputGateMul" : "");
+    outputGate->addChild(outputGateAct, 0, 0);
+    outputGateAct->addChild(outputGateMul, 0, 0);
+
+    // Updated cell state to help determine new hidden state
+    auto cellUpdatedAct = Tanh((!name.empty()) ? name + "_cellUpdatedAct" : "");
+    add->addChild(cellUpdatedAct, 0, 0);
+    cellUpdatedAct->addChild(outputGateMul, 0, 1);
+    outputGateMul->addChild(hiddenState, 0, 0);
+    add->addChild(cellState, 0, 0);
+
+    std::shared_ptr<GraphView> microGraph = std::make_shared<GraphView>();
+    microGraph->add(input);
+    microGraph->add({hiddenState, cellState, add,
+        forgetGateX, forgetGateH, forgetGate, forgetGateAct, forgetGateMul,
+        inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul,
+        cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct,
+        outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul,
+        cellUpdatedAct});
+
+    return MetaOperator("LTSM", microGraph, name);
+}
 }  // namespace Aidge
 
 #endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 6dcec5aaa..81dc6d7cc 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -17,6 +17,7 @@
 #include <set>
 #include <string>
 #include <vector>
+#include <map>
 
 namespace Aidge {
 class Node;
@@ -36,6 +37,12 @@ private:
         std::chrono::time_point<std::chrono::high_resolution_clock> end;
     };
 
+    struct PriorProducersConsumers {
+        bool isPrior = false;
+        std::set<std::shared_ptr<Aidge::Node>> requiredProducers;
+        std::set<std::shared_ptr<Aidge::Node>> priorConsumers;
+    };
+
 public:
     SequentialScheduler(std::shared_ptr<GraphView> graphView)
         : mGraphView(graphView)
@@ -80,6 +87,13 @@ private:
      * @return std::set<std::shared_ptr<Node>>
      */
     std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
+    PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
+
+    /**
+     * Return a std::map with corresponding node's name.
+     * TODO: Mutualise with similar code in GraphView::save()?
+    */
+    std::map<std::shared_ptr<Node>, std::string> getNodesName(bool verbose) const;
 
     /** @brief Shared ptr to the scheduled graph view */
     std::shared_ptr<GraphView> mGraphView;
diff --git a/include/aidge/utils/Formatting.hpp b/include/aidge/utils/Formatting.hpp
index e96e87e35..4bcb5cad6 100644
--- a/include/aidge/utils/Formatting.hpp
+++ b/include/aidge/utils/Formatting.hpp
@@ -50,8 +50,8 @@ std::string stringFormat(const std::string& format, Args... args) {
 /**
  * Print any iterable object in a std::string.
 */
-template <class T>
-std::string print(const T& vec, const std::string& format) {
+template <class T, typename F>
+std::string print(const T& vec, const std::string& format, const F& func) {
     std::string str = "{";
     bool first = true;
     for (const auto& val : vec) {
@@ -61,11 +61,16 @@ std::string print(const T& vec, const std::string& format) {
         else {
             first = false;
         }
-        str += stringFormat(format, val);
+        str += stringFormat(format, func(val));
     }
     str += "}";
     return str;
 }
+
+template <class T>
+std::string print(const T& vec, const std::string& format) {
+    return print(vec, format, [](auto val){ return val; });
+}
 }
 
 #endif //AIDGE_FORMATTING_H_
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 2093aa5af..8273df21b 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -21,6 +21,7 @@
 #include "aidge/utils/Types.h"
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/Producer.hpp"
+#include "aidge/operator/Memorize.hpp"
 
 void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
     putchar('[');
@@ -60,12 +61,22 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
     // runnable consumer, the list of consumer is again equal to frozenConsumers
     // it means we are in cycle with no more scheduling update, a.k.a. a
     // frozen state.
-    std::set<std::shared_ptr<Node>> frozenConsumers;
+    std::vector<std::set<std::shared_ptr<Node>>> frozenConsumers;
+
+    std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose);
 
     do {
-        // Check required producers
+        // From the current consumers list, check if any prior nodes are needed.
+        // If for a given node, only parent producers (at any depth) are needed
+        // to satisfy its required data, it becomes a prior.
+        // If the prior node is a producer, it is added to the list of required
+        // producers.
+        // If the prior node is of another type, it replaces the initial consumer
+        // in the new priorConsumers list. The initial consumer will necessarily
+        // be added again later in the consumers list.
+        if (verbose) printf("List of consumers with their priors:\n");
         std::set<std::shared_ptr<Node>> requiredProducers;
-        if (verbose) printf("Required producers:\n");
+        std::set<std::shared_ptr<Node>> priorConsumers;
 
         for (const auto& consumer : consumers) {
             if (verbose) {
@@ -74,43 +85,27 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
                        "%s"
                        "\x1b[0m"
                        "\n",
-                       (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
+                       namePtrTable[consumer].c_str());
             }
 
-            std::set<std::shared_ptr<Node>> consumerRequiredProducers;
-            bool requiredProducerOnly = true;
-            IOIndex_t inputIdx = 0;
-            for (const auto& consumerParent : consumer->inputs()) {
-                if (verbose) printf("\t\t#%u: ", inputIdx);
-
-                if (consumerParent.first &&
-                    (consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
-                            consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
-                    if (verbose) printf("required data from %s: C%zu + R%zu > P%zu\n",
-                        consumerParent.first->type().c_str(),
-                        consumer->getOperator()->getNbConsumedData(inputIdx),
-                        consumer->getOperator()->getNbRequiredData(inputIdx),
-                        consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
+            const auto& prior = getPriorProducersConsumers(consumer);
 
-                    if (consumerParent.first->type() == Producer_Op::Type) {
-                        consumerRequiredProducers.insert(consumerParent.first);
-                    }
-                    else {
-                        requiredProducerOnly = false;
-                        break;
-                    }
+            if (prior.isPrior) {
+                if (verbose) {
+                    printf("\t\trequired producers: %s\n", print(prior.requiredProducers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str());
+                    printf("\t\tprior consumers: %s\n", print(prior.priorConsumers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str());
                 }
-                else {
-                    if (verbose) printf("no data required\n");
-                }
-                ++inputIdx;
-            }
 
-            if (requiredProducerOnly) {
-                requiredProducers.insert(consumerRequiredProducers.begin(), consumerRequiredProducers.end());
+                requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend());
+                priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend());
+            }
+            else {
+                priorConsumers.insert(consumer);
             }
         }
 
+        consumers.swap(priorConsumers);
+
         // Make producers generate the required data
         for (const auto& requiredProducer : requiredProducers) {
             requiredProducer->getOperator()->updateConsummerProducer();
@@ -119,7 +114,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
 
         // find runnable consumers
         std::set<std::shared_ptr<Node>> runnableConsumers;
-        if (verbose) printf("List of layers receiving data:\n");
+        if (verbose) printf("Updated list of consumers:\n");
         for (const auto& consumer : consumers) {
             if (verbose) {
                 printf("\t- consumer: "
@@ -127,7 +122,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
                        "%s"
                        "\x1b[0m"
                        "\n\t\tC/R:\t",
-                       (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
+                       namePtrTable[consumer].c_str());
                 for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
                     printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
                            consumer->getOperator()->getNbRequiredData(inId));
@@ -169,15 +164,13 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
 
         // Push consumers in the list of nodes to run and update the consumer producer system
         for (const auto& runnable : runnableConsumers) {
-            if (verbose) printf("Runnable: %s\n", (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
+            if (verbose) printf("Runnable: %s\n", namePtrTable[runnable].c_str());
             runnable->getOperator()->updateConsummerProducer();
             mStaticSchedule.push_back(runnable);
         }
 
         if (runnableConsumers.empty()) {
-            if (frozenConsumers.empty()) {
-                frozenConsumers = consumers;
-            }
+            frozenConsumers.push_back(consumers);
         }
         else {
             frozenConsumers.clear();
@@ -190,7 +183,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
         for (const auto& consumer : oldConsumers) {
             if (verbose) {
                 printf("\t- consumer: %s\n\t\tC/R:\t",
-                       (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
+                       namePtrTable[consumer].c_str());
                 for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
                     printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
                            consumer->getOperator()->getNbRequiredData(inId));
@@ -243,7 +236,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
         }
 
         if (verbose) printf("********************\n");
-    } while (!consumers.empty() && consumers != frozenConsumers);
+    } while (!consumers.empty() && (frozenConsumers.empty() || std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end()));
 
     if (verbose) {
         if (!consumers.empty()) {
@@ -268,15 +261,16 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
     // Clear previous scheduling results
     mScheduling.clear();
 
+    std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose);
+
     int cpt = 0;
     for (const auto& runnable : mStaticSchedule) {
         if (verbose)
             printf("run: %s\n",
-                    (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
+                    namePtrTable[runnable].c_str());
         else
             drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
-                            (std::string("running ") + runnable->type() + "_" +
-                                std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
+                            (std::string("running ") + namePtrTable[runnable]));
         const auto tStart = std::chrono::high_resolution_clock::now();
         runnable->forward();
         const auto tEnd = std::chrono::high_resolution_clock::now();
@@ -292,12 +286,12 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa
     std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n");
 
     if (!mScheduling.empty()) {
+        std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(true);
         const auto globalStart = mScheduling[0].start;
 
         for (const auto& element : mScheduling) {
             std::fprintf(fp, "%s :%ld, %ld\n",
-                         (element.node->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(element.node.get())))
-                                 .c_str(),
+                         namePtrTable[element.node].c_str(),
                          std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
                          std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
         }
@@ -318,3 +312,66 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
 
     return consumers;
 }
+
+Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
+    const std::shared_ptr<Node>& node) const
+{
+    PriorProducersConsumers prior;
+
+    IOIndex_t inputIdx = 0;
+    for (const auto& parent : node->inputs()) {
+        if (parent.first &&
+            (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
+                    parent.first->getOperator()->getNbProducedData(parent.second))
+        {
+            if (parent.first->type() == Producer_Op::Type) {
+                prior.requiredProducers.insert(parent.first);
+                prior.priorConsumers.insert(node);
+            }
+            else if (parent.first->type() == Memorize_Op::Type) {
+                // Break cycles
+                return PriorProducersConsumers();
+            }
+            else {
+                const auto& parentPrior = getPriorProducersConsumers(parent.first);
+
+                if (!parentPrior.isPrior) {
+                    return PriorProducersConsumers();
+                }
+                else {
+                    prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
+                    prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
+                }
+            }
+        }
+        ++inputIdx;
+    }
+
+    prior.isPrior = true;
+    if (prior.priorConsumers.empty()) {
+        prior.priorConsumers.insert(node);
+    }
+    return prior;
+}
+
+std::map<std::shared_ptr<Aidge::Node>, std::string> Aidge::SequentialScheduler::getNodesName(bool verbose) const {
+    std::map<std::shared_ptr<Node>, std::string> namePtrTable;
+
+    if (verbose) {
+        std::map<const std::string, std::size_t> typeCounter;
+
+        for (const std::shared_ptr<Node> &node_ptr : mGraphView->getNodes()) {
+            const std::string currentType = node_ptr->type();
+            if (typeCounter.find(currentType) == typeCounter.end())
+                typeCounter[currentType] = 0;
+            ++typeCounter[currentType];
+
+            namePtrTable[node_ptr] =
+                (node_ptr->name().empty())
+                    ? currentType + "#" + std::to_string(typeCounter[currentType])
+                    : node_ptr->name() + " (" + currentType + "#" + std::to_string(typeCounter[currentType]) + ")";
+        }
+    }
+
+    return namePtrTable;
+}
diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index 68e2d4d4d..59e768d44 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -51,4 +51,33 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
         //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler();
         //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2);
     }
+
+    SECTION("LTSM") {
+        auto myLSTM = LTSM(32, 64, 16, "ltsm");
+        auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
+
+        auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
+        microGraph->save("lstm", false, false);
+
+        REQUIRE(myLSTM->nbInputs() == 3);
+        REQUIRE(myLSTM->nbData() == 3);
+        REQUIRE(myLSTM->nbOutputs() == 2);
+
+        std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>();
+        myInput->resize({32});
+        std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>();
+        myInit->resize({1, 64});
+
+        op->associateInput(0, myInput);
+        op->associateInput(1, myInit);
+        op->associateInput(2, myInit);
+
+        op->computeOutputDims();
+        REQUIRE(op->outputDimsForwarded());
+        microGraph->save("lstm_dims", false, false);
+
+        //op->updateConsummerProducer();  // require implementation
+        //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
+        //microGraphScheduler->saveSchedulingDiagram("lstm_scheduling");
+    }
 }
-- 
GitLab