From de3026b043f5df8a3db4322e7a05e7896145e653 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 9 Aug 2023 13:51:50 +0000 Subject: [PATCH] [Bug] Solve memory leak induced by circular reference using shared_ptr between Node/Node and Node/GRaphView" - Children Nodes are referenced by weak_ptr in Node - GraphViews are referenced by weak_ptr in Node --- include/aidge/graph/Node.hpp | 64 ++++++++++++++++++++-------------- src/graph/Node.cpp | 67 ++++++++++++++++++++---------------- 2 files changed, 77 insertions(+), 54 deletions(-) diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index fabbe5845..8c0216e5d 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -34,13 +34,23 @@ class GraphView; */ class Node : public std::enable_shared_from_this<Node> { private: + struct weakCompare { + bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const { + // Compare the content of the weak_ptrs + auto sharedA = a.lock(); + auto sharedB = b.lock(); + if (!sharedB) return false; // nothing after expired pointer + if (!sharedA) return true; + return sharedA < sharedB; // Assuming GraphView has a valid comparison operator + } + }; std::string mName; /** Name of the Node. Should be unique. */ - std::set<std::shared_ptr<GraphView>> mViews = std::set<std::shared_ptr<GraphView>>(); /** Set of pointers to GraphView instances including this Node instance. */ + std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */ const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator std::vector<NodePtr> mParents; /** List of parent node for each input (Parent --> Node --> Child) */ - std::vector<std::vector<NodePtr>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */ + std::vector<std::vector<std::weak_ptr<Node>>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */ std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */ std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */ @@ -70,7 +80,7 @@ public: * @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index. * @return Connector */ - Connector operator()(const std::vector<Connector> ctors); + Connector operator()(const std::vector<Connector> &ctors); public: /////////////////////////////////////////////////////// @@ -131,14 +141,14 @@ public: /** * @brief List of pair <Parent, ID of the data intput>. When an input is not * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. - * @return std::vector<std::pair<NodePtr, IOIndex_t>> + * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; /** * @brief List of pair <Parent, ID of the parent output>. When an input is not linked * to any Parent, the pair is <nullptr, gk_IODefaultIndex>. - * @return std::vector<std::pair<NodePtr, IOIndex_t>> + * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; @@ -146,7 +156,7 @@ public: * @brief Parent and its output Tensor ID linked to the inID-th input Tensor. * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. * @param inID - * @return std::pair<NodePtr, IOIndex_t> + * @return std::pair<std::shared_ptr<Node>, IOIndex_t> */ inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const { assert((inID != gk_IODefaultIndex) && (inID < nbInputs()) && "Input index out of bound."); @@ -178,19 +188,19 @@ public: /** * @brief List input ids of children liked to outputs of the node - * @return std::vector<std::vector<std::pair<NodePtr, + * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; /** - * @brief Children and their input Tensor ID linked to the outID-th output + * @brief Children and their input Tensor ID linked to the outId-th output * Tensor. - * @param outID - * @return std::vector<std::pair<NodePtr, IOIndex_t>> + * @param outId + * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> - output(IOIndex_t outID) const; + output(IOIndex_t outId) const; /** * @brief Number of inputs, including both data and learnable parameters. @@ -231,7 +241,11 @@ public: * @return std::vector<GraphView> */ inline std::set<std::shared_ptr<GraphView>> views() const noexcept { - return mViews; + std::set<std::shared_ptr<GraphView>> res; + for (const auto &v : mViews) { + res.insert(v.lock()); + } + return res; } /** @@ -239,14 +253,14 @@ public: * the current Node. This feature allows transparent GraphViews. * @param graphPtr Pointer to GraphView to add to the list. */ - inline void addView(const std::shared_ptr<GraphView> graphPtr) { - mViews.insert(graphPtr); + inline void addView(const std::shared_ptr<GraphView> &graphPtr) { + mViews.insert(std::weak_ptr<GraphView>(graphPtr)); } - inline void removeView(const std::shared_ptr<GraphView> graphPtr) { - if (mViews.find(graphPtr) != mViews.end()) { - mViews.erase(graphPtr); - } + inline void removeView(const std::shared_ptr<GraphView> &graphPtr) { + std::set<std::weak_ptr<GraphView>, weakCompare>::const_iterator viewIt = mViews.cbegin(); + for (; (viewIt != mViews.cend()) && ((*viewIt).lock() != graphPtr) ; ++viewIt) {} + mViews.erase(*viewIt); } /** @@ -280,14 +294,14 @@ public: /** * @brief Get the list of parent Nodes. As an input is linked to a unique Node, * if none is linked then the parent is a nullptr. - * @return std::vector<NodePtr> + * @return std::vector<std::shared_ptr<Node>> */ std::vector<NodePtr> getParents() const; /** * @brief Get the pointer to parent of the specified input index. This pointer is nullptr if no parent is linked. * @param inId Input index. - * @return NodePtr& + * @return std::shared_ptr<Node>& */ inline NodePtr &getParents(const IOIndex_t inId) { assert(inId != gk_IODefaultIndex); @@ -298,7 +312,7 @@ public: * @brief Unlink the parent Node at the specified input index and return its pointer. * Return a nullptr is no parent was linked. * @param inId Input index. - * @return NodePtr + * @return std::shared_ptr<Node> */ NodePtr popParent(const IOIndex_t inId); @@ -308,7 +322,7 @@ public: * @brief Get the set of pointers to children Nodes linked to the current Node.object. * @details The returned set does not include any nullptr as an output maybe linked to * an undifined number of Nodes. It does not change the computation of its associated Operator. - * @return std::set<NodePtr>> + * @return std::set<std::shared_ptr<Node>>> */ std::set<NodePtr> getChildren() const; @@ -317,14 +331,14 @@ public: /** * @brief Get the list of children Nodes linked to the output at specified index. * @param outId Output index. - * @return std::vector<NodePtr> + * @return std::vector<std::shared_ptr<Node>> */ - std::vector<NodePtr> getChildren(const IOIndex_t outID) const; + std::vector<NodePtr> getChildren(const IOIndex_t outId) const; /** * @brief Remove registered child from children list of specified output if possible. * If so, also remove current Node from child Node from parent. - * @param nodePtr Node to remove. + * @param std::shared_ptr<Node> Node to remove. * @param outId Output index. Default 0. * @return true Child found and removed for given output index. * @return false Child not found at given index. Nothing removed. diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5568e4b59..286ed7136 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -21,8 +21,8 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name) : mName((name == nullptr) ? std::string() : std::string(name)), mOperator(op), mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), - mChildren(std::vector<std::vector<std::shared_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), - std::vector<std::shared_ptr<Node>>())), + mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), + std::vector<std::weak_ptr<Node>>())), mIdInChildren( std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { @@ -33,7 +33,7 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name) // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// -Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> ctors) { +Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { assert((ctors.size() == nbDataInputs()) && "Wrong number of arguments.\n"); for (__attribute__((unused)) std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); @@ -134,12 +134,12 @@ Aidge::Node::outputs() const { } std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::output(Aidge::IOIndex_t outID) const { +Aidge::Node::output(Aidge::IOIndex_t outId) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outID].size()); - for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); + for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { listOutputs[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outID][i], mIdInChildren[outID][i]); + std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); } return listOutputs; } @@ -161,7 +161,7 @@ Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const { return counter; } -void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID) { +void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) { assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index"); if (mIdOutParents[inId] != gk_IODefaultIndex) { std::printf("Warning: filling a Tensor already attributed\n"); @@ -171,7 +171,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID) // find first occurence of child in the output's children originalParent.first->removeChild(shared_from_this(), originalParent.second); } - mIdOutParents[inId] = newNodeOutID; + mIdOutParents[inId] = newNodeoutId; } /////////////////////////////////////////////////////// @@ -179,9 +179,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID) /////////////////////////////////////////////////////// void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { - assert((otherInId != gk_IODefaultIndex) && (otherInId < otherNode->nbInputs()) && - "Input index out of bound."); - assert((outId != gk_IODefaultIndex) && (outId < nbOutputs()) && "Output index out of bound."); + assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); + assert((outId < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second != gk_IODefaultIndex) { std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId); } @@ -189,24 +188,22 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou otherNode->setInputId(otherInId, outId); otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId)); // manage nodes - mChildren[outId].push_back(otherNode); + mChildren[outId].push_back(std::weak_ptr<Node>(otherNode)); mIdInChildren[outId].push_back(otherInId); otherNode->addParent(shared_from_this(), otherInId); } -void Aidge::Node::addChildView(std::shared_ptr<GraphView> other_graph, const IOIndex_t outID, +void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { - assert((otherInId.second != gk_IODefaultIndex) && - (otherInId.second < otherInId.first->nbInputs()) && - "Other graph input index out of bound."); - assert((outID != gk_IODefaultIndex) && (outID < nbOutputs()) && "Output index out of bound."); - std::set<std::shared_ptr<Node>> inNodes = other_graph->inputNodes(); + assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); + assert((outId < nbOutputs()) && "Output index out of bound."); + std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node - addChildOp(otherInId.first, outID, otherInId.second); + addChildOp(otherInId.first, outId, otherInId.second); } } @@ -256,24 +253,36 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) { std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const std::vector<std::shared_ptr<Node>> &childrenOfOneOutput : mChildren) { - children.insert(childrenOfOneOutput.begin(), childrenOfOneOutput.end()); + for (const auto &childrenOfOneOutput : mChildren) { + for (const auto &oneChild : childrenOfOneOutput) { + children.insert(oneChild.lock()); + } } return children; } -std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { return mChildren; } +std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { + std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); + for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { + children[outId] = getChildren(outId); + } + return children; +} -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outID) const { - assert((outID != gk_IODefaultIndex) && (outID < nbOutputs()) && "Output index out of bound."); - return mChildren[outID]; +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { + assert((outId < nbOutputs()) && "Output index out of bound."); + std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); + for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { + children.push_back(mChildren[outId][i].lock()); + } + return children; } bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { - assert((outId != gk_IODefaultIndex) && (outId < nbOutputs()) && "Child index out of bound."); + assert((outId < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { - if (mChildren[outId][j] == nodePtr) { + if (mChildren[outId][j].lock() == nodePtr) { mChildren[outId].erase(mChildren[outId].begin() + j); mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j); removed = true; @@ -301,7 +310,7 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) { child.first->removeParent(child.second); } - mChildren[i] = std::vector<std::shared_ptr<Node>>(); + mChildren[i] = std::vector<std::weak_ptr<Node>>(); mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to -- GitLab