Skip to content
Snippets Groups Projects
Commit 633a061f authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix GraphView::forwardDType, error with undefined tensors.

parent 87e65d66
No related branches found
No related tags found
No related merge requests found
Pipeline #68640 passed
This commit is part of merge request !363. Comments created here will be created in the context of that merge request.
...@@ -295,8 +295,31 @@ public: ...@@ -295,8 +295,31 @@ public:
*/ */
bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
/**
* @brief Helper function to compute and forward data type throughout the graph
* It will try to infer the best output datatype based on the input datatype which.
* To do so it will based itself on the ``OperatorTensor::forwardDataType()`` method.
* A generic version of this method is defined in ``OperatorTensor`` and need to
* be override to account for special case.
*
* This method doesn't substitute itself to the user changing manually the data type
* of operators but it is preferred to use over ``GraphView::setDataType``.
*
* @param inputTypes A vector of data type, the order of the vector should be the same
* as the order of the inputs of the graph.
* @return true if the function succeed to propagate datatype throughout the graph.
*/
bool forwardDType(const std::vector<DataType>& inputTypes = {}); bool forwardDType(const std::vector<DataType>& inputTypes = {});
/**
* @brief Helper that call ``bool forwardDType(const std::vector<DataType>& inputTypes = {})``.
*
* @param inputType Data type to set for each input of the graph. That will be forwarded.
* @return true true if the function succeed to propagate data type throughout the graph.
*/
bool forwardDType(DataType inputType);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
/** @brief Set the same data type for each Operator of the GraphView object's Nodes. */ /** @brief Set the same data type for each Operator of the GraphView object's Nodes. */
...@@ -623,10 +646,10 @@ private: ...@@ -623,10 +646,10 @@ private:
* - That each node's input matches the expected output from its connected node. * - That each node's input matches the expected output from its connected node.
* - That all mandatory inputs are present and defined. * - That all mandatory inputs are present and defined.
* - Logs an error and returns `false` if any inconsistency is detected. * - Logs an error and returns `false` if any inconsistency is detected.
* * @param checkDefinedTensor if True, check that each tensors are not undefined.
* @return `true` if all connections and tensor states are valid, `false` otherwise. * @return `true` if all connections and tensor states are valid, `false` otherwise.
*/ */
bool connectionValid(); bool connectionValid(bool checkDefinedTensor = true);
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TOPOLOGY // TOPOLOGY
......
...@@ -80,7 +80,7 @@ void init_GraphView(py::module& m) { ...@@ -80,7 +80,7 @@ void init_GraphView(py::module& m) {
:param include_learnable_parameters: include non-data inputs, like weights and biases, default True. :param include_learnable_parameters: include non-data inputs, like weights and biases, default True.
:type include_learnable_parameters: bool, optional :type include_learnable_parameters: bool, optional
)mydelimiter") )mydelimiter")
.def("insert_parent", &GraphView::insertParent, py::arg("child_node"), py::arg("new_parent_node"), py::arg("child_input_tensor_idx"), py::arg("new_parent_input_tensor_idx"), py::arg("new_parent_output_tensor_idx"))
.def("add_child", .def("add_child",
(void (GraphView::*)(std::shared_ptr<Node>, (void (GraphView::*)(std::shared_ptr<Node>,
std::shared_ptr<Node>, std::shared_ptr<Node>,
...@@ -128,7 +128,8 @@ void init_GraphView(py::module& m) { ...@@ -128,7 +128,8 @@ void init_GraphView(py::module& m) {
.def("clone", &GraphView::clone) .def("clone", &GraphView::clone)
.def("get_nodes", &GraphView::getNodes) .def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name")) .def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dtype", &GraphView::forwardDType, py::arg("dtypes") = std::vector<DataType>()) .def("forward_dtype", (bool(GraphView::*)(const std::vector<DataType>&)) &GraphView::forwardDType, py::arg("dtypes") = std::vector<DataType>())
.def("forward_dtype", (bool(GraphView::*)(DataType)) &GraphView::forwardDType, py::arg("dtype"))
.def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false,
R"mydelimiter( R"mydelimiter(
Compute and propagate Tensor dimensions through the GraphView. Compute and propagate Tensor dimensions through the GraphView.
......
...@@ -443,7 +443,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType ...@@ -443,7 +443,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
forwardDims(dims); forwardDims(dims);
} }
bool Aidge::GraphView::connectionValid(){ bool Aidge::GraphView::connectionValid(bool checkDefinedTensor){
// Ensure every node in the graph is correctly connected // Ensure every node in the graph is correctly connected
Log::debug("Verifying graph connections and tensor validity"); Log::debug("Verifying graph connections and tensor validity");
for (std::shared_ptr<Node> nodePtr : getNodes()) { for (std::shared_ptr<Node> nodePtr : getNodes()) {
...@@ -462,7 +462,7 @@ bool Aidge::GraphView::connectionValid(){ ...@@ -462,7 +462,7 @@ bool Aidge::GraphView::connectionValid(){
i, nodePtr->name(), nodePtr->type()); i, nodePtr->name(), nodePtr->type());
return false; return false;
} }
if (std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->undefined()) { if (checkDefinedTensor && std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->undefined()) {
Log::error("Undefined mandatory input#{} for node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", Log::error("Undefined mandatory input#{} for node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]",
i, nodePtr->name(), nodePtr->type()); i, nodePtr->name(), nodePtr->type());
return false; return false;
...@@ -473,6 +473,10 @@ bool Aidge::GraphView::connectionValid(){ ...@@ -473,6 +473,10 @@ bool Aidge::GraphView::connectionValid(){
return true; return true;
} }
bool Aidge::GraphView::forwardDType(DataType inputType){
return forwardDType(std::vector<DataType>(getNbDataInputs(), inputType));
}
bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTypes){ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTypes){
if (!inputTypes.empty()){ if (!inputTypes.empty()){
auto msg = fmt::format("Manually setting GraphView input data type with provided parameters:"); auto msg = fmt::format("Manually setting GraphView input data type with provided parameters:");
...@@ -486,10 +490,12 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp ...@@ -486,10 +490,12 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp
const auto& currentTensorPtr = const auto& currentTensorPtr =
std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second); std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second);
if (i < inputTypes.size()) { if (i < inputTypes.size()) {
if (!currentTensorPtr) { // tensor detected if (!currentTensorPtr) {
Log::debug("Creating new tensor for input#{} with dtype {}", i, inputTypes[i]); Log::debug("Creating new tensor for input#{} with dtype {}", i, inputTypes[i]);
auto tensor = std::make_shared<Tensor>(inputTypes[i], DataFormat::Default); auto tensor = std::make_shared<Tensor>(inputTypes[i], DataFormat::Default);
input.first->getOperator()->setInput(input.second, tensor); input.first->getOperator()->setInput(input.second, tensor);
}else{
currentTensorPtr->setDataType(inputTypes[i]);
} }
} }
else { else {
...@@ -508,7 +514,9 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp ...@@ -508,7 +514,9 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp
++i; ++i;
} }
} }
if(!connectionValid()) return false;
if(!connectionValid(false)) return false;
// INITIALIZING Open and Close sets // INITIALIZING Open and Close sets
std::set<std::shared_ptr<Node>> close; // Already treated nodes std::set<std::shared_ptr<Node>> close; // Already treated nodes
std::set<std::shared_ptr<Node>> open = inputNodes(); // Nodes to treat std::set<std::shared_ptr<Node>> open = inputNodes(); // Nodes to treat
...@@ -524,6 +532,10 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp ...@@ -524,6 +532,10 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp
} }
} }
do{ do{
Log::debug("List of node to forward data type:");
for(auto node : open){
Log::debug("\t- Node {} (of type {})", node->name(), node->type());
}
std::set<std::shared_ptr<Node>> newOpen; std::set<std::shared_ptr<Node>> newOpen;
for (const auto& nodePtr : open) { for (const auto& nodePtr : open) {
if (nodePtr->getOperator()->operatorType() != OperatorType::Tensor) { if (nodePtr->getOperator()->operatorType() != OperatorType::Tensor) {
...@@ -552,7 +564,7 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp ...@@ -552,7 +564,7 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp
nodePtr->name(), nodePtr->type()); nodePtr->name(), nodePtr->type());
// Recompute every time, even if it was already computed in a // Recompute every time, even if it was already computed in a
// previous call of forwardDims(), as the graph may have changed! // previous call of forwardDType(), as the graph may have changed!
close.insert(nodePtr); close.insert(nodePtr);
for (const auto& child : nodePtr->getChildren()) { for (const auto& child : nodePtr->getChildren()) {
if (inView(child) && close.find(child) == close.end()) { if (inView(child) && close.find(child) == close.end()) {
...@@ -562,7 +574,8 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp ...@@ -562,7 +574,8 @@ bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTyp
} }
else { else {
if (parentsForwarded) { if (parentsForwarded) {
Log::debug("Unable to forward dimensions for node {} (of type {})", nodePtr->name(), nodePtr->type()); Log::error("Unable to forward data type for node {} (of type {})", nodePtr->name(), nodePtr->type());
} }
Log::debug("Adding back node {} (of type {}) to the list of nodes to forward data type", nodePtr->name(), nodePtr->type()); Log::debug("Adding back node {} (of type {}) to the list of nodes to forward data type", nodePtr->name(), nodePtr->type());
newOpen.insert(nodePtr); newOpen.insert(nodePtr);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment