Skip to content
Snippets Groups Projects
Commit dd02d5ad authored by Octave Perrin's avatar Octave Perrin
Browse files

continuation GraphView

parent f7cafc75
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !245. Comments created here will be created in the context of that merge request.
......@@ -240,7 +240,7 @@ public:
*/
inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/** @todo here i am
/** @todo the nullptr behavior is strange: how is a node defined as an "input node" of the graph?
* @brief List outside input connections of the GraphView. The vector
* size is guaranteed to match the number of outside inputs of the GraphView. If there is
* no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
......@@ -248,13 +248,14 @@ public:
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
/**
/** @todo having two inputs function, one without args that treat graph's inpput and one with args treating a node's input is fishy
* @TODO @warning what if the node isn't found? where is the try catch of the .at?
* @brief List all input connections (within and outside) of the specified GraphView node named "name".
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs(const std::string& name) const;
/**
/** @todo weird things happening here with this outsideoutputpos
* @brief List outside output connections of the GraphView. The vector
* size is guaranteed to match the number of outputs of the GraphView. If there is
* no connection to a given output, the corresponding sub-vector will be empty.
......@@ -269,7 +270,6 @@ public:
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
const std::string& nodeName) const;
/**
* @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent.
* If not, apply the required transformations.
......@@ -282,13 +282,16 @@ public:
* compatible with the selected kernel.
* If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators.
@params string: backend @todo: explain params
@params Aidge Datatype: datatype
@params vector of vector of DimSize_t: dims
*/
void compile(const std::string& backend = "cpu",
const Aidge::DataType datatype = DataType::Float32,
DeviceIdx_t device = 0,
const std::vector<std::vector<DimSize_t>> dims = {});
/**
/** @todo naming the input dims, while there exist currentTensorPtr->dims is a nice way to mess with reader's head maybe some proper variable name would be great
* @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes, by calling Node::forwardDims().
* This function verifies the following conditions:
......@@ -299,19 +302,27 @@ public:
*/
bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
/** @brief Set the same backend for each Operator of the GraphView object's Nodes.
* @param string: backend Backend name to be set
* @param DeviceIdx_t: device Backend device to be set
*/
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
/** @brief Set the same data type for each Operator of the GraphView object's Nodes. */
/** @brief Set the same data type for each Operator of the GraphView object's Nodes.
* @param DataType: datatype DataType to be set
*/
void setDataType(const DataType& datatype) const;
/** @brief Set the same data format for each Operator of the GraphView object's Nodes. */
/** @brief Set the same data format for each Operator of the GraphView object's Nodes.
* @param DataFormat: dataformat DataFormat to be set
*/
void setDataFormat(const DataFormat& dataformat) const;
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
//@todo 50 shades of get
public:
/**
* @brief Get the parents Nodes of inputNodes.
* @brief Get the parents Nodes of inputNodes of the graph.
* @return std::set<NodePtr>
*/
std::set<NodePtr> getParents() const;
......@@ -344,8 +355,7 @@ public:
inline const std::set<NodePtr>& getNodes() const noexcept { return mNodes; }
/**
* @brief Get the operator with the corresponding name if it is in the
* GraphView.
* @brief Get the Node with the corresponding name if it is in the GraphView.
* @param nodeName Name of the node.
* @return NodePtr returns a nullptr if the one asked for
* was not found.
......@@ -370,6 +380,7 @@ public:
*/
std::pair<std::vector<NodePtr>, size_t> getRankedNodes() const;
//@todo here i am
/**
* Get the nodes name according to the GraphView nodes ranking.
* @param format The formatting string to be used with fmt::format().
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment