From 1b929c79842c709ac178f88da69a72acfdc779cd Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 22 Nov 2023 15:53:55 +0000 Subject: [PATCH] Multiple changes - Remove setInput in Node - Change setDatatype to setDataType in GraphView and Tensor binding - Add namespace comment - Update Node includes - Run forwardDims() only if operators use Tensors --- include/aidge/graph/Node.hpp | 2 +- include/aidge/recipies/Recipies.hpp | 2 +- python_binding/data/pybind_Tensor.cpp | 2 +- python_binding/graph/pybind_GraphView.cpp | 2 +- python_binding/graph/pybind_Node.cpp | 6 +-- src/graph/GraphView.cpp | 64 +++++++++++++---------- src/graph/Node.cpp | 25 ++++----- 7 files changed, 55 insertions(+), 48 deletions(-) diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index b81f5288e..9717999da 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -169,7 +169,7 @@ public: * @param idx Input index. * @param tensor Constant Tensor to add as parent for specified index. */ - void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor); + // void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor); /** * @brief Get the lowest index in the InputData Parent list equal to the diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp index 97544937e..26f4cc9da 100644 --- a/include/aidge/recipies/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -89,6 +89,6 @@ void fuseBatchNorm(std::shared_ptr<GraphView> graphView); // std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); // void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); -} +} // namespace Aidge #endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 31470e0eb..babc534bd 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -35,7 +35,7 @@ void addCtor(py::class_<Tensor, /* Request a buffer descriptor from Python */ py::buffer_info info = b.request(); Tensor* newTensor = new Tensor(); - newTensor->setDatatype(NativeType<T>::type); + newTensor->setDataType(NativeType<T>::type); const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end()); newTensor->resize(dims); // TODO : Find a better way to choose backend diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 6ac2199b4..6a29c6941 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -89,7 +89,7 @@ void init_GraphView(py::module& m) { .def("get_node", &GraphView::getNode, py::arg("node_name")) .def("forward_dims", &GraphView::forwardDims) .def("__call__", &GraphView::operator(), py::arg("connectors")) - .def("set_datatype", &GraphView::setDatatype, py::arg("datatype")) + .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_backend", &GraphView::setBackend, py::arg("backend")) // .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { // // TODO : Should return error if backend not compatible with get diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index e3666d247..3b63189c8 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -90,7 +90,7 @@ void init_Node(py::module& m) { .def("input", &Node::input, py::arg("in_id"), R"mydelimiter( Get the parent Node and the associated output index connected to the i-th input of the current Node. - + :param in_id: input index of the current Node object. :type in_id: int :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) @@ -108,7 +108,7 @@ void init_Node(py::module& m) { .def("output", &Node::output, py::arg("out_id"), R"mydelimiter( Get a list of the children Node for a specific output and the associated input index connected to it. - + :param out_id: input index of the current Node object. :type out_id: int :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) @@ -122,7 +122,7 @@ void init_Node(py::module& m) { :rtype: int )mydelimiter") - .def("get_nb_datainputs", &Node::nbDataInputs, + .def("get_nb_data", &Node::nbData, R"mydelimiter( Number of data inputs. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index af3e24c20..2306ec8ab 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -17,6 +17,7 @@ #include "aidge/utils/Types.h" #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/ErrorHandling.hpp" /////////////////////////////////////////////////////// @@ -171,7 +172,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType setBackend(backend); // Data type // TODO: manage Datatype attribute in OperatorImpl - setDatatype(datatype); + setDataType(datatype); // Data Format // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary // Forward dimensions @@ -208,41 +209,46 @@ void Aidge::GraphView::forwardDims() { } void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { - // TODO: support multi-inputs/outputs - std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>(); - for (std::shared_ptr<Node> nodePtr : listNodes) { - if (!nodePtr->getOperator()->outputDimsForwarded()) { - nodePtr->getOperator()->computeOutputDims(); - } - if (!nodePtr->getOperator()->outputDimsForwarded()) { - nextList.insert(nodePtr); - } else { - std::set<std::shared_ptr<Node>> children = nodePtr->getChildren(); - nextList.insert(children.begin(), children.end()); + // TODO: support multi-inputs/outputs + std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>(); + for (std::shared_ptr<Node> nodePtr : listNodes) { + if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { + const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); + if (!op->outputDimsForwarded()) { + op->computeOutputDims(); + } + if (!op->outputDimsForwarded()) { // try to compute output dimensions again later + nextList.insert(nodePtr); + } else { // compute output dimensions of children + std::set<std::shared_ptr<Node>> children = nodePtr->getChildren(); + nextList.insert(children.begin(), children.end()); + } + } } - } - if (nextList.empty()) { - for (std::shared_ptr<Node> nodePtr : getNodes()) { - if (!nodePtr->getOperator()->outputDimsForwarded()) { - nextList.insert(nodePtr); - } + if (nextList.empty()) { + for (std::shared_ptr<Node> nodePtr : getNodes()) { + if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { + if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) { + nextList.insert(nodePtr); + } + } + } + } + if (!nextList.empty()) { + _forwardDims(nextList); } - } - if (!nextList.empty()) { - _forwardDims(nextList); - } } void Aidge::GraphView::setBackend(const std::string &backend) { - for (auto node : getNodes()) { - node->getOperator()->setBackend(backend); - } + for (auto node : getNodes()) { + node->getOperator()->setBackend(backend); + } } -void Aidge::GraphView::setDatatype(const Aidge::DataType &datatype) { - for (auto node : getNodes()) { - node->getOperator()->setDatatype(datatype); - } +void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) { + for (auto node : getNodes()) { + node->getOperator()->setDataType(datatype); + } } void Aidge::GraphView::updateOutputNodes() { diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 6dca8eaaf..5a7b05e46 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -15,6 +15,7 @@ #include "aidge/operator/Producer.hpp" #include <memory> #include <vector> +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Types.h" Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) @@ -111,18 +112,18 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No return res; } -void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { - assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); - if (mParents[idx] != nullptr) { - mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); - removeParent(idx); - } - std::shared_ptr<Node> newConstantNode = Producer(tensor); - newConstantNode->addChild(shared_from_this(), 0, idx); - for (auto& graphPtr : views()) { - graphPtr->add(newConstantNode); - } -} +// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { +// assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); +// if (mParents[idx] != nullptr) { +// mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); +// removeParent(idx); +// } +// std::shared_ptr<Node> newConstantNode = Producer(tensor); +// newConstantNode->addChild(shared_from_this(), 0, idx); +// for (auto& graphPtr : views()) { +// graphPtr->add(newConstantNode); +// } +// } std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::Node::outputs() const { -- GitLab