diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 813301a144682ba3e99de31ae324ffaedcc5209f..c0aaea9025389ace11705635491cb7ffb50aa5c9 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -62,11 +62,7 @@ public: return mNodes == gv.mNodes; } - NodePtr operator[](const std::string& name) - { - assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView."); - return mNodeRegistry.at(name); - } + const NodePtr operator[](const std::string& nodeName) const; /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION @@ -82,14 +78,14 @@ public: * @brief Name of the node. * @return std::string */ - std::string name() const; + inline std::string name() const noexcept { return mName; } /** * @brief Set the node name. * @warning Undefined behaviour when several Nodes have the same name. * @param name New name for the node. */ - void setName(const std::string &name); + inline void setName(const std::string &name) { mName = name; } /** * @brief Save the GraphView as a Mermaid graph in a .md file at the @@ -98,11 +94,9 @@ public: */ void save(std::string path, bool verbose = false, bool showProducers = true) const; - inline bool inView(NodePtr nodePtr) const { - return mNodes.find(nodePtr) != mNodes.end(); - } + bool inView(const NodePtr& nodePtr) const; - NodePtr getRootNode() { + inline NodePtr getRootNode() const noexcept { return mRootNode; } @@ -111,37 +105,22 @@ public: /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ - inline std::set<NodePtr> inputNodes() const noexcept { - std::set<NodePtr> nodes; - for (auto node : mInputNodes) { - nodes.insert(node.first); - } - return nodes; - } + std::set<NodePtr> inputNodes() const; + /** @brief Get reference to the set of output Nodes. */ - inline std::set<NodePtr> outputNodes() const noexcept { - std::set<NodePtr> nodes; - for (auto node : mOutputNodes) { - nodes.insert(node.first); - } - return nodes; - } + std::set<NodePtr> outputNodes() const; + /** @brief Assess if the given Node is an input Node of the GraphView object. */ - inline bool isInputNode(NodePtr nodePtr) const { - const auto nodes = inputNodes(); - return (nodes.find(nodePtr) != nodes.end()) ? true : false; - } + bool isInputNode(const NodePtr& nodePtr) const; + /** @brief Assess if the given Node is an output Node of the GraphView object. */ - inline bool isOutputNode(NodePtr nodePtr) const { - const auto nodes = outputNodes(); - return (nodes.find(nodePtr) != nodes.end()) ? true : false; - } + bool isOutputNode(const NodePtr& nodePtr) const; void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); - inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; }; - inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; }; + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const noexcept { return mInputNodes; }; + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const noexcept { return mOutputNodes; }; /** * @brief List outside data input connections of the GraphView. @@ -212,9 +191,9 @@ public: void forwardDims(); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ - void setBackend(const std::string &backend, DeviceIdx_t device = 0); + void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ - void setDataType(const DataType &datatype); + void setDataType(const DataType& datatype) const; /////////////////////////////////////////////////////// // TOPOLOGY diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 968e98e75cc587977eb3033fe7f25936880755a4..5abd1d71a077eb3b7ad3fdfb3e4b23715d3a2288 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -21,6 +21,11 @@ #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/ErrorHandling.hpp" + +const std::shared_ptr<Aidge::Node> Aidge::GraphView::operator[](const std::string& nodeName) const { + return (mNodeRegistry.find(nodeName) != mNodeRegistry.cend()) ? mNodeRegistry.at(nodeName) : nullptr; +} + /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// @@ -50,10 +55,9 @@ Aidge::Connector Aidge::GraphView::operator()( // INNER /////////////////////////////////////////////////////// -std::string Aidge::GraphView::name() const { return mName; } - -void Aidge::GraphView::setName(const std::string &name) { mName = name; } - +bool Aidge::GraphView::inView(const std::shared_ptr<Aidge::Node>& nodePtr) const { + return mNodes.find(nodePtr) != mNodes.cend(); +} void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const { FILE *fp = std::fopen((path + ".mmd").c_str(), "w"); @@ -154,6 +158,33 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) // TENSOR MANAGEMENT /////////////////////////////////////////////////////// +std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const { + std::set<std::shared_ptr<Aidge::Node>> nodes; + for (const auto& node : mInputNodes) { + nodes.insert(node.first); + } + return nodes; +} + +std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::outputNodes() const { + std::set<std::shared_ptr<Aidge::Node>> nodes; + for (const auto& node : mOutputNodes) { + nodes.insert(node.first); + } + return nodes; +} + +bool Aidge::GraphView::isInputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const { + const auto nodes = inputNodes(); + return (nodes.find(nodePtr) != nodes.cend()); +} + +bool Aidge::GraphView::isOutputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const { + const auto nodes = outputNodes(); + return (nodes.find(nodePtr) != nodes.cend()); +} + + void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) { AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs"); @@ -324,14 +355,14 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { } } -void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) { - for (auto node : getNodes()) { +void Aidge::GraphView::setBackend(const std::string &backend, const DeviceIdx_t device) const { + for (const auto& node : getNodes()) { node->getOperator()->setBackend(backend, device); } } -void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) { - for (auto node : getNodes()) { +void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) const { + for (const auto& node : getNodes()) { node->getOperator()->setDataType(datatype); } } @@ -508,11 +539,9 @@ bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool inc } bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { - if (mRootNode == nullptr) { - mRootNode = graph->getRootNode(); - } - - return add(graph->getNodes(), false); + // set the rootNode to the other graphView rootNode if no rootNode yet + mRootNode = mRootNode ? mRootNode : graph->getRootNode(); + return add(graph->getNodes(), false); } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,