diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index e1f6f6bc155b1bec0f61fe0e62e7bbe9240d9a2d..d6c80e800e310b5d6890a317773f67c08d346da0 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -71,10 +71,11 @@ public: void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); - if (inputIdx == 2) { - assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template getAttr<FCAttr::NoBias>()) == false ? static_cast<std::size_t>(this->template getAttr<FCAttr::OutChannels>()) : 0)); - assert(std::dynamic_pointer_cast<Tensor>(data)->nbDims() == 1); - } + // TODO: FIXME: check this, because data dims may not be initialized at this point... + //if (inputIdx == 2) { + // assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template getAttr<FCAttr::NoBias>()) == false ? static_cast<std::size_t>(this->template getAttr<FCAttr::OutChannels>()) : 0)); + // assert(std::dynamic_pointer_cast<Tensor>(data)->nbDims() == 1); + //} mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); if (inputIdx == 0 && getInput(0)->nbDims() == 1) mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()}); diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index d51844e3b8452bb7e65228eddff9304c32a560f7..8f1de7c0e92558a4b47962c3a375764e1bd1c2ee 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -232,6 +232,7 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, {inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2}, {inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2}, {hiddenState, 1}, {cellState, 1}}); + microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}}); auto metaOp = MetaOperator("LSTM", microGraph, name); addProducer(metaOp, 1, {hidden_channels, in_channels}, "wi"); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 3ccc1174db40e161641b98dd906a69fdc4ecee3c..4aa74b0f6b9102332aaa2db4317ca59dd7f4aa74 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -23,6 +23,8 @@ #include "aidge/data/Tensor.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" /////////////////////////////////////////////////////// @@ -65,29 +67,36 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) "'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n"); // Start by creating every node - auto namePtrTable = getRankedNodesName("{3}"); + const auto namePtrTable = getRankedNodesName("{3}"); for (const std::shared_ptr<Node> &node_ptr : mNodes) { std::string givenName = (node_ptr->name().empty()) - ? "<em>" + node_ptr->type() + "#" + namePtrTable[node_ptr] + "</em>" - : "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable[node_ptr] + ")</em></sub>\""; + ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" + : "\"" + node_ptr->name() + "\\n<sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\""; + + std::string nodeCls = ""; + if (node_ptr->type() == "Producer") { + nodeCls = ":::producerCls"; + } + else if (std::dynamic_pointer_cast<GenericOperator_Op>(node_ptr->getOperator())) { + nodeCls = ":::genericCls"; + } + else if (const auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node_ptr->getOperator())) { + nodeCls = ":::metaCls"; + + if (verbose) { + metaOp->getMicroGraph()->save(path + "_" + node_ptr->type() + "#" + namePtrTable.at(node_ptr), verbose, showProducers); + } + } if (node_ptr == mRootNode) { - std::fprintf(fp, "%s_%s(%s):::rootCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), - givenName.c_str()); + nodeCls += "_rootCls"; } - else { - if (node_ptr->type() == "Producer") { - if (showProducers) { - std::fprintf(fp, "%s_%s(%s):::producerCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), - givenName.c_str()); - } - } - else { - std::fprintf(fp, "%s_%s(%s)\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), - givenName.c_str()); - } + + if (node_ptr == mRootNode || node_ptr->type() != "Producer" || showProducers) { + std::fprintf(fp, "%s_%s(%s)%s\n", node_ptr->type().c_str(), namePtrTable.at(node_ptr).c_str(), + givenName.c_str(), nodeCls.c_str()); } } @@ -111,11 +120,11 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) } if (mNodes.find(child) != mNodes.end()) { - std::fprintf(fp, "%s_%s-->|\"%u%s→%u\"|%s_%s\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), - outputIdx, dims.c_str(), inputIdx, child->type().c_str(), namePtrTable[child].c_str()); + std::fprintf(fp, "%s_%s-->|\"%u%s→%u\"|%s_%s\n", node_ptr->type().c_str(), namePtrTable.at(node_ptr).c_str(), + outputIdx, dims.c_str(), inputIdx, child->type().c_str(), namePtrTable.at(child).c_str()); } else if (verbose) { - std::fprintf(fp, "%s_%s-->|\"%u%s→%u\"|%p:::externalCls\n", node_ptr->type().c_str(), namePtrTable[node_ptr].c_str(), + std::fprintf(fp, "%s_%s-->|\"%u%s→%u\"|%p:::externalCls\n", node_ptr->type().c_str(), namePtrTable.at(node_ptr).c_str(), outputIdx, dims.c_str(), inputIdx, static_cast<void*>(child.get())); } break; @@ -132,7 +141,7 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) for (auto input : mInputNodes) { if (input.first != nullptr) { std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|→%u|%s_%s\n", inputIdx, inputIdx, - input.second, input.first->type().c_str(), namePtrTable[input.first].c_str()); + input.second, input.first->type().c_str(), namePtrTable.at(input.first).c_str()); } else { std::fprintf(fp, "input%lu((in#%lu)):::inputCls\n", inputIdx, inputIdx); @@ -146,12 +155,12 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) // Add-on to display the operator's output dimensions std::string dims = ""; const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator()); - if (op && !op->getOutput(output.second)->dims().empty()) { + if (op && op->getOutput(output.second) && !op->getOutput(output.second)->dims().empty()) { dims += " " + fmt::format("{}", op->getOutput(output.second)->dims()); } std::fprintf(fp, "%s_%s--->|\"%u%s→\"|output%lu((out#%lu)):::outputCls\n", - output.first->type().c_str(), namePtrTable[output.first].c_str(), output.second, + output.first->type().c_str(), namePtrTable.at(output.first).c_str(), output.second, dims.c_str(), outputIdx, outputIdx); } else { @@ -163,8 +172,13 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) std::fprintf(fp, "classDef inputCls fill:#afa\n"); std::fprintf(fp, "classDef outputCls fill:#ffa\n"); std::fprintf(fp, "classDef externalCls fill:#ccc\n"); - std::fprintf(fp, "classDef rootCls stroke:#f00\n"); - std::fprintf(fp, "classDef producerCls fill:#cbf\n"); + std::fprintf(fp, "classDef producerCls fill:#ccf\n"); + std::fprintf(fp, "classDef genericCls fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5\n"); + std::fprintf(fp, "classDef metaCls stroke-width:5px\n"); + std::fprintf(fp, "classDef _rootCls stroke:#f00\n"); + std::fprintf(fp, "classDef producerCls_rootCls stroke:#f00,fill:#ccf\n"); + std::fprintf(fp, "classDef genericCls_rootCls stroke:#f00,fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5\n"); + std::fprintf(fp, "classDef metaCls_rootCls stroke:#f00,stroke-width:5px\n"); std::fprintf(fp, "\n"); std::fclose(fp); } diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 0ff758a564223b98c0fb1b422e6e8f64249a707b..7a3780036f730d1f1d635a75f99d44a2f073d1bb 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -13,10 +13,10 @@ #include "aidge/utils/ErrorHandling.hpp" Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph) - : OperatorTensor(type, graph->dataInputs().size(), (graph->inputs().size() - graph->dataInputs().size()), graph->outputs().size()), + : OperatorTensor(type, graph->dataInputs().size(), (graph->getOrderedInputs().size() - graph->dataInputs().size()), graph->getOrderedOutputs().size()), mGraph(graph) { - mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); + mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedInputs().size()); for (std::size_t i = 0; i < mInputs.size(); ++i) { mInputs[i] = std::make_shared<Tensor>(); } @@ -24,7 +24,9 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedOutputs().size()); for (size_t outputIdx = 0; outputIdx < mOutputs.size(); ++outputIdx) { const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; - mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second)); + if (outputOp.first) { + mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second)); + } } } @@ -34,7 +36,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI } else { const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; - return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); + if (inputOp.first) { + return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); + } + else { + return 0; + } } } @@ -44,7 +51,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co } else { const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; - return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); + if (inputOp.first) { + return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); + } + else { + return 0; + } } } @@ -54,7 +66,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c } else { const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; - return outputOp.first->getOperator()->getNbProducedData(outputOp.second); + if (outputOp.first) { + return outputOp.first->getOperator()->getNbProducedData(outputOp.second); + } + else { + return 0; + } } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 16d2104ba1f1ceda4b65c1f6901f3ebc29cf8c99..fcff5b8f43440229636bc65be8100d706a74d177 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -369,10 +369,7 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared return upperInput.first->getOperator()->getNbProducedData(upperInput.second); } } - else if (input.first != nullptr) { - // Do not take into account dummy inputs from getOrderedInputs() - ++nodeInputIdx; - } + ++nodeInputIdx; } }