Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_GRAPH_GRAPHVIEW_H__
#define __AIDGE_CORE_GRAPH_GRAPHVIEW_H__
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "graph/Node.hpp"
#include "utils/Types.h"
namespace Aidge {
enum class DataType;
class GraphView : public std::enable_shared_from_this<GraphView> {
private:
/// @brief Name of the graphview
std::string mName;
/// @brief Set of nodes included in the GraphView
std::set<NodePtr> mNodes;
/// @brief Set of nodes included in the graphview with names
std::map<std::string, NodePtr> mNodeRegistry;
/// @brief Nodes without input link
std::set<NodePtr> mInputNodes;
/// @brief Nodes without output link
std::set<NodePtr> mOutputNodes;
public:
GraphView(std::string name="")
: mName(name)
{
// ctor
}
GraphView(std::set<NodePtr> nodes, std::string name="")
: mName(name)
{
add(nodes);
}
bool operator==(const GraphView &gv) const
{
return mNodes == gv.mNodes;
}
NodePtr operator[](std::string name)
{
assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
return mNodeRegistry.at(name);
}
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Connector operator()(const std::vector<Connector> ctors);
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
public:
/**
* @brief Name of the node.
* @return std::string
*/
std::string name() const;
/**
* @brief Set the node name.
* @warning Undefined behaviour when several Nodes have the same name.
* @param name New name for the node.
*/
void setName(const std::string &name);
/**
* @brief Save the GraphView as a Mermaid graph in a .md file at the
* specified location.
* @param path
*/
void save(std::string path, bool verbose = false) const;
inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end();
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
public:
inline std::set<NodePtr> inputNodes() const noexcept { return mInputNodes; }
inline std::set<NodePtr> outputNodes() const noexcept { return mOutputNodes; }
inline bool isInputNode(NodePtr nodePtr) const {
return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false;
}
inline bool isOutputNode(NodePtr nodePtr) const {
return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false;
}
/**
* @brief List data input Tensors of the graph input nodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
/**
* @brief List data input Tensors of the graph input nodes.
* @param name Name of the Node.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
inline auto dataInputs(std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/**
* @brief List input Tensors of the graph input nodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const;
/**
* @brief List output Tensors of the node.
* @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;
/**
* @brief Specific i-th output Tensor of the GraphView.
* @param nodeName Name of the Node of which to show the output.
* @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
std::string nodeName) const;
void forwardDims();
void setBackend(const std::string &backend);
void setDatatype(const DataType &datatype);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
public:
/**
* @brief Get the Parents of inputNodes.
* @return std::vector<NodePtr>
*/
std::set<NodePtr> getParents() const;
std::vector<NodePtr> getParents(const std::string nodeName) const;
std::vector<std::vector<NodePtr>> getOrderedParents() const;
/**
* @brief Get the Children of outputNodes.
* @return std::set<NodePtr>
*/
std::set<NodePtr> getChildren() const;
std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const;
std::set<NodePtr> getChildren(
const NodePtr otherNode) const; // TODO change it for a vector<vector> ?
/**
* @brief Getter for Operators of the GraphView.
* @return std::set<NodePtr>
*/
inline std::set<NodePtr> getNodes() const { return mNodes; }
/**
* @brief Get the operator with the corresponding name if it is in the
* GraphView.
* @param nodeName name of the node.
* @return NodePtr return a new empty node if the one asked for
* was not found.
*/
NodePtr getNode(const char *nodeName) const;
/**
* @brief Remove a Node from the current GraphView scope without affecting its connections
* @param nodePtr Node to remove
* @param includeLearnableParam Whether learnable parameters should also be removed. Default true.
*/
void remove(NodePtr nodePtr, bool includeLearnableParam = true);
// Surrounding nodes management
void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID);
/**
* @brief Includes a Node to the current GraphView
* @param other_node Node to add.
* @param includeLearnableParam Should non-data inputs, like weights and biases
* be included in the GraphView automatically. Default: true.
*/
void add(NodePtr otherNode, bool includeLearnableParam = true);
void add(std::set<NodePtr> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include every Node inside another GraphView to the current
* GraphView.
* @param other_graph GraphView containing the Nodes to include.
*/
void add(std::shared_ptr<GraphView> otherGraph);
/**
* @brief Include a Node in the current GraphView and link it to another
* already contained Node.
*
* @param toOtherNode Pointer to the Node to add.
* @param fromOutNode Pointer to the already included Node the new Node will
* be linked to (it will become a parent of the new Node). If the GraphView
* only has one output Node, then default to this Node.
* @param fromTensor Ouput Tensor ID of the already included Node. Default to
* 0.
* @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
* first available data input for the Node.
*/
void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr,
const IOIndex_t fromTensor = IOIndex_t(0),
IOIndex_t toTensor = gk_IODefaultIndex);
/**
* @brief Include a Node in the current GraphView and link it to another
* already contained Node.
*
* @param toOtherNode Pointer to the Node to add.
* @param fromOutNodeName Name of the already included Node the new Node will
* be linked to (it will become a parent of the new Node). As a name is
* optional, ensure such Node is in the GraphView or it will send back an
* error message.
* @param fromTensor Ouput Tensor ID of the already included Node. Default to
* 0.
* @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
* first available data input for the Node.
*/
inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName,
const IOIndex_t fromTensor = IOIndex_t(0),
IOIndex_t toTensor = gk_IODefaultIndex) {
assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() &&
"No Node with this name found in the GraphView.");
addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
}
/**
* @brief Include a GraphView content in the current GraphView and link
* the two sets by linking one Node from each GraphView.
* @param toOtherView Pointer to the GraphView whose content should be added.
* @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the
* connection. If the GraphView including the other one has only one output
* Node, then it defaults to the first output Tensor of this Node.
* @param toNode Pair of pointer to Node and Tensor ID for specifying the
* connection. If the GraphView whose content is included has only one input
* Node, then it defaults to the first available data input Tensor of this
* Node.
*/
void addChild(std::shared_ptr<GraphView> toOtherView,
std::pair<NodePtr, IOIndex_t> fromOutNode =
std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)),
std::pair<NodePtr, IOIndex_t> toNode =
std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));
/**
* @brief Swap two Node instances if possible.
* @param node
* @param otherNode
* @return true
* @return false
*/
bool swap(Node &node, Node &otherNode);
void link(std::string name1_inID, std::string name2_outID);
void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes,
IOIndex_t tensorIdx);
/**
* @brief Replace the current GraphView with the set of given Nodes if possible
* @param newNodes Set of Nodes.
* @return true
* @return false
*/
bool replaceWith(std::set<NodePtr> newNodes);
void updateInputNodes();
/**
* @brief Process from zero the set of output Nodes.
*/
void updateOutputNodes();
private:
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
IONb_t getNbDataInputs() const;
IONb_t getNbFreeDataInputs() const;
void updateInputNodes(NodePtr node);
/**
* @brief Update the set of output Nodes with a new Node,checking if it can be
* added and removing any Node not part of mOutputNode anymore.
* @param nodePtr
*/
void updateOutputNodes(NodePtr node);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes);
void removeInputNode(const std::string nodeName);
void removeOutputNode(const std::string nodeName);
};
} // namespace Aidge
#endif /* __AIDGE_CORE_GRAPH_GRAPHVIEW_H__ */