Skip to content
Snippets Groups Projects
Commit f80511fb authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added LSTM meta-operator (not tested yet with actuel values)

parent 337c5229
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
Pipeline #38758 failed
...@@ -18,6 +18,14 @@ ...@@ -18,6 +18,14 @@
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/Pad.hpp" #include "aidge/operator/Pad.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Tanh.hpp"
#include "aidge/operator/Sigmoid.hpp"
namespace Aidge { namespace Aidge {
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
...@@ -135,6 +143,90 @@ inline std::shared_ptr<Node> PaddedMaxPooling( ...@@ -135,6 +143,90 @@ inline std::shared_ptr<Node> PaddedMaxPooling(
{ {
return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode); return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode);
} }
inline std::shared_ptr<Node> LTSM(DimSize_t in_channels,
DimSize_t hidden_channels,
DimSize_t seq_length,
const std::string& name = "")
{
// Construct micro-graph
auto input = Identity((!name.empty()) ? name + "_input" : "");
auto hiddenState = Memorize(seq_length, (!name.empty()) ? name + "_hidden_state" : "");
auto cellState = Memorize(seq_length, (!name.empty()) ? name + "_cell_state" : "");
auto add = Add(2, (!name.empty()) ? name + "_add" : "");
// Forget gate
auto forgetGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateX" : "");
input->addChild(forgetGateX, 0, 0);
auto forgetGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateH" : "");
hiddenState->addChild(forgetGateH, 1, 0);
auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : "");
forgetGateX->addChild(forgetGate, 0, 0);
forgetGateH->addChild(forgetGate, 0, 1);
auto forgetGateAct = Sigmoid((!name.empty()) ? name + "_forgetGateAct" : "");
auto forgetGateMul = Mul((!name.empty()) ? name + "_forgetGateMul" : "");
forgetGate->addChild(forgetGateAct, 0, 0);
forgetGateAct->addChild(forgetGateMul, 0, 0);
forgetGateMul->addChild(add, 0, 0);
cellState->addChild(forgetGateMul, 1, 1);
// Input gate
auto inputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateX" : "");
input->addChild(inputGateX, 0, 0);
auto inputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateH" : "");
hiddenState->addChild(inputGateH, 1, 0);
auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : "");
inputGateX->addChild(inputGate, 0, 0);
inputGateH->addChild(inputGate, 0, 1);
auto inputGateAct = Sigmoid((!name.empty()) ? name + "_inputGateAct" : "");
auto inputGateMul = Mul((!name.empty()) ? name + "_inputGateMul" : "");
inputGate->addChild(inputGateAct, 0, 0);
inputGateAct->addChild(inputGateMul, 0, 0);
inputGateMul->addChild(add, 0, 1);
// Candidate for cell update
auto cellCandidateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateX" : "");
input->addChild(cellCandidateX, 0, 0);
auto cellCandidateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateH" : "");
hiddenState->addChild(cellCandidateH, 1, 0);
auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : "");
cellCandidateX->addChild(cellCandidate, 0, 0);
cellCandidateH->addChild(cellCandidate, 0, 1);
auto cellCandidateAct = Tanh((!name.empty()) ? name + "_cellCandidateAct" : "");
cellCandidate->addChild(cellCandidateAct, 0, 0);
cellCandidateAct->addChild(inputGateMul, 0, 1);
// Output gate
auto outputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateX" : "");
input->addChild(outputGateX, 0, 0);
auto outputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateH" : "");
hiddenState->addChild(outputGateH, 1, 0);
auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : "");
outputGateX->addChild(outputGate, 0, 0);
outputGateH->addChild(outputGate, 0, 1);
auto outputGateAct = Sigmoid((!name.empty()) ? name + "_outputGateAct" : "");
auto outputGateMul = Mul((!name.empty()) ? name + "_outputGateMul" : "");
outputGate->addChild(outputGateAct, 0, 0);
outputGateAct->addChild(outputGateMul, 0, 0);
// Updated cell state to help determine new hidden state
auto cellUpdatedAct = Tanh((!name.empty()) ? name + "_cellUpdatedAct" : "");
add->addChild(cellUpdatedAct, 0, 0);
cellUpdatedAct->addChild(outputGateMul, 0, 1);
outputGateMul->addChild(hiddenState, 0, 0);
add->addChild(cellState, 0, 0);
std::shared_ptr<GraphView> microGraph = std::make_shared<GraphView>();
microGraph->add(input);
microGraph->add({hiddenState, cellState, add,
forgetGateX, forgetGateH, forgetGate, forgetGateAct, forgetGateMul,
inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul,
cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct,
outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul,
cellUpdatedAct});
return MetaOperator("LTSM", microGraph, name);
}
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */ #endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map>
namespace Aidge { namespace Aidge {
class Node; class Node;
...@@ -36,6 +37,12 @@ private: ...@@ -36,6 +37,12 @@ private:
std::chrono::time_point<std::chrono::high_resolution_clock> end; std::chrono::time_point<std::chrono::high_resolution_clock> end;
}; };
struct PriorProducersConsumers {
bool isPrior = false;
std::set<std::shared_ptr<Aidge::Node>> requiredProducers;
std::set<std::shared_ptr<Aidge::Node>> priorConsumers;
};
public: public:
SequentialScheduler(std::shared_ptr<GraphView> graphView) SequentialScheduler(std::shared_ptr<GraphView> graphView)
: mGraphView(graphView) : mGraphView(graphView)
...@@ -80,6 +87,13 @@ private: ...@@ -80,6 +87,13 @@ private:
* @return std::set<std::shared_ptr<Node>> * @return std::set<std::shared_ptr<Node>>
*/ */
std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
/**
* Return a std::map with corresponding node's name.
* TODO: Mutualise with similar code in GraphView::save()?
*/
std::map<std::shared_ptr<Node>, std::string> getNodesName(bool verbose) const;
/** @brief Shared ptr to the scheduled graph view */ /** @brief Shared ptr to the scheduled graph view */
std::shared_ptr<GraphView> mGraphView; std::shared_ptr<GraphView> mGraphView;
......
...@@ -50,8 +50,8 @@ std::string stringFormat(const std::string& format, Args... args) { ...@@ -50,8 +50,8 @@ std::string stringFormat(const std::string& format, Args... args) {
/** /**
* Print any iterable object in a std::string. * Print any iterable object in a std::string.
*/ */
template <class T> template <class T, typename F>
std::string print(const T& vec, const std::string& format) { std::string print(const T& vec, const std::string& format, const F& func) {
std::string str = "{"; std::string str = "{";
bool first = true; bool first = true;
for (const auto& val : vec) { for (const auto& val : vec) {
...@@ -61,11 +61,16 @@ std::string print(const T& vec, const std::string& format) { ...@@ -61,11 +61,16 @@ std::string print(const T& vec, const std::string& format) {
else { else {
first = false; first = false;
} }
str += stringFormat(format, val); str += stringFormat(format, func(val));
} }
str += "}"; str += "}";
return str; return str;
} }
template <class T>
std::string print(const T& vec, const std::string& format) {
return print(vec, format, [](auto val){ return val; });
}
} }
#endif //AIDGE_FORMATTING_H_ #endif //AIDGE_FORMATTING_H_
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('['); putchar('[');
...@@ -60,12 +61,22 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -60,12 +61,22 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// runnable consumer, the list of consumer is again equal to frozenConsumers // runnable consumer, the list of consumer is again equal to frozenConsumers
// it means we are in cycle with no more scheduling update, a.k.a. a // it means we are in cycle with no more scheduling update, a.k.a. a
// frozen state. // frozen state.
std::set<std::shared_ptr<Node>> frozenConsumers; std::vector<std::set<std::shared_ptr<Node>>> frozenConsumers;
std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose);
do { do {
// Check required producers // From the current consumers list, check if any prior nodes are needed.
// If for a given node, only parent producers (at any depth) are needed
// to satisfy its required data, it becomes a prior.
// If the prior node is a producer, it is added to the list of required
// producers.
// If the prior node is of another type, it replaces the initial consumer
// in the new priorConsumers list. The initial consumer will necessarily
// be added again later in the consumers list.
if (verbose) printf("List of consumers with their priors:\n");
std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> requiredProducers;
if (verbose) printf("Required producers:\n"); std::set<std::shared_ptr<Node>> priorConsumers;
for (const auto& consumer : consumers) { for (const auto& consumer : consumers) {
if (verbose) { if (verbose) {
...@@ -74,43 +85,27 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -74,43 +85,27 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
"%s" "%s"
"\x1b[0m" "\x1b[0m"
"\n", "\n",
(consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); namePtrTable[consumer].c_str());
} }
std::set<std::shared_ptr<Node>> consumerRequiredProducers; const auto& prior = getPriorProducersConsumers(consumer);
bool requiredProducerOnly = true;
IOIndex_t inputIdx = 0;
for (const auto& consumerParent : consumer->inputs()) {
if (verbose) printf("\t\t#%u: ", inputIdx);
if (consumerParent.first &&
(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
if (verbose) printf("required data from %s: C%zu + R%zu > P%zu\n",
consumerParent.first->type().c_str(),
consumer->getOperator()->getNbConsumedData(inputIdx),
consumer->getOperator()->getNbRequiredData(inputIdx),
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
if (consumerParent.first->type() == Producer_Op::Type) { if (prior.isPrior) {
consumerRequiredProducers.insert(consumerParent.first); if (verbose) {
} printf("\t\trequired producers: %s\n", print(prior.requiredProducers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str());
else { printf("\t\tprior consumers: %s\n", print(prior.priorConsumers, "%s", [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }).c_str());
requiredProducerOnly = false;
break;
}
} }
else {
if (verbose) printf("no data required\n");
}
++inputIdx;
}
if (requiredProducerOnly) { requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend());
requiredProducers.insert(consumerRequiredProducers.begin(), consumerRequiredProducers.end()); priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend());
}
else {
priorConsumers.insert(consumer);
} }
} }
consumers.swap(priorConsumers);
// Make producers generate the required data // Make producers generate the required data
for (const auto& requiredProducer : requiredProducers) { for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer(); requiredProducer->getOperator()->updateConsummerProducer();
...@@ -119,7 +114,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -119,7 +114,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// find runnable consumers // find runnable consumers
std::set<std::shared_ptr<Node>> runnableConsumers; std::set<std::shared_ptr<Node>> runnableConsumers;
if (verbose) printf("List of layers receiving data:\n"); if (verbose) printf("Updated list of consumers:\n");
for (const auto& consumer : consumers) { for (const auto& consumer : consumers) {
if (verbose) { if (verbose) {
printf("\t- consumer: " printf("\t- consumer: "
...@@ -127,7 +122,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -127,7 +122,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
"%s" "%s"
"\x1b[0m" "\x1b[0m"
"\n\t\tC/R:\t", "\n\t\tC/R:\t",
(consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); namePtrTable[consumer].c_str());
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
consumer->getOperator()->getNbRequiredData(inId)); consumer->getOperator()->getNbRequiredData(inId));
...@@ -169,15 +164,13 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -169,15 +164,13 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// Push consumers in the list of nodes to run and update the consumer producer system // Push consumers in the list of nodes to run and update the consumer producer system
for (const auto& runnable : runnableConsumers) { for (const auto& runnable : runnableConsumers) {
if (verbose) printf("Runnable: %s\n", (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); if (verbose) printf("Runnable: %s\n", namePtrTable[runnable].c_str());
runnable->getOperator()->updateConsummerProducer(); runnable->getOperator()->updateConsummerProducer();
mStaticSchedule.push_back(runnable); mStaticSchedule.push_back(runnable);
} }
if (runnableConsumers.empty()) { if (runnableConsumers.empty()) {
if (frozenConsumers.empty()) { frozenConsumers.push_back(consumers);
frozenConsumers = consumers;
}
} }
else { else {
frozenConsumers.clear(); frozenConsumers.clear();
...@@ -190,7 +183,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -190,7 +183,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
for (const auto& consumer : oldConsumers) { for (const auto& consumer : oldConsumers) {
if (verbose) { if (verbose) {
printf("\t- consumer: %s\n\t\tC/R:\t", printf("\t- consumer: %s\n\t\tC/R:\t",
(consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); namePtrTable[consumer].c_str());
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
consumer->getOperator()->getNbRequiredData(inId)); consumer->getOperator()->getNbRequiredData(inId));
...@@ -243,7 +236,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -243,7 +236,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
} }
if (verbose) printf("********************\n"); if (verbose) printf("********************\n");
} while (!consumers.empty() && consumers != frozenConsumers); } while (!consumers.empty() && (frozenConsumers.empty() || std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end()));
if (verbose) { if (verbose) {
if (!consumers.empty()) { if (!consumers.empty()) {
...@@ -268,15 +261,16 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { ...@@ -268,15 +261,16 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
// Clear previous scheduling results // Clear previous scheduling results
mScheduling.clear(); mScheduling.clear();
std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(verbose);
int cpt = 0; int cpt = 0;
for (const auto& runnable : mStaticSchedule) { for (const auto& runnable : mStaticSchedule) {
if (verbose) if (verbose)
printf("run: %s\n", printf("run: %s\n",
(runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); namePtrTable[runnable].c_str());
else else
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
(std::string("running ") + runnable->type() + "_" + (std::string("running ") + namePtrTable[runnable]));
std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
const auto tStart = std::chrono::high_resolution_clock::now(); const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward(); runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now(); const auto tEnd = std::chrono::high_resolution_clock::now();
...@@ -292,12 +286,12 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa ...@@ -292,12 +286,12 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa
std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n"); std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n");
if (!mScheduling.empty()) { if (!mScheduling.empty()) {
std::map<std::shared_ptr<Node>, std::string> namePtrTable = getNodesName(true);
const auto globalStart = mScheduling[0].start; const auto globalStart = mScheduling[0].start;
for (const auto& element : mScheduling) { for (const auto& element : mScheduling) {
std::fprintf(fp, "%s :%ld, %ld\n", std::fprintf(fp, "%s :%ld, %ld\n",
(element.node->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(element.node.get()))) namePtrTable[element.node].c_str(),
.c_str(),
std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
} }
...@@ -318,3 +312,66 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( ...@@ -318,3 +312,66 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
return consumers; return consumers;
} }
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
{
PriorProducersConsumers prior;
IOIndex_t inputIdx = 0;
for (const auto& parent : node->inputs()) {
if (parent.first &&
(node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
parent.first->getOperator()->getNbProducedData(parent.second))
{
if (parent.first->type() == Producer_Op::Type) {
prior.requiredProducers.insert(parent.first);
prior.priorConsumers.insert(node);
}
else if (parent.first->type() == Memorize_Op::Type) {
// Break cycles
return PriorProducersConsumers();
}
else {
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
return PriorProducersConsumers();
}
else {
prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
}
}
}
++inputIdx;
}
prior.isPrior = true;
if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node);
}
return prior;
}
std::map<std::shared_ptr<Aidge::Node>, std::string> Aidge::SequentialScheduler::getNodesName(bool verbose) const {
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) {
std::map<const std::string, std::size_t> typeCounter;
for (const std::shared_ptr<Node> &node_ptr : mGraphView->getNodes()) {
const std::string currentType = node_ptr->type();
if (typeCounter.find(currentType) == typeCounter.end())
typeCounter[currentType] = 0;
++typeCounter[currentType];
namePtrTable[node_ptr] =
(node_ptr->name().empty())
? currentType + "#" + std::to_string(typeCounter[currentType])
: node_ptr->name() + " (" + currentType + "#" + std::to_string(typeCounter[currentType]) + ")";
}
}
return namePtrTable;
}
...@@ -51,4 +51,33 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { ...@@ -51,4 +51,33 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
//auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler();
//REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2);
} }
SECTION("LTSM") {
auto myLSTM = LTSM(32, 64, 16, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3);
REQUIRE(myLSTM->nbData() == 3);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>();
myInput->resize({32});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>();
myInit->resize({1, 64});
op->associateInput(0, myInput);
op->associateInput(1, myInit);
op->associateInput(2, myInit);
op->computeOutputDims();
REQUIRE(op->outputDimsForwarded());
microGraph->save("lstm_dims", false, false);
//op->updateConsummerProducer(); // require implementation
//auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
//microGraphScheduler->saveSchedulingDiagram("lstm_scheduling");
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment