diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index ccff976cbb7cf8efc59223dfd658ca2a4d03a80b..744dbd1327a83267b7840e03ba83190326ee6cdd 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -37,7 +37,7 @@ public: std::weak_ptr<Node> mUpperNode; public: - MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph); + MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory = {}); /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -113,6 +113,7 @@ public: std::shared_ptr<Node> MetaOperator(const char *type, const std::shared_ptr<GraphView>& graph, + const std::vector<InputCategory>& forcedInputsCategory = {}, const std::string& name = ""); } // namespace Aidge diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index bc3348377525cdd2e5b2c030c8fc6b7cb8177e7b..750a808aaeb23447578501f8b27c7eba3d34234c 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -126,7 +126,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> & MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims, ceil_mode) }); - return MetaOperator("PaddedMaxPooling", graph, name); + return MetaOperator("PaddedMaxPooling", graph, {}, name); } template <std::array<DimSize_t, 1>::size_type DIM> diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index d021a79c5ff4e337bebf424465458ddabf056a56..afd682f3e546b408b231a14e55a7ba5432fef430 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -195,14 +195,17 @@ void init_MetaOperatorDefs(py::module &m) { declare_LSTMOp(m); py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance()) - .def(py::init<const char *, const std::shared_ptr<GraphView>&>(), + .def(py::init<const char *, const std::shared_ptr<GraphView>&, const std::vector<InputCategory>&>(), py::arg("type"), - py::arg("graph")) - .def("get_micro_graph", &MetaOperator_Op::getMicroGraph); + py::arg("graph"), + py::arg("forced_inputs_category") = std::vector<InputCategory>()) + .def("get_micro_graph", &MetaOperator_Op::getMicroGraph) + .def("set_upper_node", &MetaOperator_Op::setUpperNode); m.def("meta_operator", &MetaOperator, py::arg("type"), py::arg("graph"), + py::arg("forced_inputs_category") = std::vector<InputCategory>(), py::arg("name") = "" ); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index f7b6b97cdf2e23080e17b3a162b72a327a893ca4..b2c03e794888a0909ada5db208fc07ad266d4ae2 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -144,7 +144,9 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd } IOIndex_t outputIdx = 0; for (const auto& childs : node_ptr->getOrderedChildren()) { - for (const auto& child : childs) { + // Keep only unique childs in order to avoid duplicating connections + const auto uniqueChilds = std::set<NodePtr>(childs.begin(), childs.end()); + for (const auto& child : uniqueChilds) { if (child != nullptr) { IOIndex_t inputIdx = 0; for (auto parent : child->inputs()) { @@ -164,7 +166,7 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), outputIdx, dims, inputIdx, static_cast<void*>(child.get())); } - break; + // Do no break here because the same child can be connected to several inputs } ++inputIdx; } @@ -270,7 +272,10 @@ void Aidge::GraphView::setRootNode(NodePtr node) { std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const { std::set<std::shared_ptr<Aidge::Node>> nodes; for (const auto& node : mInputNodes) { - nodes.insert(node.first); + // Do not include dummy inputs + if (node.first) { + nodes.insert(node.first); + } } return nodes; } @@ -278,7 +283,10 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const { std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::outputNodes() const { std::set<std::shared_ptr<Aidge::Node>> nodes; for (const auto& node : mOutputNodes) { - nodes.insert(node.first); + // Do not include dummy outputs + if (node.first) { + nodes.insert(node.first); + } } return nodes; } diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 1a71737479f0c98cddcd4d1437012bfb16d2dc85..ab6bde74fb73011f7b49e6958d8cfa8320d0bc1b 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -20,17 +20,22 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/DynamicAttributes.hpp" -Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph) - : OperatorTensor(type, [graph]() { +Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory) + : OperatorTensor(type, [graph, forcedInputsCategory]() { + IOIndex_t inputIdx = 0; std::vector<InputCategory> inputsCategory; for (const auto& in : graph->getOrderedInputs()) { - if (in.first) { + if (inputIdx < forcedInputsCategory.size()) { + inputsCategory.push_back(forcedInputsCategory[inputIdx]); + } + else if (in.first) { inputsCategory.push_back(in.first->getOperator()->inputCategory(in.second)); } else { // Dummy input, default to OptionalData inputsCategory.push_back(InputCategory::OptionalData); } + ++inputIdx; } return inputsCategory; }(), graph->getOrderedOutputs().size()), @@ -54,6 +59,7 @@ void Aidge::MetaOperator_Op::associateInput(const IOIndex_t inputIdx, const std: AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx); const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + AIDGE_ASSERT(inputOp.first, "associateInput(): inputIdx ({}) is a dummy input for this MetaOperator, cannot associate data!", inputIdx); inputOp.first->getOperator()->associateInput(inputOp.second, data); // Associate inputs for custom implementation @@ -64,6 +70,7 @@ void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + AIDGE_ASSERT(inputOp.first, "setInput(): inputIdx ({}) is a dummy input for this MetaOperator, cannot associate data!", inputIdx); inputOp.first->getOperator()->setInput(inputOp.second, data); // Associate inputs for custom implementation @@ -243,9 +250,10 @@ void Aidge::MetaOperator_Op::forward() { std::shared_ptr<Aidge::Node> Aidge::MetaOperator(const char *type, const std::shared_ptr<Aidge::GraphView>& graph, + const std::vector<InputCategory>& forcedInputsCategory, const std::string& name) { - auto op = std::make_shared<MetaOperator_Op>(type, graph); + auto op = std::make_shared<MetaOperator_Op>(type, graph, forcedInputsCategory); auto node = std::make_shared<Node>(op, name); op->setUpperNode(node); return node; diff --git a/src/operator/MetaOperatorDefs/LSTM.cpp b/src/operator/MetaOperatorDefs/LSTM.cpp index 910e7c67aad0068679ca2d240b23312add3e42d7..9620f040472aed984afb99018cde5476ec5f60d3 100644 --- a/src/operator/MetaOperatorDefs/LSTM.cpp +++ b/src/operator/MetaOperatorDefs/LSTM.cpp @@ -115,7 +115,7 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel, {hiddenState, 1}, {cellState, 1}}); microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}}); - auto metaOp = MetaOperator("LSTM", microGraph, name); + auto metaOp = MetaOperator("LSTM", microGraph, {}, name); addProducer(metaOp, 1, {hiddenChannel, inChannel}, "wi"); addProducer(metaOp, 2, {hiddenChannel, inChannel}, "wo"); addProducer(metaOp, 3, {hiddenChannel, inChannel}, "wf"); diff --git a/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp b/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp index ef319ef38ad18de9eaed0a1d4a92c3877ee7cf8e..c35d964d0cdd224e9d00eadf6e158bc87b4c776f 100644 --- a/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp +++ b/src/operator/MetaOperatorDefs/PaddedAvgPooling.cpp @@ -41,7 +41,7 @@ std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &kernel_ AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims) }); - return MetaOperator("PaddedAvgPooling", graph, name); + return MetaOperator("PaddedAvgPooling", graph, {}, name); } template std::shared_ptr<Node> PaddedAvgPooling<1>(const std::array<DimSize_t,1>&, const std::string&, const std::array<DimSize_t,1>&, const std::array<DimSize_t,2>&); diff --git a/src/operator/MetaOperatorDefs/PaddedConv.cpp b/src/operator/MetaOperatorDefs/PaddedConv.cpp index 31b1c675e9d577002350ea11dd0b42601a91ef76..49373341a3a7cd1dd764dbfcb385a1817079e8b0 100644 --- a/src/operator/MetaOperatorDefs/PaddedConv.cpp +++ b/src/operator/MetaOperatorDefs/PaddedConv.cpp @@ -43,7 +43,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConv(Aidge::DimSize_t in_channels, Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "") }); - auto metaOpNode = MetaOperator("PaddedConv", graph, name); + auto metaOpNode = MetaOperator("PaddedConv", graph, {}, name); addProducer(metaOpNode, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); if (!no_bias) { addProducer(metaOpNode, 2, {out_channels}, "b"); diff --git a/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp b/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp index 1c073b78a61763b46e330089cccfcc4bced352a4..12d980b4073c115443fe0ed8db38f978aa98dcad 100644 --- a/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp +++ b/src/operator/MetaOperatorDefs/PaddedConvDepthWise.cpp @@ -40,7 +40,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConvDepthWise(const Aidge::DimSize_t n Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv_depth_wise" : "") }); - auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, name); + auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, {}, name); addProducer(metaOpNode, 1, append(nb_channels, append(Aidge::DimSize_t(1), kernel_dims)), "w"); if (!no_bias) { addProducer(metaOpNode, 2, {nb_channels}, "b"); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 1613450508ea84a230f36ba6526a1322c6a70559..958b2543208dfdce3eee4e1ba7a22cc8bd0be74b 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -63,22 +63,15 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; - // 1) Initialize consumers list: - // 1.1) List of the GraphView's input nodes - std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); - - // 1.2) List of nodes inside the GraphView connected to an inner Producer + // 1) Initialize consumers list: start from the output nodes and + // find the required prior producers/consumers at step 2). + // Beware that generateBaseScheduling() can be called multiple time + // with some node having already produced some data. In this case, + // we should always consume available data first. This is ensured + // by setting the consumers list to the output nodes and then recursively + // find the dependencies. + std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes(); std::set<std::shared_ptr<Node>> producers; - for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { - if (nodePtr->type() == Producer_Op::Type) { - for (const auto& child : nodePtr->getChildren()) { - // Do not schedule childs outside current graph! - if (mGraphView->inView(child)) { - consumers.insert(child); - } - } - } - } do { // 2) From the current consumers list, check if any prior consumer node