Skip to content
Snippets Groups Projects
Commit 8aa8f001 authored by Vincent Templier's avatar Vincent Templier
Browse files

Add getParent method to Node class

parent 6b2df066
No related branches found
No related tags found
No related merge requests found
...@@ -303,7 +303,7 @@ public: ...@@ -303,7 +303,7 @@ public:
* @param inId Input index. * @param inId Input index.
* @return std::shared_ptr<Node>& * @return std::shared_ptr<Node>&
*/ */
inline NodePtr &getParents(const IOIndex_t inId) { inline NodePtr &getParent(const IOIndex_t inId) {
assert(inId != gk_IODefaultIndex); assert(inId != gk_IODefaultIndex);
return mParents.at(inId); return mParents.at(inId);
} }
......
...@@ -326,7 +326,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara ...@@ -326,7 +326,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
// add learnable parameters to the graph // add learnable parameters to the graph
if (includeLearnableParam) { if (includeLearnableParam) {
for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) {
std::shared_ptr<Node> parentNode = node->getParents(static_cast<IOIndex_t>(i)); std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
if (parentNode) { if (parentNode) {
parentNode->addView(shared_from_this()); parentNode->addView(shared_from_this());
mNodes.insert(parentNode); mNodes.insert(parentNode);
......
...@@ -226,7 +226,7 @@ void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t ...@@ -226,7 +226,7 @@ void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t
} }
void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) { void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) {
if (getParents(inId) != nullptr) { if (getParent(inId) != nullptr) {
printf("Warning, you're replacing a Parent.\n"); printf("Warning, you're replacing a Parent.\n");
} }
assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound."); assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound.");
......
...@@ -59,12 +59,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -59,12 +59,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Step 2 : Branch existing producers & create the others // Step 2 : Branch existing producers & create the others
// link weights & bias // link weights & bias
if (matmul->getParents(1)==nullptr) { if (matmul->getParent(1)==nullptr) {
matmul->getParents(0)->addChild(fc, 0, 1); matmul->getParent(0)->addChild(fc, 0, 1);
} else { } else {
if (matmul->getParents(0)!=nullptr) if (matmul->getParent(0)!=nullptr)
matmul->getParents(0)->addChild(fc, 0, 0); matmul->getParent(0)->addChild(fc, 0, 0);
matmul->getParents(1)->addChild(fc, 0, 1); matmul->getParent(1)->addChild(fc, 0, 1);
} }
(producer_add_bias.first)->addChild(fc,0,2); (producer_add_bias.first)->addChild(fc,0,2);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment