Skip to content
Snippets Groups Projects
Commit 6a299939 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] GraphView hpp and cpp

parent 78230a51
No related branches found
No related tags found
3 merge requests!105version 0.2.0,!88Basic supervised learning,!79Scheduler backward
......@@ -62,11 +62,7 @@ public:
return mNodes == gv.mNodes;
}
NodePtr operator[](const std::string& name)
{
assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
return mNodeRegistry.at(name);
}
const NodePtr operator[](const std::string& nodeName) const;
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
......@@ -82,14 +78,14 @@ public:
* @brief Name of the node.
* @return std::string
*/
std::string name() const;
inline std::string name() const noexcept { return mName; }
/**
* @brief Set the node name.
* @warning Undefined behaviour when several Nodes have the same name.
* @param name New name for the node.
*/
void setName(const std::string &name);
inline void setName(const std::string &name) { mName = name; }
/**
* @brief Save the GraphView as a Mermaid graph in a .md file at the
......@@ -98,11 +94,9 @@ public:
*/
void save(std::string path, bool verbose = false, bool showProducers = true) const;
inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end();
}
bool inView(const NodePtr& nodePtr) const;
NodePtr getRootNode() {
inline NodePtr getRootNode() const noexcept {
return mRootNode;
}
......@@ -111,31 +105,16 @@ public:
///////////////////////////////////////////////////////
public:
/** @brief Get reference to the set of input Nodes. */
inline std::set<NodePtr> inputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mInputNodes) {
nodes.insert(node.first);
}
return nodes;
}
std::set<NodePtr> inputNodes() const;
/** @brief Get reference to the set of output Nodes. */
inline std::set<NodePtr> outputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mOutputNodes) {
nodes.insert(node.first);
}
return nodes;
}
std::set<NodePtr> outputNodes() const;
/** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const {
const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
bool isInputNode(const NodePtr& nodePtr) const;
/** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const {
const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
bool isOutputNode(const NodePtr& nodePtr) const;
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
......@@ -212,9 +191,9 @@ public:
void forwardDims();
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string &backend, DeviceIdx_t device = 0);
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setDataType(const DataType &datatype);
void setDataType(const DataType& datatype) const;
///////////////////////////////////////////////////////
// TOPOLOGY
......
......@@ -21,6 +21,11 @@
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
const std::shared_ptr<Aidge::Node> Aidge::GraphView::operator[](const std::string& nodeName) const {
return (mNodeRegistry.find(nodeName) != mNodeRegistry.cend()) ? mNodeRegistry.at(nodeName) : nullptr;
}
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
......@@ -50,10 +55,9 @@ Aidge::Connector Aidge::GraphView::operator()(
// INNER
///////////////////////////////////////////////////////
std::string Aidge::GraphView::name() const { return mName; }
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
bool Aidge::GraphView::inView(const std::shared_ptr<Aidge::Node>& nodePtr) const {
return mNodes.find(nodePtr) != mNodes.cend();
}
void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) const {
FILE *fp = std::fopen((path + ".mmd").c_str(), "w");
......@@ -154,6 +158,33 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers)
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const {
std::set<std::shared_ptr<Aidge::Node>> nodes;
for (const auto& node : mInputNodes) {
nodes.insert(node.first);
}
return nodes;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::outputNodes() const {
std::set<std::shared_ptr<Aidge::Node>> nodes;
for (const auto& node : mOutputNodes) {
nodes.insert(node.first);
}
return nodes;
}
bool Aidge::GraphView::isInputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const {
const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.cend());
}
bool Aidge::GraphView::isOutputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const {
const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.cend());
}
void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs");
......@@ -324,14 +355,14 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
}
}
void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) {
for (auto node : getNodes()) {
void Aidge::GraphView::setBackend(const std::string &backend, const DeviceIdx_t device) const {
for (const auto& node : getNodes()) {
node->getOperator()->setBackend(backend, device);
}
}
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) {
for (auto node : getNodes()) {
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) const {
for (const auto& node : getNodes()) {
node->getOperator()->setDataType(datatype);
}
}
......@@ -508,11 +539,9 @@ bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool inc
}
bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
if (mRootNode == nullptr) {
mRootNode = graph->getRootNode();
}
return add(graph->getNodes(), false);
// set the rootNode to the other graphView rootNode if no rootNode yet
mRootNode = mRootNode ? mRootNode : graph->getRootNode();
return add(graph->getNodes(), false);
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment