From bc289a200b959cbd34b52ba1511d9cb94a2cfd47 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 14 Nov 2023 11:44:57 +0100 Subject: [PATCH] Changed outputs() and getNbDataInputs() in GraphView to achieve consistent behavior (clarified in doctring) --- include/aidge/graph/GraphView.hpp | 29 +++++++++++----- include/aidge/graph/Node.hpp | 11 ++++-- include/aidge/operator/MetaOperator.hpp | 9 +++-- include/aidge/utils/Recipies.hpp | 6 ++-- src/graph/GraphView.cpp | 45 ++++++++++++++++++------- src/operator/MetaOperator.cpp | 20 +++++------ 6 files changed, 80 insertions(+), 40 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 481099726..6fc571c66 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -41,10 +41,10 @@ private: /// @brief Set of nodes included in the graphview with names std::map<std::string, NodePtr> mNodeRegistry; - /// @brief Nodes without input link + /// @brief Nodes without input link (computable, cached) std::set<NodePtr> mInputNodes; - /// @brief Nodes without output link + /// @brief Nodes without output link (computable, cached) std::set<NodePtr> mOutputNodes; public: @@ -124,38 +124,46 @@ public: } /** - * @brief List outside dataInput connections of the GraphView object's inputNodes. + * @brief List outside data input connections of the GraphView. + * Data inputs exclude inputs expecting parameters (weights or bias). + * The vector size is garanteed to match the number of outside data inputs of the GraphView. If there is + * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; /** - * @brief List dataInput connections of the GraphView object's inputNodes. + * @brief List all dataInput connections (within and outside) of the specified GraphView node named "name". + * Data inputs exclude inputs expecting parameters (weights or bias). * @param name Name of the Node. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** - * @brief List outside input connections of the GraphView object's inputNodes. + * @brief List outside input connections of the GraphView. The vector + * size is garanteed to match the number of outside inputs of the GraphView. If there is + * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; /** - * @brief List input connections of the specified GraphView object's inputNode. + * @brief List all input connections (within and outside) of the specified GraphView node named "name". * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const; /** - * @brief List output connections of the GraphView object's outputNodes. + * @brief List outside output connections of the GraphView. The vector + * size is garanteed to match the number of outputs of the GraphView. If there is + * no connection to a given output, the corresponding sub-vector will be empty. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; /** - * @brief Specific i-th output connection of the GraphView object. + * @brief List all output connections (within and outside) of the specified GraphView node named "name". * @param nodeName Name of the Node of which to show the output. * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> */ @@ -388,6 +396,7 @@ public: /** * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return IOIndex_t */ IOIndex_t getNbFreeDataInputs() const; @@ -398,7 +407,9 @@ private: /////////////////////////////////////////////////////// /** - * @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object. + * @brief Get the number of dataInput that are outside the GraphView. + * Data inputs exclude inputs expecting parameters (weights or bias). + * This number matches the size of the vector returned by GraphView::dataInputs(). * @return IOIndex_t */ IOIndex_t getNbDataInputs() const; diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index f1d0a39d4..c1636734f 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -140,7 +140,8 @@ public: /** * @brief List of pair <Parent, ID of the data intput>. When an input is not - * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; @@ -174,6 +175,7 @@ public: /** * @brief Get the lowest index in the InputData Parent list equal to the * nullptr. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return std::size_t */ inline IOIndex_t getFirstFreeDataInput() const { @@ -187,7 +189,9 @@ public: IOIndex_t getNbFreeDataInputs() const; /** - * @brief List input ids of children liked to outputs of the node + * @brief List input ids of children linked to outputs of the node. The vector + * size is garanteed to match the number of outputs of the node. If there is + * no connection to a given output, the corresponding sub-vector will be empty. * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ @@ -210,7 +214,8 @@ public: inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); } /** - * @brief Number of input specifically for data + * @brief Number of input specifically for data. + * Data inputs exclude inputs expecting parameters (weights or bias). * @details [data, data, weight, bias] => 2 * @return IOIndex_t */ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 72058dfcb..e74ae55e6 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -27,11 +27,11 @@ public: // Micro-graph handling: std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph std::shared_ptr<SequentialScheduler> mScheduler; - // Need to store an ordored list of input/output operators for the micro-graph, + // Need to store an ordored list of input/output nodes for the micro-graph, // because input/output nodes in a GraphView are unordered. // TODO: refactor GraphView to handle ordered input/output? - std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mInputOps; - std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mOutputOps; + std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes; + std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes; public: MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, @@ -47,6 +47,9 @@ public: mGraph(op.mGraph->clone()) { // cpy-ctor + // TODO: FIXME: mInputOps and mOutputOps are not populated! + // Issue: how to map new (cloned) nodes with old nodes? getNodes() does + // not garantee any order! Check issue #52. } /** diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index c110c9cf8..7428d9d22 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -31,7 +31,7 @@ void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void fuseMulAdd(std::shared_ptr<GraphView> graphView); @@ -47,7 +47,7 @@ void removeFlatten(std::set<std::shared_ptr<Node>> nodes); /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void removeFlatten(std::shared_ptr<GraphView> graphView); @@ -64,7 +64,7 @@ void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes); * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void fuseBatchNorm(std::shared_ptr<GraphView> graphView); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 367c9f10d..e8e10798b 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -108,15 +108,26 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const { IOIndex_t nbDataInput = 0; - // assert(outputNodes().size() == static_cast<std::size_t>(1)); for (const std::shared_ptr<Node> &inNode : inputNodes()) { - nbDataInput += inNode->nbDataInputs(); + // We cannot simply add inNode->nbDataInputs(), as input nodes may already + // have some inputs connected within the GraphView, which would therefore not + // constitue inputs (from outside) for the GraphView! + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inNode->dataInputs(); + + for (const auto& input : inputNodeinputs) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { + ++nbDataInput; + } + } } return nbDataInput; } Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { IOIndex_t nbIn = 0; + // Free inputs within the GraphView are logically also free inputs from outside + // the GraphView. for (const std::shared_ptr<Node>& inputNode : mInputNodes) { nbIn += inputNode->getNbFreeDataInputs(); } @@ -129,11 +140,11 @@ Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); for (const auto& input : inputNodeinputs) { - if (mNodes.find(input.first) == mNodes.end()) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } @@ -147,11 +158,11 @@ Aidge::GraphView::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); for (const auto& input : inputNodeinputs) { - if (mNodes.find(input.first) == mNodes.end()) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } @@ -286,14 +297,24 @@ std::vector< std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::GraphView::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> - outputTensors; + outsideOutputs; for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { - std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> - tmpOutputs = (outputNode->outputs()); - outputTensors.insert(outputTensors.end(), tmpOutputs.begin(), - tmpOutputs.end()); + const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> + outputNodeOutputs = outputNode->outputs(); + + for (const auto& outputPos : outputNodeOutputs) { + // Keep only the nodes connected at this output position that are outside the GraphView + std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos; + for (const auto& output : outputPos) { + if (mNodes.find(output.first) == mNodes.end()) { + outsideOutputPos.push_back(output); + } + } + + outsideOutputs.push_back(outsideOutputPos); + } } - return outputTensors; + return outsideOutputs; } std::vector< diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index c1f58c686..28fafb6a8 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -53,7 +53,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< // The input is not connected inside the micro-graph // (no connection to this input or connection outside the micro-graph) // => it is therefore an input for the meta-operator - mInputOps.push_back(std::make_pair(inputNode->getOperator(), inputIdx)); + mInputNodes.push_back(std::make_pair(inputNode, inputIdx)); } ++inputIdx; @@ -67,12 +67,12 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< outputNode->outputs(); for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) { - mOutputOps.push_back(std::make_pair(outputNode->getOperator(), outputIdx)); + mOutputNodes.push_back(std::make_pair(outputNode, outputIdx)); } } - AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); - AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); + AIDGE_INTERNAL_ASSERT(mInputNodes.size() == mGraph->inputs().size()); + AIDGE_INTERNAL_ASSERT(mOutputNodes.size() == mGraph->outputs().size()); } Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { @@ -80,8 +80,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI return mImpl->getNbRequiredData(inputIdx); } else { - const auto& inputOp = mInputOps[inputIdx]; - return inputOp.first->getNbRequiredData(inputOp.second); + const auto& inputOp = mInputNodes[inputIdx]; + return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); } } @@ -90,8 +90,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co return mImpl->getNbConsumedData(inputIdx); } else { - const auto& inputOp = mInputOps[inputIdx]; - return inputOp.first->getNbConsumedData(inputOp.second); + const auto& inputOp = mInputNodes[inputIdx]; + return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); } } @@ -100,8 +100,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c return mImpl->getNbProducedData(outputIdx); } else { - const auto& outputOp = mOutputOps[outputIdx]; - return outputOp.first->getNbProducedData(outputOp.second); + const auto& outputOp = mOutputNodes[outputIdx]; + return outputOp.first->getOperator()->getNbProducedData(outputOp.second); } } -- GitLab