diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 89ba148497709f0af475bbf953ff285c88036102..e87f6a3e88c996ecd53aa5ad98bd7733f02f67a9 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -162,6 +162,21 @@ public: std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs( std::string nodeName) const; + /** + * @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent. + * If not, apply the required transformations. + * @details Sets the GraphView ready for computation in four steps: + * 1 - Assert input Tensors' datatype is compatible with each Operator's datatype. + * If not, a conversion Operator is inserted. + * 2 - Assert input Tensors' backend is compatible with each Operator's backend. + * If not, add a Transmitter Operator. + * 3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is + * compatible with the selected kernel. + * If not, add a Transpose Operator. + * 4 - Propagate Tensor dimensions through the consecutive Operators. + */ + void compile(const std::string& backend, const Aidge::DataType datatype); + /** * @brief Compute dimensions of input/output Tensors for each Operator of the * GraphView object's Nodes. @@ -322,17 +337,17 @@ public: /** * @brief Insert a node (newParentNode) as a parent of the passed node (childNode). - * + * * @param childNode Node that gets a new parent. * @param newParentNode Inserted Node. * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output. * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode. * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor. */ - void insertParent(NodePtr childNode, - NodePtr newParentNode, - IOIndex_t childInputTensorIdx, - IOIndex_t newParentInputTensorIdx, + void insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx); /** diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 8f8f51c89bbcc380963f355f781e8fda940dcffc..1ca54c9c194a6b0a1fcf932a1f0f92d3b251d312 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -164,6 +164,19 @@ Aidge::GraphView::inputs(std::string name) const { return mNodeRegistry.at(name)->inputs(); } +void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype) { + // Backend + // TODO: add Backend attribute to Operator + setBackend(backend); + // Data type + // TODO: manage Datatype attribute in OperatorImpl + setDatatype(datatype); + // Data Format + // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary + // Forward dimensions + forwardDims(); +} + void Aidge::GraphView::forwardDims() { // setInputs // Link every tensor to the right pointer @@ -225,7 +238,7 @@ void Aidge::GraphView::setBackend(const std::string &backend) { } } -void Aidge::GraphView::setDatatype(const DataType &datatype) { +void Aidge::GraphView::setDatatype(const Aidge::DataType &datatype) { for (auto node : getNodes()) { node->getOperator()->setDatatype(datatype); } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 9f014364636c70031b522b09c893e1144af3f133..0811f4abfe5504e5210f09f66b6774ba8362e28b 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -244,7 +244,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } } -TEST_CASE("Graph Forward dims", "[GraphView]") { +TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2");