Skip to content
Snippets Groups Projects
Commit c01e57f1 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Minor changes

parent 13d0638c
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
...@@ -98,6 +98,11 @@ public: ...@@ -98,6 +98,11 @@ public:
*/ */
void save(std::string path, bool verbose = false, bool showProducers = true) const; void save(std::string path, bool verbose = false, bool showProducers = true) const;
/**
* Check that a node is in the current GraphView.
* @param nodePtr Node to check
* @return bool True is nodePtr belongs to the GraphView.
*/
inline bool inView(NodePtr nodePtr) const { inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end(); return mNodes.find(nodePtr) != mNodes.end();
} }
......
...@@ -401,6 +401,8 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/, ...@@ -401,6 +401,8 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
} }
void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) { void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
AIDGE_ASSERT(node != nullptr, "Trying to add non-existant node!");
// first node to be added to the graph is the root node by default // first node to be added to the graph is the root node by default
if (mRootNode == nullptr) { if (mRootNode == nullptr) {
mRootNode = node; mRootNode = node;
...@@ -442,7 +444,7 @@ std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes( ...@@ -442,7 +444,7 @@ std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes(
for (auto childs : curNode->getOrderedChildren()) { for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) { for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) { if (child != nullptr && nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child); rankedNodes.push_back(child);
nodesToRank.erase(child); nodesToRank.erase(child);
} }
...@@ -450,7 +452,7 @@ std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes( ...@@ -450,7 +452,7 @@ std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes(
} }
for (auto parent : curNode->getParents()) { for (auto parent : curNode->getParents()) {
if (nodesToRank.find(parent) != nodesToRank.end()) { if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
rankedNodes.push_back(parent); rankedNodes.push_back(parent);
nodesToRank.erase(parent); nodesToRank.erase(parent);
} }
...@@ -538,7 +540,7 @@ bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl ...@@ -538,7 +540,7 @@ bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl
for (auto childs : curNode->getOrderedChildren()) { for (auto childs : curNode->getOrderedChildren()) {
for (auto child : childs) { for (auto child : childs) {
if (nodesToRank.find(child) != nodesToRank.end()) { if (child != nullptr && nodesToRank.find(child) != nodesToRank.end()) {
rankedNodes.push_back(child); rankedNodes.push_back(child);
nodesToRank.erase(child); nodesToRank.erase(child);
...@@ -551,7 +553,7 @@ bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl ...@@ -551,7 +553,7 @@ bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl
} }
for (auto parent : curNode->getParents()) { for (auto parent : curNode->getParents()) {
if (nodesToRank.find(parent) != nodesToRank.end()) { if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
rankedNodes.push_back(parent); rankedNodes.push_back(parent);
nodesToRank.erase(parent); nodesToRank.erase(parent);
......
...@@ -152,6 +152,7 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { ...@@ -152,6 +152,7 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
} }
for (IOIndex_t i = nbData(); i < nbInputs(); ++i) { for (IOIndex_t i = nbData(); i < nbInputs(); ++i) {
AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type());
getInput(i)->setDataType(dataType); getInput(i)->setDataType(dataType);
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment