diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 37ddb382d6364554e5155958c427778760465f81..081c429e869a7897d9a24b4633f87a7f6efd68e3 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -295,8 +295,31 @@ public: */ bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); + /** + * @brief Helper function to compute and forward data type throughout the graph + * It will try to infer the best output datatype based on the input datatype which. + * To do so it will based itself on the ``OperatorTensor::forwardDataType()`` method. + * A generic version of this method is defined in ``OperatorTensor`` and need to + * be override to account for special case. + * + * This method doesn't substitute itself to the user changing manually the data type + * of operators but it is preferred to use over ``GraphView::setDataType``. + * + * @param inputTypes A vector of data type, the order of the vector should be the same + * as the order of the inputs of the graph. + * @return true if the function succeed to propagate datatype throughout the graph. + */ bool forwardDType(const std::vector<DataType>& inputTypes = {}); + + /** + * @brief Helper that call ``bool forwardDType(const std::vector<DataType>& inputTypes = {})``. + * + * @param inputType Data type to set for each input of the graph. That will be forwarded. + * @return true true if the function succeed to propagate data type throughout the graph. + */ + bool forwardDType(DataType inputType); + /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; /** @brief Set the same data type for each Operator of the GraphView object's Nodes. */ @@ -623,10 +646,10 @@ private: * - That each node's input matches the expected output from its connected node. * - That all mandatory inputs are present and defined. * - Logs an error and returns `false` if any inconsistency is detected. - * + * @param checkDefinedTensor if True, check that each tensors are not undefined. * @return `true` if all connections and tensor states are valid, `false` otherwise. */ - bool connectionValid(); + bool connectionValid(bool checkDefinedTensor = true); /////////////////////////////////////////////////////// // TOPOLOGY diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 1d1778c318337cc15c6330be430eb5199603a4bb..d1b99c305d0e067a74c13a33cde062b2c6f2ddfa 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -80,7 +80,7 @@ void init_GraphView(py::module& m) { :param include_learnable_parameters: include non-data inputs, like weights and biases, default True. :type include_learnable_parameters: bool, optional )mydelimiter") - + .def("insert_parent", &GraphView::insertParent, py::arg("child_node"), py::arg("new_parent_node"), py::arg("child_input_tensor_idx"), py::arg("new_parent_input_tensor_idx"), py::arg("new_parent_output_tensor_idx")) .def("add_child", (void (GraphView::*)(std::shared_ptr<Node>, std::shared_ptr<Node>, @@ -128,7 +128,8 @@ void init_GraphView(py::module& m) { .def("clone", &GraphView::clone) .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) - .def("forward_dtype", &GraphView::forwardDType, py::arg("dtypes") = std::vector<DataType>()) + .def("forward_dtype", (bool(GraphView::*)(const std::vector<DataType>&)) &GraphView::forwardDType, py::arg("dtypes") = std::vector<DataType>()) + .def("forward_dtype", (bool(GraphView::*)(DataType)) &GraphView::forwardDType, py::arg("dtype")) .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, R"mydelimiter( Compute and propagate Tensor dimensions through the GraphView. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 64e5c331083a8c7e06b056ab915478e9fc07f718..d2e4079e115f8493331381f50f314e44cdc6e0d1 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -443,7 +443,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(dims); } -bool Aidge::GraphView::connectionValid(){ +bool Aidge::GraphView::connectionValid(bool checkDefinedTensor){ // Ensure every node in the graph is correctly connected Log::debug("Verifying graph connections and tensor validity"); for (std::shared_ptr<Node> nodePtr : getNodes()) { @@ -462,7 +462,7 @@ bool Aidge::GraphView::connectionValid(){ i, nodePtr->name(), nodePtr->type()); return false; } - if (std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->undefined()) { + if (checkDefinedTensor && std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->undefined()) { Log::error("Undefined mandatory input#{} for node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", i, nodePtr->name(), nodePtr->type()); return false; @@ -473,6 +473,10 @@ bool Aidge::GraphView::connectionValid(){ return true; } +bool Aidge::GraphView::forwardDType(DataType inputType){ + return forwardDType(std::vector<DataType>(getNbDataInputs(), inputType)); +} + bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTypes){ if (!inputTypes.empty()){ auto msg = fmt::format("Manually setting GraphView input data type with provided parameters:"); @@ -486,10 +490,12 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp const auto& currentTensorPtr = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second); if (i < inputTypes.size()) { - if (!currentTensorPtr) { // tensor detected + if (!currentTensorPtr) { Log::debug("Creating new tensor for input#{} with dtype {}", i, inputTypes[i]); auto tensor = std::make_shared<Tensor>(inputTypes[i], DataFormat::Default); input.first->getOperator()->setInput(input.second, tensor); + }else{ + currentTensorPtr->setDataType(inputTypes[i]); } } else { @@ -508,7 +514,9 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp ++i; } } - if(!connectionValid()) return false; + + if(!connectionValid(false)) return false; + // INITIALIZING Open and Close sets std::set<std::shared_ptr<Node>> close; // Already treated nodes std::set<std::shared_ptr<Node>> open = inputNodes(); // Nodes to treat @@ -524,6 +532,10 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp } } do{ + Log::debug("List of node to forward data type:"); + for(auto node : open){ + Log::debug("\t- Node {} (of type {})", node->name(), node->type()); + } std::set<std::shared_ptr<Node>> newOpen; for (const auto& nodePtr : open) { if (nodePtr->getOperator()->operatorType() != OperatorType::Tensor) { @@ -552,7 +564,7 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp nodePtr->name(), nodePtr->type()); // Recompute every time, even if it was already computed in a - // previous call of forwardDims(), as the graph may have changed! + // previous call of forwardDType(), as the graph may have changed! close.insert(nodePtr); for (const auto& child : nodePtr->getChildren()) { if (inView(child) && close.find(child) == close.end()) { @@ -562,7 +574,8 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp } else { if (parentsForwarded) { - Log::debug("Unable to forward dimensions for node {} (of type {})", nodePtr->name(), nodePtr->type()); + Log::error("Unable to forward data type for node {} (of type {})", nodePtr->name(), nodePtr->type()); + } Log::debug("Adding back node {} (of type {}) to the list of nodes to forward data type", nodePtr->name(), nodePtr->type()); newOpen.insert(nodePtr);