From dc5ae734a8f3fbc91a3eafbd5fc5e2391be4400d Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 26 Oct 2023 14:55:24 +0000 Subject: [PATCH] [Add] Introduce compile() member function to set the GraphView ready for forward in one line --- include/aidge/graph/GraphView.hpp | 25 ++++++++++++++++++++----- src/graph/GraphView.cpp | 15 ++++++++++++++- unit_tests/graph/Test_GraphView.cpp | 2 +- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 89ba14849..e87f6a3e8 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -162,6 +162,21 @@ public: std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs( std::string nodeName) const; + /** + * @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent. + * If not, apply the required transformations. + * @details Sets the GraphView ready for computation in four steps: + * 1 - Assert input Tensors' datatype is compatible with each Operator's datatype. + * If not, a conversion Operator is inserted. + * 2 - Assert input Tensors' backend is compatible with each Operator's backend. + * If not, add a Transmitter Operator. + * 3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is + * compatible with the selected kernel. + * If not, add a Transpose Operator. + * 4 - Propagate Tensor dimensions through the consecutive Operators. + */ + void compile(const std::string& backend, const Aidge::DataType datatype); + /** * @brief Compute dimensions of input/output Tensors for each Operator of the * GraphView object's Nodes. @@ -322,17 +337,17 @@ public: /** * @brief Insert a node (newParentNode) as a parent of the passed node (childNode). - * + * * @param childNode Node that gets a new parent. * @param newParentNode Inserted Node. * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output. * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode. * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor. */ - void insertParent(NodePtr childNode, - NodePtr newParentNode, - IOIndex_t childInputTensorIdx, - IOIndex_t newParentInputTensorIdx, + void insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx); /** diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 8f8f51c89..1ca54c9c1 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -164,6 +164,19 @@ Aidge::GraphView::inputs(std::string name) const { return mNodeRegistry.at(name)->inputs(); } +void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype) { + // Backend + // TODO: add Backend attribute to Operator + setBackend(backend); + // Data type + // TODO: manage Datatype attribute in OperatorImpl + setDatatype(datatype); + // Data Format + // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary + // Forward dimensions + forwardDims(); +} + void Aidge::GraphView::forwardDims() { // setInputs // Link every tensor to the right pointer @@ -225,7 +238,7 @@ void Aidge::GraphView::setBackend(const std::string &backend) { } } -void Aidge::GraphView::setDatatype(const DataType &datatype) { +void Aidge::GraphView::setDatatype(const Aidge::DataType &datatype) { for (auto node : getNodes()) { node->getOperator()->setDatatype(datatype); } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 9f0143646..0811f4abf 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -244,7 +244,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } } -TEST_CASE("Graph Forward dims", "[GraphView]") { +TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); -- GitLab