Skip to content
Snippets Groups Projects
Commit d64f9665 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Untested PoC

parent bc289a20
No related branches found
No related tags found
1 merge request!53GraphView inputs/outputs ordering
Pipeline #34725 failed
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <set> #include <unordered_set>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -41,11 +41,11 @@ private: ...@@ -41,11 +41,11 @@ private:
/// @brief Set of nodes included in the graphview with names /// @brief Set of nodes included in the graphview with names
std::map<std::string, NodePtr> mNodeRegistry; std::map<std::string, NodePtr> mNodeRegistry;
/// @brief Nodes without input link (computable, cached) /// @brief GraphView inputs
std::set<NodePtr> mInputNodes; std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes;
/// @brief Nodes without output link (computable, cached) /// @brief GraphView outputs
std::set<NodePtr> mOutputNodes; std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
public: public:
GraphView(std::string name="") GraphView(std::string name="")
...@@ -54,11 +54,21 @@ public: ...@@ -54,11 +54,21 @@ public:
// ctor // ctor
} }
// GraphView(std::set<NodePtr> nodes, std::string name="") /**
// : mName(name) * Construct a GraphView from a set of nodes. The startNode parameters
// { * allows to define a default inputs/ouputs order relative to this node.
// add(nodes); * For two topologically identical graphs, using the same topological node
// } * as starting node will lead to the same topologically ordered inputs/outputs list.
* Otherwise, inputs/outputs order will be arbitrary.
*/
GraphView(std::set<NodePtr> nodes, NodePtr startNode = nullptr, std::string name="")
: mName(name)
{
if (startNode != nullptr) {
add(startNode, false);
}
add(nodes);
}
bool operator==(const GraphView &gv) const bool operator==(const GraphView &gv) const
{ {
...@@ -110,19 +120,35 @@ public: ...@@ -110,19 +120,35 @@ public:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
public: public:
/** @brief Get reference to the set of input Nodes. */ /** @brief Get reference to the set of input Nodes. */
inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; } inline const std::set<NodePtr>& inputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mInputNodes) {
nodes.insert(node.first);
}
return nodes;
}
/** @brief Get reference to the set of output Nodes. */ /** @brief Get reference to the set of output Nodes. */
inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; } inline const std::set<NodePtr>& outputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mOutputNodes) {
nodes.insert(node.first);
}
return nodes;
}
/** @brief Assess if the given Node is an input Node of the GraphView object. */ /** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const { inline bool isInputNode(NodePtr nodePtr) const {
return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false; const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
} }
/** @brief Assess if the given Node is an output Node of the GraphView object. */ /** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const { inline bool isOutputNode(NodePtr nodePtr) const {
return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false; const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
} }
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
/** /**
* @brief List outside data input connections of the GraphView. * @brief List outside data input connections of the GraphView.
* Data inputs exclude inputs expecting parameters (weights or bias). * Data inputs exclude inputs expecting parameters (weights or bias).
...@@ -357,11 +383,10 @@ public: ...@@ -357,11 +383,10 @@ public:
*/ */
static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes);
void updateInputNodes();
/** /**
* @brief Process from zero the set of output Nodes. * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes with ordered inputs/outputs in a GraphView if possible.
*/ */
void updateOutputNodes(); static bool replace(const std::shared_ptr<GraphView>& oldNodes, const std::shared_ptr<GraphView>& newNodes);
/** /**
* @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones.
...@@ -415,27 +440,33 @@ private: ...@@ -415,27 +440,33 @@ private:
IOIndex_t getNbDataInputs() const; IOIndex_t getNbDataInputs() const;
/** /**
* @brief Update the set of inputNodes with a new Node, checking if it can be * @brief Update inputs/outputs of the GraphView, with no particular order.
* added and removing any Node not part of mInputNode anymore. * This function DOES NOT preserve inputs/outputs order and should NOT BE USED.
* It is here only to leave time to adapt the replace() function.
*/
[[deprecated]] void updateInputsOutputsNodes();
/**
* @brief Automatically update GraphView inputs/outputs with a new Node, checking if
* it this Node becomes an input/output for the graph and if previous inputs are still
* inputs/outputs after adding this node.
* @param nodePtr * @param nodePtr
*/ */
void updateInputNodes(NodePtr node); void updateInputsOutputsNew(NodePtr newNode);
/** /**
* @brief Update the set of outputNodes with a new Node, checking if it can be * @brief Automatically update GraphView inputs/outputs with a Node removed, checking if
* added and removing any Node not part of mOutputNode anymore. * it this Node was an input/output for the graph and if this node childs become new inputs/outputs
* for the graph.
* @param nodePtr * @param nodePtr
*/ */
void updateOutputNodes(NodePtr node); void updateInputsOutputsDelete(NodePtr deletedNode);
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TOPOLOGY // TOPOLOGY
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes); void _forwardDims(std::set<NodePtr> listNodes);
void removeInputNode(const std::string nodeName);
void removeOutputNode(const std::string nodeName);
}; };
} // namespace Aidge } // namespace Aidge
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment