From cea8cbcd21a40736fb82559bf548efbcb3b38c93 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 3 Apr 2024 11:33:57 +0000 Subject: [PATCH] update 'GraphView::compile()' member function --- include/aidge/graph/GraphView.hpp | 5 ++++- python_binding/graph/pybind_GraphView.cpp | 2 +- src/graph/GraphView.cpp | 12 ++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 06f73c97f..845599fd3 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -201,7 +201,10 @@ public: * If not, add a Transpose Operator. * 4 - Propagate Tensor dimensions through the consecutive Operators. */ - void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0); + void compile(const std::string& backend = "cpu", + const Aidge::DataType datatype = DataType::Float32, + DeviceIdx_t device = 0, + const std::vector<std::vector<DimSize_t>> dims = {}); /** * @brief Compute dimensions of input/output Tensors for each Operator of the diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index f06a70f32..953ec981e 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -118,7 +118,7 @@ void init_GraphView(py::module& m) { .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) - .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0) + .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index f498d5e82..dcd7a06ef 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -378,7 +378,7 @@ Aidge::GraphView::inputs(const std::string& name) const { return mNodeRegistry.at(name)->inputs(); } -void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device) { +void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device, const std::vector<std::vector<DimSize_t>> dims) { // Backend // TODO: add Backend attribute to Operator setBackend(backend, device); @@ -388,7 +388,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType // Data Format // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary // Forward dimensions - forwardDims(); + forwardDims(dims); } void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) { @@ -913,14 +913,14 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const // keep in memory every node related to the node to replace : // Parent for (std::size_t i = 0; i < oldOIn.size(); ++i) { - std::pair<NodePtr, IOIndex_t> inputParent = + std::pair<NodePtr, IOIndex_t> inputParent = oldOIn[i].first -> input(oldOIn[i].second); inputParents[i]= inputParent; // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); } // Children for (std::size_t i = 0; i < oldOOut.size();) { - std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = + std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = oldOOut[i].first -> output(oldOOut[i].second); if (outputChild.empty()) { outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); @@ -983,7 +983,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const for (std::size_t i = 0; i < oldOIn.size(); ++i) { if (inputParents[i].first) { inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); - } + } } } else if ((oldOIn.size() == 1) && (inputParents[0].first)) { @@ -1259,7 +1259,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo if (deletedNode == mRootNode) { const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes(); if(ranked_nodes.second== 0 || ranked_nodes.first.size() <= 1) - { + { mRootNode = nullptr; } else { // The new root node will be the second node in the order of ranked nodes -- GitLab