From f80511fbf2b55d1c4ab77215f061880e354e3c64 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 9 Feb 2024 15:12:53 +0100 Subject: [PATCH] Added LSTM meta-operator (not tested yet with actuel values) --- include/aidge/operator/MetaOperatorDefs.hpp | 92 ++++++++++++ include/aidge/scheduler/Scheduler.hpp | 14 ++ include/aidge/utils/Formatting.hpp | 11 +- src/scheduler/Scheduler.cpp | 147 ++++++++++++++------ unit_tests/operator/Test_MetaOperator.cpp | 29 ++++ 5 files changed, 245 insertions(+), 48 deletions(-) diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index 2832f9fce..492bcd95d 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -18,6 +18,14 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/Pad.hpp" +#include "aidge/operator/Memorize.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/Identity.hpp" +#include "aidge/operator/Concat.hpp" +#include "aidge/operator/Tanh.hpp" +#include "aidge/operator/Sigmoid.hpp" namespace Aidge { template <std::array<DimSize_t, 1>::size_type DIM> @@ -135,6 +143,90 @@ inline std::shared_ptr<Node> PaddedMaxPooling( { return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode); } + +inline std::shared_ptr<Node> LTSM(DimSize_t in_channels, + DimSize_t hidden_channels, + DimSize_t seq_length, + const std::string& name = "") +{ + // Construct micro-graph + auto input = Identity((!name.empty()) ? name + "_input" : ""); + auto hiddenState = Memorize(seq_length, (!name.empty()) ? name + "_hidden_state" : ""); + auto cellState = Memorize(seq_length, (!name.empty()) ? name + "_cell_state" : ""); + auto add = Add(2, (!name.empty()) ? name + "_add" : ""); + + // Forget gate + auto forgetGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateX" : ""); + input->addChild(forgetGateX, 0, 0); + auto forgetGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateH" : ""); + hiddenState->addChild(forgetGateH, 1, 0); + auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : ""); + forgetGateX->addChild(forgetGate, 0, 0); + forgetGateH->addChild(forgetGate, 0, 1); + auto forgetGateAct = Sigmoid((!name.empty()) ? name + "_forgetGateAct" : ""); + auto forgetGateMul = Mul((!name.empty()) ? name + "_forgetGateMul" : ""); + forgetGate->addChild(forgetGateAct, 0, 0); + forgetGateAct->addChild(forgetGateMul, 0, 0); + forgetGateMul->addChild(add, 0, 0); + cellState->addChild(forgetGateMul, 1, 1); + + // Input gate + auto inputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateX" : ""); + input->addChild(inputGateX, 0, 0); + auto inputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateH" : ""); + hiddenState->addChild(inputGateH, 1, 0); + auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : ""); + inputGateX->addChild(inputGate, 0, 0); + inputGateH->addChild(inputGate, 0, 1); + auto inputGateAct = Sigmoid((!name.empty()) ? name + "_inputGateAct" : ""); + auto inputGateMul = Mul((!name.empty()) ? name + "_inputGateMul" : ""); + inputGate->addChild(inputGateAct, 0, 0); + inputGateAct->addChild(inputGateMul, 0, 0); + inputGateMul->addChild(add, 0, 1); + + // Candidate for cell update + auto cellCandidateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateX" : ""); + input->addChild(cellCandidateX, 0, 0); + auto cellCandidateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateH" : ""); + hiddenState->addChild(cellCandidateH, 1, 0); + auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : ""); + cellCandidateX->addChild(cellCandidate, 0, 0); + cellCandidateH->addChild(cellCandidate, 0, 1); + auto cellCandidateAct = Tanh((!name.empty()) ? name + "_cellCandidateAct" : ""); + cellCandidate->addChild(cellCandidateAct, 0, 0); + cellCandidateAct->addChild(inputGateMul, 0, 1); + + // Output gate + auto outputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateX" : ""); + input->addChild(outputGateX, 0, 0); + auto outputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateH" : ""); + hiddenState->addChild(outputGateH, 1, 0); + auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : ""); + outputGateX->addChild(outputGate, 0, 0); + outputGateH->addChild(outputGate, 0, 1); + auto outputGateAct = Sigmoid((!name.empty()) ? name + "_outputGateAct" : ""); + auto outputGateMul = Mul((!name.empty()) ? name + "_outputGateMul" : ""); + outputGate->addChild(outputGateAct, 0, 0); + outputGateAct->addChild(outputGateMul, 0, 0); + + // Updated cell state to help determine new hidden state + auto cellUpdatedAct = Tanh((!name.empty()) ? name + "_cellUpdatedAct" : ""); + add->addChild(cellUpdatedAct, 0, 0); + cellUpdatedAct->addChild(outputGateMul, 0, 1); + outputGateMul->addChild(hiddenState, 0, 0); + add->addChild(cellState, 0, 0); + + std::shared_ptr<GraphView> microGraph = std::make_shared<GraphView>(); + microGraph->add(input); + microGraph->add({hiddenState, cellState, add, + forgetGateX, forgetGateH, forgetGate, forgetGateAct, forgetGateMul, + inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul, + cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct, + outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul, + cellUpdatedAct}); + + return MetaOperator("LTSM", microGraph, name); +} } // namespace Aidge #endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 6dcec5aaa..81dc6d7cc 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -17,6 +17,7 @@ #include <set> #include <string> #include <vector> +#include <map> namespace Aidge { class Node; @@ -36,6 +37,12 @@ private: std::chrono::time_point<std::chrono::high_resolution_clock> end; }; + struct PriorProducersConsumers { + bool isPrior = false; + std::set<std::shared_ptr<Aidge::Node>> requiredProducers; + std::set<std::shared_ptr<Aidge::Node>> priorConsumers; + }; + public: SequentialScheduler(std::shared_ptr<GraphView> graphView) : mGraphView(graphView) @@ -80,6 +87,13 @@ private: * @return std::set<std::shared_ptr<Node>> */ std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; + PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const; + + /** + * Return a std::map with corresponding node's name. + * TODO: Mutualise with similar code in GraphView::save()? + */ + std::map<std::shared_ptr<Node>, std::string> getNodesName(bool verbose) const; /** @brief Shared ptr to the scheduled graph view */ std::shared_ptr<GraphView> mGraphView; diff --git a/include/aidge/utils/Formatting.hpp b/include/aidge/utils/Formatting.hpp index e96e87e35..4bcb5cad6 100644 --- a/include/aidge/utils/Formatting.hpp +++ b/include/aidge/utils/Formatting.hpp @@ -50,8 +50,8 @@ std::string stringFormat(const std::string& format, Args... args) { /** * Print any iterable object in a std::string. */ -template <class T> -std::string print(const T& vec, const std::string& format) { +template <class T, typename F> +std::string print(const T& vec, const std::string& format, const F& func) { std::string str = "{"; bool first = true; for (const auto& val : vec) { @@ -61,11 +61,16 @@ std::string print(const T& vec, const std::string& format) { else { first = false; } - str += stringFormat(format, val); + str += stringFormat(format, func(val)); } str += "}"; return str; } + +template <class T> +std::string print(const T& vec, const std::string& format) { + return print(vec, format, [](auto val){ return val; }); +} } #endif //AIDGE_FORMATTING_H_ diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 2093aa5af..8273df21b 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -21,6 +21,7 @@ #include "aidge/utils/Types.h" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Memorize.hpp" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); @@ -60,12 +61,22 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // runnable consumer, the list of consumer is again equal to frozenConsumers // it means we are in cycle with no more scheduling update, a.k.a. a // frozen state. - std::set<std::shared_ptr<Node>> frozenConsumers; + std::vector<std::set<std::shared_ptr<Node>>> frozenConsumers; + + std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose); do { - // Check required producers + // From the current consumers list, check if any prior nodes are needed. + // If for a given node, only parent producers (at any depth) are needed + // to satisfy its required data, it becomes a prior. + // If the prior node is a producer, it is added to the list of required + // producers. + // If the prior node is of another type, it replaces the initial consumer + // in the new priorConsumers list. The initial consumer will necessarily + // be added again later in the consumers list. + if (verbose) printf("List of consumers with their priors:\n"); std::set<std::shared_ptr<Node>> requiredProducers; - if (verbose) printf("Required producers:\n"); + std::set<std::shared_ptr<Node>> priorConsumers; for (const auto& consumer : consumers) { if (verbose) { @@ -74,43 +85,27 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { "%s" "\x1b[0m" "\n", - (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); + namePtrTable[consumer].c_str()); } - std::set<std::shared_ptr<Node>> consumerRequiredProducers; - bool requiredProducerOnly = true; - IOIndex_t inputIdx = 0; - for (const auto& consumerParent : consumer->inputs()) { - if (verbose) printf("\t\t#%u: ", inputIdx); - - if (consumerParent.first && - (consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > - consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { - if (verbose) printf("required data from %s: C%zu + R%zu > P%zu\n", - consumerParent.first->type().c_str(), - consumer->getOperator()->getNbConsumedData(inputIdx), - consumer->getOperator()->getNbRequiredData(inputIdx), - consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)); + const auto& prior = getPriorProducersConsumers(consumer); - if (consumerParent.first->type() == Producer_Op::Type) { - consumerRequiredProducers.insert(consumerParent.first); - } - else { - requiredProducerOnly = false; - break; - } + if (prior.isPrior) { + if (verbose) { + printf("\t\trequired producers: %s\n", print(prior.requiredProducers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str()); + printf("\t\tprior consumers: %s\n", print(prior.priorConsumers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str()); } - else { - if (verbose) printf("no data required\n"); - } - ++inputIdx; - } - if (requiredProducerOnly) { - requiredProducers.insert(consumerRequiredProducers.begin(), consumerRequiredProducers.end()); + requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend()); + priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend()); + } + else { + priorConsumers.insert(consumer); } } + consumers.swap(priorConsumers); + // Make producers generate the required data for (const auto& requiredProducer : requiredProducers) { requiredProducer->getOperator()->updateConsummerProducer(); @@ -119,7 +114,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // find runnable consumers std::set<std::shared_ptr<Node>> runnableConsumers; - if (verbose) printf("List of layers receiving data:\n"); + if (verbose) printf("Updated list of consumers:\n"); for (const auto& consumer : consumers) { if (verbose) { printf("\t- consumer: " @@ -127,7 +122,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { "%s" "\x1b[0m" "\n\t\tC/R:\t", - (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); + namePtrTable[consumer].c_str()); for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); @@ -169,15 +164,13 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // Push consumers in the list of nodes to run and update the consumer producer system for (const auto& runnable : runnableConsumers) { - if (verbose) printf("Runnable: %s\n", (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); + if (verbose) printf("Runnable: %s\n", namePtrTable[runnable].c_str()); runnable->getOperator()->updateConsummerProducer(); mStaticSchedule.push_back(runnable); } if (runnableConsumers.empty()) { - if (frozenConsumers.empty()) { - frozenConsumers = consumers; - } + frozenConsumers.push_back(consumers); } else { frozenConsumers.clear(); @@ -190,7 +183,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { for (const auto& consumer : oldConsumers) { if (verbose) { printf("\t- consumer: %s\n\t\tC/R:\t", - (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); + namePtrTable[consumer].c_str()); for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); @@ -243,7 +236,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } if (verbose) printf("********************\n"); - } while (!consumers.empty() && consumers != frozenConsumers); + } while (!consumers.empty() && (frozenConsumers.empty() || std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end())); if (verbose) { if (!consumers.empty()) { @@ -268,15 +261,16 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { // Clear previous scheduling results mScheduling.clear(); + std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose); + int cpt = 0; for (const auto& runnable : mStaticSchedule) { if (verbose) printf("run: %s\n", - (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); + namePtrTable[runnable].c_str()); else drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, - (std::string("running ") + runnable->type() + "_" + - std::to_string(reinterpret_cast<uintptr_t>(runnable.get())))); + (std::string("running ") + namePtrTable[runnable])); const auto tStart = std::chrono::high_resolution_clock::now(); runnable->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); @@ -292,12 +286,12 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n"); if (!mScheduling.empty()) { + std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(true); const auto globalStart = mScheduling[0].start; for (const auto& element : mScheduling) { std::fprintf(fp, "%s :%ld, %ld\n", - (element.node->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(element.node.get()))) - .c_str(), + namePtrTable[element.node].c_str(), std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); } @@ -318,3 +312,66 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( return consumers; } + +Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( + const std::shared_ptr<Node>& node) const +{ + PriorProducersConsumers prior; + + IOIndex_t inputIdx = 0; + for (const auto& parent : node->inputs()) { + if (parent.first && + (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > + parent.first->getOperator()->getNbProducedData(parent.second)) + { + if (parent.first->type() == Producer_Op::Type) { + prior.requiredProducers.insert(parent.first); + prior.priorConsumers.insert(node); + } + else if (parent.first->type() == Memorize_Op::Type) { + // Break cycles + return PriorProducersConsumers(); + } + else { + const auto& parentPrior = getPriorProducersConsumers(parent.first); + + if (!parentPrior.isPrior) { + return PriorProducersConsumers(); + } + else { + prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); + prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend()); + } + } + } + ++inputIdx; + } + + prior.isPrior = true; + if (prior.priorConsumers.empty()) { + prior.priorConsumers.insert(node); + } + return prior; +} + +std::map<std::shared_ptr<Aidge::Node>, std::string> Aidge::SequentialScheduler::getNodesName(bool verbose) const { + std::map<std::shared_ptr<Node>, std::string> namePtrTable; + + if (verbose) { + std::map<const std::string, std::size_t> typeCounter; + + for (const std::shared_ptr<Node> &node_ptr : mGraphView->getNodes()) { + const std::string currentType = node_ptr->type(); + if (typeCounter.find(currentType) == typeCounter.end()) + typeCounter[currentType] = 0; + ++typeCounter[currentType]; + + namePtrTable[node_ptr] = + (node_ptr->name().empty()) + ? currentType + "#" + std::to_string(typeCounter[currentType]) + : node_ptr->name() + " (" + currentType + "#" + std::to_string(typeCounter[currentType]) + ")"; + } + } + + return namePtrTable; +} diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 68e2d4d4d..59e768d44 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -51,4 +51,33 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); } + + SECTION("LTSM") { + auto myLSTM = LTSM(32, 64, 16, "ltsm"); + auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); + + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); + microGraph->save("lstm", false, false); + + REQUIRE(myLSTM->nbInputs() == 3); + REQUIRE(myLSTM->nbData() == 3); + REQUIRE(myLSTM->nbOutputs() == 2); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(); + myInput->resize({32}); + std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(); + myInit->resize({1, 64}); + + op->associateInput(0, myInput); + op->associateInput(1, myInit); + op->associateInput(2, myInit); + + op->computeOutputDims(); + REQUIRE(op->outputDimsForwarded()); + microGraph->save("lstm_dims", false, false); + + //op->updateConsummerProducer(); // require implementation + //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); + //microGraphScheduler->saveSchedulingDiagram("lstm_scheduling"); + } } -- GitLab