Skip to content
Snippets Groups Projects
Commit 1f86559c authored by Octave Perrin's avatar Octave Perrin
Browse files

continuation

parent 7fce53e6
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !231. Comments created here will be created in the context of that merge request.
...@@ -41,16 +41,6 @@ class GraphView; ...@@ -41,16 +41,6 @@ class GraphView;
/** /**
* @brief Object carrying the topological information of the computational graph. * @brief Object carrying the topological information of the computational graph.
* A Node contains :
* - mName: the name of the Node, should be unique
* - mViews: a set of pointers to GraphView instances including this Node instance
* - mOperator: a pointer to the Operator associated to the node
* - mParents: a vector of parent nodes, which are its inputs
* - mIdOutParents: a vector of indexes, which tells for all the parent nodes from which of their output we take the value
* - mChildren: a vector of vector of children nodes, which lists all the recipient nodes, for all of the outputs
* - mIdInChildren: a vector of vector of indexes, which gives for all the recipient nodes in which of their input the current value is taken
* - mforward: ?
* - mbackward: ?
*/ */
class Node : public std::enable_shared_from_this<Node> { class Node : public std::enable_shared_from_this<Node> {
private: private:
...@@ -64,8 +54,7 @@ private: ...@@ -64,8 +54,7 @@ private:
return sharedA < sharedB; // shared_ptr has a valid comparison operator return sharedA < sharedB; // shared_ptr has a valid comparison operator
} }
}; };
std::shared_ptr<DynamicAttributes> mAttrs; std::string mName; /** Name of the Node. Should be unique. */
std::string mName; /** Name of the Node. It should be unique. */
std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */ std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */
const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator
...@@ -79,19 +68,7 @@ private: ...@@ -79,19 +68,7 @@ private:
std::deque<std::function<bool()>> mBackward; std::deque<std::function<bool()>> mBackward;
public: public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
Node() = delete; Node() = delete;
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
/**
* @brief Construct a new Node object associated with the input Operator.
* @param op Operator giving the Node its number of connections.
* @param attrs Attributes for the Node.
*/
Node(std::shared_ptr<Operator> op, std::shared_ptr<DynamicAttributes> attrs);
Node(std::shared_ptr<Operator> op, const DynamicAttributes& attrs);
/** /**
* @brief Construct a new Node object associated with the input Operator. * @brief Construct a new Node object associated with the input Operator.
...@@ -130,6 +107,7 @@ public: ...@@ -130,6 +107,7 @@ public:
/** /**
* @brief Functional operator for user-friendly connection interface using an ordered set of Connectors. * @brief Functional operator for user-friendly connection interface using an ordered set of Connectors.
* @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index. * @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index.
* length of ctors must be lower than the number of input of the Node
* @return Connector * @return Connector
*/ */
Connector operator()(const std::vector<Connector> &ctors); Connector operator()(const std::vector<Connector> &ctors);
...@@ -143,7 +121,7 @@ public: ...@@ -143,7 +121,7 @@ public:
* @brief Name of the Node. * @brief Name of the Node.
* @return std::string * @return std::string
*/ */
inline std::string name() const noexcept { return (mAttrs->hasAttr("name")) ? mAttrs->getAttr<std::string>("name") : ""; } inline std::string name() const noexcept { return mName; }
/** /**
* @brief Set the Node name. * @brief Set the Node name.
...@@ -155,8 +133,8 @@ public: ...@@ -155,8 +133,8 @@ public:
/** /**
* @brief Given the parameter name generate a new name which is unique * @brief Given the parameter name generate a new name which is unique
* in all the GraphView which contains this node. * in all the GraphView which contains this node.
* To generate the new name the method is called recursively and append * To generate the new name the method appends
* the caracter ``_``. * the caracter ``_`` until it is unique.
* If no duplicate return name, this is the exit condition. * If no duplicate return name, this is the exit condition.
* @param name Base name to make unique. * @param name Base name to make unique.
* @return A unique name in all the GraphView which contains this one. * @return A unique name in all the GraphView which contains this one.
...@@ -187,7 +165,7 @@ public: ...@@ -187,7 +165,7 @@ public:
* @brief Get the Operator object of the Node. * @brief Get the Operator object of the Node.
* @return std::shared_ptr<Operator> * @return std::shared_ptr<Operator>
*/ */
inline std::shared_ptr<Operator> getOperator() const { return (*mOperator)(mAttrs); } inline std::shared_ptr<Operator> getOperator() const { return mOperator; }
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TENSOR MANAGEMENT // TENSOR MANAGEMENT
...@@ -246,7 +224,10 @@ public: ...@@ -246,7 +224,10 @@ public:
return (i < nbInputs()) ? i : gk_IODefaultIndex; return (i < nbInputs()) ? i : gk_IODefaultIndex;
} }
/**
* @brief Returns the number of free data inputs of the Node
* @return IOIndex_t
*/
IOIndex_t getNbFreeDataInputs() const; IOIndex_t getNbFreeDataInputs() const;
/** /**
...@@ -334,6 +315,11 @@ public: ...@@ -334,6 +315,11 @@ public:
mViews.insert(std::weak_ptr<GraphView>(graphPtr)); mViews.insert(std::weak_ptr<GraphView>(graphPtr));
} }
/**
* @brief Remove a GraphView pointer from the list of GraphView containing
* the current Node.
* @param graphPtr Pointer to GraphView to remove from the list.
*/
inline void removeView(const std::shared_ptr<GraphView> &graphPtr) { inline void removeView(const std::shared_ptr<GraphView> &graphPtr) {
mViews.erase(graphPtr); mViews.erase(graphPtr);
} }
...@@ -344,7 +330,7 @@ public: ...@@ -344,7 +330,7 @@ public:
* @param outId ID of the current Node output to connect to the other Node. * @param outId ID of the current Node output to connect to the other Node.
* Default to 0. * Default to 0.
* @param otherInId ID of the other Node input to connect to the current Node. * @param otherInId ID of the other Node input to connect to the current Node.
* Default to the first avaible data input. * Default to the first available data input.
*/ */
void addChild(NodePtr otherNode, void addChild(NodePtr otherNode,
const IOIndex_t outId = IOIndex_t(0), const IOIndex_t outId = IOIndex_t(0),
...@@ -391,6 +377,11 @@ public: ...@@ -391,6 +377,11 @@ public:
*/ */
NodePtr popParent(const IOIndex_t inId); NodePtr popParent(const IOIndex_t inId);
/**
* @brief unlinks the parent from the Node and replaces it with nullptr (for coherence with remaining parents)
* @param inId Input index of the parent to be removed
* @return std::bool true if parent has been removed.
*/
bool removeParent(const IOIndex_t inId); bool removeParent(const IOIndex_t inId);
/** /**
...@@ -401,18 +392,22 @@ public: ...@@ -401,18 +392,22 @@ public:
*/ */
std::set<NodePtr> getChildren() const; std::set<NodePtr> getChildren() const;
/**
* @brief Get all sets of children of the node, grouped by which output of the Node they come from
* @returns std::vector<std::vector<std::shared_ptr<Node>>>
*/
std::vector<std::vector<NodePtr>> getOrderedChildren() const; std::vector<std::vector<NodePtr>> getOrderedChildren() const;
/** /**
* @brief Get the list of children Nodes linked to the output at specified index. * @brief Get the list of children Nodes linked to the output at specified index.
* @param outId Output index. * @param outId Output index.
* @return std::vector<std::shared_ptr<Node>> * @return std::vector<std::shared_ptr<Node>>
*/ */>
std::vector<NodePtr> getChildren(const IOIndex_t outId) const; std::vector<NodePtr> getChildren(const IOIndex_t outId) const;
/** /**
* @brief Remove registered child from children list of specified output if possible. * @brief Remove registered child from children list of specified output if possible.
* If so, also remove current Node from child Node from parent. * If so, also remove current Node from child's parent.
* @param std::shared_ptr<Node> Node to remove. * @param std::shared_ptr<Node> Node to remove.
* @param outId Output index. Default 0. * @param outId Output index. Default 0.
* @return true Child found and removed for given output index. * @return true Child found and removed for given output index.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment