From de3026b043f5df8a3db4322e7a05e7896145e653 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 9 Aug 2023 13:51:50 +0000
Subject: [PATCH] [Bug] Solve memory leak induced by circular reference using
 shared_ptr between Node/Node and Node/GRaphView"

- Children Nodes are referenced by weak_ptr in Node
- GraphViews are referenced by weak_ptr in Node
---
 include/aidge/graph/Node.hpp | 64 ++++++++++++++++++++--------------
 src/graph/Node.cpp           | 67 ++++++++++++++++++++----------------
 2 files changed, 77 insertions(+), 54 deletions(-)

diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp
index fabbe5845..8c0216e5d 100644
--- a/include/aidge/graph/Node.hpp
+++ b/include/aidge/graph/Node.hpp
@@ -34,13 +34,23 @@ class GraphView;
  */
 class Node : public std::enable_shared_from_this<Node> {
 private:
+  struct weakCompare {
+      bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const {
+          // Compare the content of the weak_ptrs
+          auto sharedA = a.lock();
+          auto sharedB = b.lock();
+          if (!sharedB) return false; // nothing after expired pointer 
+          if (!sharedA) return true;
+          return sharedA < sharedB; // Assuming GraphView has a valid comparison operator
+      }
+  };
   std::string mName; /** Name of the Node. Should be unique. */
 
-  std::set<std::shared_ptr<GraphView>> mViews = std::set<std::shared_ptr<GraphView>>(); /** Set of pointers to GraphView instances including this Node instance. */
+  std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */
   const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator
 
   std::vector<NodePtr> mParents; /** List of parent node for each input (Parent --> Node --> Child) */
-  std::vector<std::vector<NodePtr>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */
+  std::vector<std::vector<std::weak_ptr<Node>>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */
   std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */
   std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */
 
@@ -70,7 +80,7 @@ public:
    * @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index.
    * @return Connector 
    */
-  Connector operator()(const std::vector<Connector> ctors);
+  Connector operator()(const std::vector<Connector> &ctors);
 
 public:
   ///////////////////////////////////////////////////////
@@ -131,14 +141,14 @@ public:
   /**
    * @brief List of pair <Parent, ID of the data intput>. When an input is not
    * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
-   * @return std::vector<std::pair<NodePtr, IOIndex_t>>
+   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
    */
   std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
 
   /**
    * @brief List of pair <Parent, ID of the parent output>. When an input is not linked
    * to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
-   * @return std::vector<std::pair<NodePtr, IOIndex_t>>
+   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
    */
   std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
 
@@ -146,7 +156,7 @@ public:
    * @brief Parent and its output Tensor ID linked to the inID-th input Tensor.
    * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
    * @param inID
-   * @return std::pair<NodePtr, IOIndex_t>
+   * @return std::pair<std::shared_ptr<Node>, IOIndex_t>
    */
   inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const {
     assert((inID != gk_IODefaultIndex) && (inID < nbInputs()) && "Input index out of bound.");
@@ -178,19 +188,19 @@ public:
 
   /**
    * @brief List input ids of children liked to outputs of the node
-   * @return std::vector<std::vector<std::pair<NodePtr,
+   * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
    * IOIndex_t>>>
    */
   std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;
 
   /**
-   * @brief Children and their input Tensor ID linked to the outID-th output
+   * @brief Children and their input Tensor ID linked to the outId-th output
    * Tensor.
-   * @param outID
-   * @return std::vector<std::pair<NodePtr, IOIndex_t>>
+   * @param outId
+   * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
    */
   std::vector<std::pair<NodePtr, IOIndex_t>>
-  output(IOIndex_t outID) const;
+  output(IOIndex_t outId) const;
 
   /**
    * @brief Number of inputs, including both data and learnable parameters.
@@ -231,7 +241,11 @@ public:
    * @return std::vector<GraphView>
    */
   inline std::set<std::shared_ptr<GraphView>> views() const noexcept {
-    return mViews;
+    std::set<std::shared_ptr<GraphView>> res;
+    for (const auto &v : mViews) {
+      res.insert(v.lock());
+    }
+    return res;
   }
 
   /**
@@ -239,14 +253,14 @@ public:
    * the current Node. This feature allows transparent GraphViews.
    * @param graphPtr Pointer to GraphView to add to the list.
    */
-  inline void addView(const std::shared_ptr<GraphView> graphPtr) {
-    mViews.insert(graphPtr);
+  inline void addView(const std::shared_ptr<GraphView> &graphPtr) {
+    mViews.insert(std::weak_ptr<GraphView>(graphPtr));
   }
 
-  inline void removeView(const std::shared_ptr<GraphView> graphPtr) {
-    if (mViews.find(graphPtr) != mViews.end()) {
-      mViews.erase(graphPtr);
-    }
+  inline void removeView(const std::shared_ptr<GraphView> &graphPtr) {
+    std::set<std::weak_ptr<GraphView>, weakCompare>::const_iterator viewIt = mViews.cbegin();
+    for (; (viewIt != mViews.cend()) && ((*viewIt).lock() != graphPtr) ; ++viewIt) {}
+    mViews.erase(*viewIt);
   }
 
   /**
@@ -280,14 +294,14 @@ public:
   /**
    * @brief Get the list of parent Nodes. As an input is linked to a unique Node,
    * if none is linked then the parent is a nullptr.
-   * @return std::vector<NodePtr>
+   * @return std::vector<std::shared_ptr<Node>>
    */
   std::vector<NodePtr> getParents() const;
 
   /**
    * @brief Get the pointer to parent of the specified input index. This pointer is nullptr if no parent is linked.
    * @param inId Input index.
-   * @return NodePtr& 
+   * @return std::shared_ptr<Node>& 
    */
   inline NodePtr &getParents(const IOIndex_t inId) {
     assert(inId != gk_IODefaultIndex);
@@ -298,7 +312,7 @@ public:
    * @brief Unlink the parent Node at the specified input index and return its pointer.
    * Return a nullptr is no parent was linked.
    * @param inId Input index.
-   * @return NodePtr 
+   * @return std::shared_ptr<Node> 
    */
   NodePtr popParent(const IOIndex_t inId);
 
@@ -308,7 +322,7 @@ public:
    * @brief Get the set of pointers to children Nodes linked to the current Node.object.
    * @details The returned set does not include any nullptr as an output maybe linked to
    * an undifined number of Nodes. It does not change the computation of its associated Operator.
-   * @return std::set<NodePtr>>
+   * @return std::set<std::shared_ptr<Node>>>
    */
   std::set<NodePtr> getChildren() const;
 
@@ -317,14 +331,14 @@ public:
   /**
    * @brief Get the list of children Nodes linked to the output at specified index.
    * @param outId Output index.
-   * @return std::vector<NodePtr> 
+   * @return std::vector<std::shared_ptr<Node>> 
    */
-  std::vector<NodePtr> getChildren(const IOIndex_t outID) const;
+  std::vector<NodePtr> getChildren(const IOIndex_t outId) const;
 
   /**
    * @brief Remove registered child from children list of specified output if possible.
    * If so, also remove current Node from child Node from parent.
-   * @param nodePtr Node to remove.
+   * @param std::shared_ptr<Node> Node to remove.
    * @param outId Output index. Default 0.
    * @return true Child found and removed for given output index.
    * @return false Child not found at given index. Nothing removed.
diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp
index 5568e4b59..286ed7136 100644
--- a/src/graph/Node.cpp
+++ b/src/graph/Node.cpp
@@ -21,8 +21,8 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name)
     : mName((name == nullptr) ? std::string() : std::string(name)),
       mOperator(op),
       mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)),
-      mChildren(std::vector<std::vector<std::shared_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()),
-                                                                std::vector<std::shared_ptr<Node>>())),
+      mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()),
+                                                                std::vector<std::weak_ptr<Node>>())),
       mIdInChildren(
               std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())),
       mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) {
@@ -33,7 +33,7 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name)
 //        FUNCTIONAL DESCRIPTION
 ///////////////////////////////////////////////////////
 
-Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> ctors) {
+Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) {
     assert((ctors.size() == nbDataInputs()) && "Wrong number of arguments.\n");
     for (__attribute__((unused)) std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) {
         assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n");
@@ -134,12 +134,12 @@ Aidge::Node::outputs() const {
 }
 
 std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>
-Aidge::Node::output(Aidge::IOIndex_t outID) const {
+Aidge::Node::output(Aidge::IOIndex_t outId) const {
     std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs =
-            std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outID].size());
-    for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) {
+            std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size());
+    for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) {
         listOutputs[i] =
-                std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outID][i], mIdInChildren[outID][i]);
+                std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]);
     }
     return listOutputs;
 }
@@ -161,7 +161,7 @@ Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const {
     return counter;
 }
 
-void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID) {
+void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) {
     assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index");
     if (mIdOutParents[inId] != gk_IODefaultIndex) {
         std::printf("Warning: filling a Tensor already attributed\n");
@@ -171,7 +171,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID)
         // find first occurence of child in the output's children
         originalParent.first->removeChild(shared_from_this(), originalParent.second);
     }
-    mIdOutParents[inId] = newNodeOutID;
+    mIdOutParents[inId] = newNodeoutId;
 }
 
 ///////////////////////////////////////////////////////
@@ -179,9 +179,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeOutID)
 ///////////////////////////////////////////////////////
 
 void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) {
-    assert((otherInId != gk_IODefaultIndex) && (otherInId < otherNode->nbInputs()) &&
-           "Input index out of bound.");
-    assert((outId != gk_IODefaultIndex) && (outId < nbOutputs()) && "Output index out of bound.");
+    assert((otherInId < otherNode->nbInputs()) && "Input index out of bound.");
+    assert((outId < nbOutputs()) && "Output index out of bound.");
     if (otherNode->input(otherInId).second != gk_IODefaultIndex) {
         std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId);
     }
@@ -189,24 +188,22 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou
     otherNode->setInputId(otherInId, outId);
     otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId));
     // manage nodes
-    mChildren[outId].push_back(otherNode);
+    mChildren[outId].push_back(std::weak_ptr<Node>(otherNode));
     mIdInChildren[outId].push_back(otherInId);
     otherNode->addParent(shared_from_this(), otherInId);
 }
 
-void Aidge::Node::addChildView(std::shared_ptr<GraphView> other_graph, const IOIndex_t outID,
+void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId,
                               std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) {
-    assert((otherInId.second != gk_IODefaultIndex) &&
-           (otherInId.second < otherInId.first->nbInputs()) &&
-           "Other graph input index out of bound.");
-    assert((outID != gk_IODefaultIndex) && (outID < nbOutputs()) && "Output index out of bound.");
-    std::set<std::shared_ptr<Node>> inNodes = other_graph->inputNodes();
+    assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound.");
+    assert((outId < nbOutputs()) && "Output index out of bound.");
+    std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes();
     if (inNodes.size() == std::size_t(0)) {  // no input Node
         printf("Cannot add GraphView to the Node. No input node detected.\n");
     } else  // inNodes.size() >= 1
     {
         assert((inNodes.find(otherInId.first) != inNodes.end()));  // assert it really is an input node
-        addChildOp(otherInId.first, outID, otherInId.second);
+        addChildOp(otherInId.first, outId, otherInId.second);
     }
 }
 
@@ -256,24 +253,36 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) {
 
 std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
     std::set<std::shared_ptr<Node>> children;
-    for (const std::vector<std::shared_ptr<Node>> &childrenOfOneOutput : mChildren) {
-        children.insert(childrenOfOneOutput.begin(), childrenOfOneOutput.end());
+    for (const auto &childrenOfOneOutput : mChildren) {
+        for (const auto &oneChild : childrenOfOneOutput) {
+            children.insert(oneChild.lock());
+        }
     }
     return children;
 }
 
-std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { return mChildren; }
+std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const {
+    std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size());
+    for (std::size_t outId = 0; outId < mChildren.size(); ++outId) {
+        children[outId] = getChildren(outId);
+    }
+    return children;
+}
 
-std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outID) const {
-    assert((outID != gk_IODefaultIndex) && (outID < nbOutputs()) && "Output index out of bound.");
-    return mChildren[outID];
+std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const {
+    assert((outId < nbOutputs()) && "Output index out of bound.");
+    std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size());
+    for (std::size_t i = 0; i < mChildren[outId].size(); ++i) {
+            children.push_back(mChildren[outId][i].lock());
+        }
+    return children;
 }
 
 bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) {
-    assert((outId != gk_IODefaultIndex) && (outId < nbOutputs()) && "Child index out of bound.");
+    assert((outId < nbOutputs()) && "Child index out of bound.");
     bool removed = false;
     for (std::size_t j = 0; j < mChildren[outId].size(); ++j) {
-        if (mChildren[outId][j] == nodePtr) {
+        if (mChildren[outId][j].lock() == nodePtr) {
             mChildren[outId].erase(mChildren[outId].begin() + j);
             mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j);
             removed = true;
@@ -301,7 +310,7 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
         for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) {
             child.first->removeParent(child.second);
         }
-        mChildren[i] = std::vector<std::shared_ptr<Node>>();
+        mChildren[i] = std::vector<std::weak_ptr<Node>>();
         mIdInChildren[i] = std::vector<IOIndex_t>();
     }
     // removing this Node from every GraphView it belongs to
-- 
GitLab