Skip to content
Snippets Groups Projects
Commit dd02d5ad authored by Octave Perrin's avatar Octave Perrin
Browse files

continuation GraphView

parent f7cafc75
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !245. Comments created here will be created in the context of that merge request.
...@@ -240,7 +240,7 @@ public: ...@@ -240,7 +240,7 @@ public:
*/ */
inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/** @todo here i am /** @todo the nullptr behavior is strange: how is a node defined as an "input node" of the graph?
* @brief List outside input connections of the GraphView. The vector * @brief List outside input connections of the GraphView. The vector
* size is guaranteed to match the number of outside inputs of the GraphView. If there is * size is guaranteed to match the number of outside inputs of the GraphView. If there is
* no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
...@@ -248,13 +248,14 @@ public: ...@@ -248,13 +248,14 @@ public:
*/ */
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
/** /** @todo having two inputs function, one without args that treat graph's inpput and one with args treating a node's input is fishy
* @TODO @warning what if the node isn't found? where is the try catch of the .at?
* @brief List all input connections (within and outside) of the specified GraphView node named "name". * @brief List all input connections (within and outside) of the specified GraphView node named "name".
* @return std::vector<std::pair<NodePtr, IOIndex_t>> * @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/ */
std::vector<std::pair<NodePtr, IOIndex_t>> inputs(const std::string& name) const; std::vector<std::pair<NodePtr, IOIndex_t>> inputs(const std::string& name) const;
/** /** @todo weird things happening here with this outsideoutputpos
* @brief List outside output connections of the GraphView. The vector * @brief List outside output connections of the GraphView. The vector
* size is guaranteed to match the number of outputs of the GraphView. If there is * size is guaranteed to match the number of outputs of the GraphView. If there is
* no connection to a given output, the corresponding sub-vector will be empty. * no connection to a given output, the corresponding sub-vector will be empty.
...@@ -269,7 +270,6 @@ public: ...@@ -269,7 +270,6 @@ public:
*/ */
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs( std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
const std::string& nodeName) const; const std::string& nodeName) const;
/** /**
* @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent. * @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent.
* If not, apply the required transformations. * If not, apply the required transformations.
...@@ -282,13 +282,16 @@ public: ...@@ -282,13 +282,16 @@ public:
* compatible with the selected kernel. * compatible with the selected kernel.
* If not, add a Transpose Operator. * If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators. * 4 - Propagate Tensor dimensions through the consecutive Operators.
@params string: backend @todo: explain params
@params Aidge Datatype: datatype
@params vector of vector of DimSize_t: dims
*/ */
void compile(const std::string& backend = "cpu", void compile(const std::string& backend = "cpu",
const Aidge::DataType datatype = DataType::Float32, const Aidge::DataType datatype = DataType::Float32,
DeviceIdx_t device = 0, DeviceIdx_t device = 0,
const std::vector<std::vector<DimSize_t>> dims = {}); const std::vector<std::vector<DimSize_t>> dims = {});
/** /** @todo naming the input dims, while there exist currentTensorPtr->dims is a nice way to mess with reader's head maybe some proper variable name would be great
* @brief Compute dimensions of input/output Tensors for each Operator of the * @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes, by calling Node::forwardDims(). * GraphView object's Nodes, by calling Node::forwardDims().
* This function verifies the following conditions: * This function verifies the following conditions:
...@@ -299,19 +302,27 @@ public: ...@@ -299,19 +302,27 @@ public:
*/ */
bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ /** @brief Set the same backend for each Operator of the GraphView object's Nodes.
* @param string: backend Backend name to be set
* @param DeviceIdx_t: device Backend device to be set
*/
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
/** @brief Set the same data type for each Operator of the GraphView object's Nodes. */ /** @brief Set the same data type for each Operator of the GraphView object's Nodes.
* @param DataType: datatype DataType to be set
*/
void setDataType(const DataType& datatype) const; void setDataType(const DataType& datatype) const;
/** @brief Set the same data format for each Operator of the GraphView object's Nodes. */ /** @brief Set the same data format for each Operator of the GraphView object's Nodes.
* @param DataFormat: dataformat DataFormat to be set
*/
void setDataFormat(const DataFormat& dataformat) const; void setDataFormat(const DataFormat& dataformat) const;
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TOPOLOGY // TOPOLOGY
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
//@todo 50 shades of get
public: public:
/** /**
* @brief Get the parents Nodes of inputNodes. * @brief Get the parents Nodes of inputNodes of the graph.
* @return std::set<NodePtr> * @return std::set<NodePtr>
*/ */
std::set<NodePtr> getParents() const; std::set<NodePtr> getParents() const;
...@@ -344,8 +355,7 @@ public: ...@@ -344,8 +355,7 @@ public:
inline const std::set<NodePtr>& getNodes() const noexcept { return mNodes; } inline const std::set<NodePtr>& getNodes() const noexcept { return mNodes; }
/** /**
* @brief Get the operator with the corresponding name if it is in the * @brief Get the Node with the corresponding name if it is in the GraphView.
* GraphView.
* @param nodeName Name of the node. * @param nodeName Name of the node.
* @return NodePtr returns a nullptr if the one asked for * @return NodePtr returns a nullptr if the one asked for
* was not found. * was not found.
...@@ -370,6 +380,7 @@ public: ...@@ -370,6 +380,7 @@ public:
*/ */
std::pair<std::vector<NodePtr>, size_t> getRankedNodes() const; std::pair<std::vector<NodePtr>, size_t> getRankedNodes() const;
//@todo here i am
/** /**
* Get the nodes name according to the GraphView nodes ranking. * Get the nodes name according to the GraphView nodes ranking.
* @param format The formatting string to be used with fmt::format(). * @param format The formatting string to be used with fmt::format().
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment