Skip to content
Snippets Groups Projects
Commit fafe61a4 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Support for ONNX LSTM import

parent f426ed60
No related branches found
No related tags found
No related merge requests found
...@@ -71,10 +71,11 @@ public: ...@@ -71,10 +71,11 @@ public:
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final {
assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(inputIdx < 3 && "operators supports only 3 inputs");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
if (inputIdx == 2) { // TODO: FIXME: check this, because data dims may not be initialized at this point...
assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template getAttr<FCAttr::NoBias>()) == false ? static_cast<std::size_t>(this->template getAttr<FCAttr::OutChannels>()) : 0)); //if (inputIdx == 2) {
assert(std::dynamic_pointer_cast<Tensor>(data)->nbDims() == 1); // assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template getAttr<FCAttr::NoBias>()) == false ? static_cast<std::size_t>(this->template getAttr<FCAttr::OutChannels>()) : 0));
} // assert(std::dynamic_pointer_cast<Tensor>(data)->nbDims() == 1);
//}
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
if (inputIdx == 0 && getInput(0)->nbDims() == 1) if (inputIdx == 0 && getInput(0)->nbDims() == 1)
mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()}); mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()});
......
...@@ -232,6 +232,7 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, ...@@ -232,6 +232,7 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
{inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2}, {inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2},
{inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2}, {inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2},
{hiddenState, 1}, {cellState, 1}}); {hiddenState, 1}, {cellState, 1}});
microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}});
auto metaOp = MetaOperator("LSTM", microGraph, name); auto metaOp = MetaOperator("LSTM", microGraph, name);
addProducer(metaOp, 1, {hidden_channels, in_channels}, "wi"); addProducer(metaOp, 1, {hidden_channels, in_channels}, "wi");
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
...@@ -65,29 +67,36 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) ...@@ -65,29 +67,36 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
"'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n"); "'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n");
// Start by creating every node // Start by creating every node
auto namePtrTable = getRankedNodesName("{3}"); const auto namePtrTable = getRankedNodesName("{3}");
for (const std::shared_ptr<Node> &node_ptr : mNodes) { for (const std::shared_ptr<Node> &node_ptr : mNodes) {
std::string givenName = std::string givenName =
(node_ptr->name().empty()) (node_ptr->name().empty())
? "<em>" + node_ptr->type() + "#" + namePtrTable[node_ptr] + "</em>" ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>"
: "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable[node_ptr] + ")</em></sub>\""; : "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\"";
std::string nodeCls = "";
if (node_ptr->type() == "Producer") {
nodeCls = ":::producerCls";
}
else if (std::dynamic_pointer_cast<GenericOperator_Op>(node_ptr->getOperator())) {
nodeCls = ":::genericCls";
}
else if (const auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node_ptr->getOperator())) {
nodeCls = ":::metaCls";
if (verbose) {
metaOp->getMicroGraph()->save(path + "_" + node_ptr->type() + "#" + namePtrTable.at(node_ptr), verbose, showProducers);
}
}
if (node_ptr == mRootNode) { if (node_ptr == mRootNode) {
std::fprintf(fp, "%s_%s(%s):::rootCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), nodeCls += "_rootCls";
givenName.c_str());
} }
else {
if (node_ptr->type() == "Producer") { if (node_ptr == mRootNode || node_ptr->type() != "Producer" || showProducers) {
if (showProducers) { std::fprintf(fp, "%s_%s(%s)%s\n", node_ptr->type().c_str(), namePtrTable.at(node_ptr).c_str(),
std::fprintf(fp, "%s_%s(%s):::producerCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), givenName.c_str(), nodeCls.c_str());
givenName.c_str());
}
}
else {
std::fprintf(fp, "%s_%s(%s)\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(),
givenName.c_str());
}
} }
} }
...@@ -111,11 +120,11 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) ...@@ -111,11 +120,11 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
} }
if (mNodes.find(child) != mNodes.end()) { if (mNodes.find(child) != mNodes.end()) {
std::fprintf(fp, "%s_%s-->|\"%u%s&rarr;%u\"|%s_%s\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), std::fprintf(fp, "%s_%s-->|\"%u%s&rarr;%u\"|%s_%s\n", node_ptr->type().c_str(), namePtrTable.at(node_ptr).c_str(),
outputIdx, dims.c_str(), inputIdx, child->type().c_str(), namePtrTable[child].c_str()); outputIdx, dims.c_str(), inputIdx, child->type().c_str(), namePtrTable.at(child).c_str());
} }
else if (verbose) { else if (verbose) {
std::fprintf(fp, "%s_%s-->|\"%u%s&rarr;%u\"|%p:::externalCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), std::fprintf(fp, "%s_%s-->|\"%u%s&rarr;%u\"|%p:::externalCls\n", node_ptr->type().c_str(), namePtrTable.at(node_ptr).c_str(),
outputIdx, dims.c_str(), inputIdx, static_cast<void*>(child.get())); outputIdx, dims.c_str(), inputIdx, static_cast<void*>(child.get()));
} }
break; break;
...@@ -132,7 +141,7 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) ...@@ -132,7 +141,7 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
for (auto input : mInputNodes) { for (auto input : mInputNodes) {
if (input.first != nullptr) { if (input.first != nullptr) {
std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|&rarr;%u|%s_%s\n", inputIdx, inputIdx, std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|&rarr;%u|%s_%s\n", inputIdx, inputIdx,
input.second, input.first->type().c_str(), namePtrTable[input.first].c_str()); input.second, input.first->type().c_str(), namePtrTable.at(input.first).c_str());
} }
else { else {
std::fprintf(fp, "input%lu((in#%lu)):::inputCls\n", inputIdx, inputIdx); std::fprintf(fp, "input%lu((in#%lu)):::inputCls\n", inputIdx, inputIdx);
...@@ -146,12 +155,12 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) ...@@ -146,12 +155,12 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
// Add-on to display the operator's output dimensions // Add-on to display the operator's output dimensions
std::string dims = ""; std::string dims = "";
const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator()); const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator());
if (op && !op->getOutput(output.second)->dims().empty()) { if (op && op->getOutput(output.second) && !op->getOutput(output.second)->dims().empty()) {
dims += " " + fmt::format("{}", op->getOutput(output.second)->dims()); dims += " " + fmt::format("{}", op->getOutput(output.second)->dims());
} }
std::fprintf(fp, "%s_%s--->|\"%u%s&rarr;\"|output%lu((out#%lu)):::outputCls\n", std::fprintf(fp, "%s_%s--->|\"%u%s&rarr;\"|output%lu((out#%lu)):::outputCls\n",
output.first->type().c_str(), namePtrTable[output.first].c_str(), output.second, output.first->type().c_str(), namePtrTable.at(output.first).c_str(), output.second,
dims.c_str(), outputIdx, outputIdx); dims.c_str(), outputIdx, outputIdx);
} }
else { else {
...@@ -163,8 +172,13 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) ...@@ -163,8 +172,13 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
std::fprintf(fp, "classDef inputCls fill:#afa\n"); std::fprintf(fp, "classDef inputCls fill:#afa\n");
std::fprintf(fp, "classDef outputCls fill:#ffa\n"); std::fprintf(fp, "classDef outputCls fill:#ffa\n");
std::fprintf(fp, "classDef externalCls fill:#ccc\n"); std::fprintf(fp, "classDef externalCls fill:#ccc\n");
std::fprintf(fp, "classDef rootCls stroke:#f00\n"); std::fprintf(fp, "classDef producerCls fill:#ccf\n");
std::fprintf(fp, "classDef producerCls fill:#cbf\n"); std::fprintf(fp, "classDef genericCls fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5\n");
std::fprintf(fp, "classDef metaCls stroke-width:5px\n");
std::fprintf(fp, "classDef _rootCls stroke:#f00\n");
std::fprintf(fp, "classDef producerCls_rootCls stroke:#f00,fill:#ccf\n");
std::fprintf(fp, "classDef genericCls_rootCls stroke:#f00,fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5\n");
std::fprintf(fp, "classDef metaCls_rootCls stroke:#f00,stroke-width:5px\n");
std::fprintf(fp, "\n"); std::fprintf(fp, "\n");
std::fclose(fp); std::fclose(fp);
} }
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph) Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph)
: OperatorTensor(type, graph->dataInputs().size(), (graph->inputs().size() - graph->dataInputs().size()), graph->outputs().size()), : OperatorTensor(type, graph->dataInputs().size(), (graph->getOrderedInputs().size() - graph->dataInputs().size()), graph->getOrderedOutputs().size()),
mGraph(graph) mGraph(graph)
{ {
mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedInputs().size());
for (std::size_t i = 0; i < mInputs.size(); ++i) { for (std::size_t i = 0; i < mInputs.size(); ++i) {
mInputs[i] = std::make_shared<Tensor>(); mInputs[i] = std::make_shared<Tensor>();
} }
...@@ -24,7 +24,9 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< ...@@ -24,7 +24,9 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedOutputs().size()); mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedOutputs().size());
for (size_t outputIdx = 0; outputIdx < mOutputs.size(); ++outputIdx) { for (size_t outputIdx = 0; outputIdx < mOutputs.size(); ++outputIdx) {
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second)); if (outputOp.first) {
mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second));
}
} }
} }
...@@ -34,7 +36,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI ...@@ -34,7 +36,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI
} }
else { else {
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); if (inputOp.first) {
return inputOp.first->getOperator()->getNbRequiredData(inputOp.second);
}
else {
return 0;
}
} }
} }
...@@ -44,7 +51,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co ...@@ -44,7 +51,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co
} }
else { else {
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); if (inputOp.first) {
return inputOp.first->getOperator()->getNbConsumedData(inputOp.second);
}
else {
return 0;
}
} }
} }
...@@ -54,7 +66,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c ...@@ -54,7 +66,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c
} }
else { else {
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
return outputOp.first->getOperator()->getNbProducedData(outputOp.second); if (outputOp.first) {
return outputOp.first->getOperator()->getNbProducedData(outputOp.second);
}
else {
return 0;
}
} }
} }
......
...@@ -369,10 +369,7 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared ...@@ -369,10 +369,7 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared
return upperInput.first->getOperator()->getNbProducedData(upperInput.second); return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
} }
} }
else if (input.first != nullptr) { ++nodeInputIdx;
// Do not take into account dummy inputs from getOrderedInputs()
++nodeInputIdx;
}
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment