diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index c6e3322aea3ee74322571b6619e5b02f857ef12e..37ddb382d6364554e5155958c427778760465f81 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -295,6 +295,8 @@ public: */ bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); + bool forwardDType(const std::vector<DataType>& inputTypes = {}); + /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; /** @brief Set the same data type for each Operator of the GraphView object's Nodes. */ @@ -613,6 +615,19 @@ private: */ void updateInputsOutputsDelete(NodePtr deletedNode); + /** + * @brief Validates the connectivity and tensor integrity of the graph. + * + * This function ensures that all nodes in the graph are correctly connected + * and that mandatory input tensors are properly defined. It verifies: + * - That each node's input matches the expected output from its connected node. + * - That all mandatory inputs are present and defined. + * - Logs an error and returns `false` if any inconsistency is detected. + * + * @return `true` if all connections and tensor states are valid, `false` otherwise. + */ + bool connectionValid(); + /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 8bd8239ec664a7bcb9d520c3dc37488f932437bb..4ce9f7a494ea1089d0f28d31c664e863a111062e 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -111,6 +111,12 @@ public: */ bool forwardDims(bool allowDataDependency = false) override final; + /** + * @brief Forward the data type. + * @return True if successful, false otherwise. + */ + bool forwardDType() override final; + /** * @brief Set the backend for the operator. * @param name The name of the backend. diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index c24d3ba21e1c3a01f3a0a40bbabc90eae07e4fc7..3ba37cbbb6754a61aa03b812e7adfdcd4899e96c 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -172,6 +172,16 @@ public: */ virtual bool forwardDims(bool allowDataDependency = false); + /** + * @brief Computes the data type of the operator's output tensor based on input data type. + * + * For each operator inputs: + * - If input is an (optional) Param, the operator will forward + * + * @return True if data types are successfully computed, false otherwise. + */ + virtual bool forwardDType(); + /** * @brief Checks if dimensions have been successfully forwarded. * @return True if dimensions are forwarded, false otherwise. @@ -189,7 +199,7 @@ public: /** * @brief Sets the data type of the operator's tensors. - * @warning Sets all outputs but only inputs of category + * @warning Sets all outputs but only inputs of category * @code InputCategory::Param @endcode & @code InputCategory::OptionnalParam @endcode * @param dataType Data type to set. */ diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index c93ef09c9dd35ca887b0b491bd8c1177dbbb35e1..f8bfaf73bb7b4d585af28377f5314822afdbfb94 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -120,6 +120,11 @@ public: */ bool forwardDims(bool allowDataDependency = false) override final; + /** + * @brief Forward the data type. + * @return True if successful, false otherwise. + */ + bool forwardDType() override final; /** * @brief Set the backend for the Reshape operator. * @param[in] name Name of the backend. diff --git a/include/aidge/operator/Shape.hpp b/include/aidge/operator/Shape.hpp index 4028c4041584833f14a4fa4db0f944dca2c2f035..3d5d02f917c2465d587498ae65a1f6d6308f4256 100644 --- a/include/aidge/operator/Shape.hpp +++ b/include/aidge/operator/Shape.hpp @@ -108,6 +108,12 @@ public: */ bool forwardDims(bool /*allowDataDependency*/ = false) override final; + /** + * @brief Forward the data type. + * @return True if successful, false otherwise. + */ + bool forwardDType() override final; + /** * @brief Set the backend for the Shape operator. * @param[in] name Name of the backend. diff --git a/include/aidge/operator/Unsqueeze.hpp b/include/aidge/operator/Unsqueeze.hpp index 27b3851fc7b741955889f7119bdf2b829918950a..b8b367090acfe2ad18352334fd1594e44c473be8 100644 --- a/include/aidge/operator/Unsqueeze.hpp +++ b/include/aidge/operator/Unsqueeze.hpp @@ -105,6 +105,12 @@ public: * @brief Compute dimensions for the output Tensor */ bool forwardDims(bool allowDataDependency = false) override final; + /** + * @brief Forward the data type. + * @return True if successful, false otherwise. + */ + bool forwardDType() override final; + bool dimsForwarded() const override final; void setBackend(const std::string &name, diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index abb1a9eca0bc3edb1ee0faaecb9f6cd9bc52e167..1d1778c318337cc15c6330be430eb5199603a4bb 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -128,6 +128,7 @@ void init_GraphView(py::module& m) { .def("clone", &GraphView::clone) .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) + .def("forward_dtype", &GraphView::forwardDType, py::arg("dtypes") = std::vector<DataType>()) .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, R"mydelimiter( Compute and propagate Tensor dimensions through the GraphView. @@ -209,7 +210,7 @@ void init_GraphView(py::module& m) { :param dims: input dimension to forward :type dims: List[List[Int]] - + )mydelimiter") .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 07fb764b404016cf0182df8f258f33ea69b7656f..d28d48dd3d6b0dc7dc30cb9d699370b479bdc3e4 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -451,6 +451,147 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(dims); } +bool Aidge::GraphView::connectionValid(){ + // Ensure every node in the graph is correctly connected + Log::debug("Verifying graph connections and tensor validity"); + for (std::shared_ptr<Node> nodePtr : getNodes()) { + for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { + std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); + if (inputI.first) { + if (nodePtr->getOperator()->getRawInput(i) != inputI.first->getOperator()->getRawOutput(inputI.second)) { + Log::error("Connection mismatch: Input#{} of node [\033[1m\033[3m{}\033[0m (\033[1m\033[3m{}\033[0m)] -> Output#{} of node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", + i, nodePtr->name(), nodePtr->type(), inputI.second, inputI.first->name(), inputI.first->type()); + return false; + } + } else if (nodePtr->inputCategory(i) != InputCategory::OptionalData && + nodePtr->inputCategory(i) != InputCategory::OptionalParam) { + if (!nodePtr->getOperator()->getRawInput(i)) { + Log::error("Missing mandatory input#{} for node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", + i, nodePtr->name(), nodePtr->type()); + return false; + } + if (std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->undefined()) { + Log::error("Undefined mandatory input#{} for node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", + i, nodePtr->name(), nodePtr->type()); + return false; + } + } + } + } + return true; +} + +bool Aidge::GraphView::forwardDType(const std::vector<Aidge::DataType>& inputTypes){ + if (!inputTypes.empty()){ + auto msg = fmt::format("Manually setting GraphView input data type with provided parameters:"); + for (std::size_t i = 0; i< inputTypes.size(); ++i) + msg = fmt::format("{}\n\t* input#{} {}", msg, i, inputTypes[i]); + Log::info("{}", msg); + + Log::debug("Validating input dtype against existing graph inputs"); + std::size_t i = 0; + for (auto& input : mInputNodes) { + const auto& currentTensorPtr = + std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second); + if (i < inputTypes.size()) { + if (!currentTensorPtr) { // tensor detected + Log::debug("Creating new tensor for input#{} with dtype {}", i, inputTypes[i]); + auto tensor = std::make_shared<Tensor>(inputTypes[i], DataFormat::Default); + input.first->getOperator()->setInput(input.second, tensor); + } + } + else { + const bool optional = (input.first->inputCategory(input.second) == InputCategory::OptionalData + || input.first->inputCategory(input.second) == InputCategory::OptionalParam); + + if (currentTensorPtr) { + Log::debug("Using existing data type {} for graph input#{} (matching input#{} of node [\033[1m\033[3m{}\033[0m] - [\033[1m\033[3m{}\033[0m])", + currentTensorPtr->dataType(), i, input.second, input.first->name(), input.first->type()); + } + else if (!optional) { + Log::warn("Missing data type for mandatory graph input#{} (matching input#{} of node [\033[1m\033[3m{}\033[0m] - [\033[1m\033[3m{}\033[0m])", + i, input.second, input.first->name(), input.first->type()); + } + } + ++i; + } + } + if(!connectionValid()) return false; + // INITIALIZING Open and Close sets + std::set<std::shared_ptr<Node>> close; // Already treated nodes + std::set<std::shared_ptr<Node>> open = inputNodes(); // Nodes to treat + for (const auto& nodePtr : getNodes()) { + if (nodePtr->type() == Producer_Op::Type) { + // Producers dType is set by user + // So it is considered already treated + close.insert(nodePtr); + // Producers childs are put in open list + for (const auto& child : nodePtr->getChildren()) { + if (inView(child)) open.insert(child); + } + } + } + do{ + std::set<std::shared_ptr<Node>> newOpen; + for (const auto& nodePtr : open) { + if (nodePtr->getOperator()->operatorType() != OperatorType::Tensor) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Node {} (of type {}) as it is not an OperatorTensor. ForwardDType is currently only supported for OperatorTensor.", nodePtr->name(), nodePtr->type()); + } + const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); + bool anyParent = false; + bool parentsForwarded = true; + for (const auto& parent : nodePtr->getParents()) { + if (parent != nullptr && inView(parent) && close.find(parent) == close.end()) { + Log::debug("Data type not forwarded for parent (node {} (of type {})) of node {} (of type {})", + parent->name(), parent->type(), nodePtr->name(), nodePtr->type()); + parentsForwarded = false; + } + else { + anyParent = true; + } + } + // Special rule for Memorize_Op, which only requires one parent + // to have its dtype forwarded. This avoids circular dependency. + if (nodePtr->type() == Memorize_Op::Type && anyParent) { + parentsForwarded = true; + } + if (parentsForwarded && op->forwardDType()) { + Log::debug("Data type forwarded for node {} (of type {})", + nodePtr->name(), nodePtr->type()); + + // Recompute every time, even if it was already computed in a + // previous call of forwardDims(), as the graph may have changed! + close.insert(nodePtr); + for (const auto& child : nodePtr->getChildren()) { + if (inView(child) && close.find(child) == close.end()) { + newOpen.insert(child); + } + } + } + else { + if (parentsForwarded) { + Log::debug("Unable to forward dimensions for node {} (of type {})", nodePtr->name(), nodePtr->type()); + } + Log::debug("Adding back node {} (of type {}) to the list of nodes to forward data type", nodePtr->name(), nodePtr->type()); + newOpen.insert(nodePtr); + } + + } + if (newOpen == open) { + // We are stuck! + std::vector<std::string> nodesName; + std::transform(newOpen.begin(), newOpen.end(), + std::back_inserter(nodesName), + [](auto val){ return val->name() + " (" + val->type() + ")"; }); + + Log::warn("Unable to forward data type (circular dependency and/or wrong dimensions and/or data dependent dimension?). Unable to compute output data type for nodes {}.", nodesName); + return false; + } + open.swap(newOpen); + }while(!open.empty()); + return open.empty(); +} + bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency) { Log::debug("Starting dimension forward propagation for GraphView"); // remove current Data connections and use dummy inputs to propagate dimensions @@ -499,32 +640,7 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ } } - // Ensure every node in the graph is correctly connected - Log::debug("Verifying graph connections and tensor validity"); - for (std::shared_ptr<Node> nodePtr : getNodes()) { - for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { - std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); - if (inputI.first) { - if (nodePtr->getOperator()->getRawInput(i) != inputI.first->getOperator()->getRawOutput(inputI.second)) { - Log::error("Connection mismatch: Input#{} of node [\033[1m\033[3m{}\033[0m (\033[1m\033[3m{}\033[0m)] -> Output#{} of node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", - i, nodePtr->name(), nodePtr->type(), inputI.second, inputI.first->name(), inputI.first->type()); - return false; - } - } else if (nodePtr->inputCategory(i) != InputCategory::OptionalData && - nodePtr->inputCategory(i) != InputCategory::OptionalParam) { - if (!nodePtr->getOperator()->getRawInput(i)) { - Log::error("Missing mandatory input#{} for node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", - i, nodePtr->name(), nodePtr->type()); - return false; - } - if (std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->undefined()) { - Log::error("Undefined mandatory input#{} for node [\033[1m\033[3m{}\033[0m - (\033[1m\033[3m{}\033[0m)]", - i, nodePtr->name(), nodePtr->type()); - return false; - } - } - } - } + if(!connectionValid()) return false; Log::debug("Initializing dimension propagation"); // Establish initial list of dims forwardable nodes: graph input node + Producers childs diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index a4cb4aab0f10cbb3b197e743a5b40208b4a0da94..ccef5ec534b45dfe0e010c5e20d39235e88205fd 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -59,7 +59,14 @@ bool Aidge::Gather_Op::dimsForwarded() const { return OperatorTensor::dimsForwarded(); } - +bool Aidge::Gather_Op::forwardDType(){ + if (inputsAssociated()) { + mOutputs[0]->setDataType(getInput(0)->dataType()); + return true; + } + Log::notice("Gather_Op: No input associated, failed to forward data type."); + return false; +} bool Aidge::Gather_Op::forwardDims(bool allowDataDependency) { if (inputsAssociated()) { // Copy optional input #1, if present, to attribute Indices diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index f907c584980312e25463e201e26c044f76339f76..050d823ae9a7e7bb3270cf65f9802a8f57d9a6e2 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -169,6 +169,36 @@ bool Aidge::OperatorTensor::forwardDims(bool /*allowDataDependency*/) { return false; } +bool Aidge::OperatorTensor::forwardDType(){ + Log::debug("Running default forwardDtype for operator {}", + type()); + + if (inputsAssociated()) { + const auto expectedDType = getInput(0)->dataType(); + for (std::size_t i = 1; i < nbInputs(); ++i) { + if (inputCategory(i) == InputCategory::OptionalParam + || inputCategory(i) == InputCategory::Param){ + // Param input can be different dtype than data input + continue; + } + if (expectedDType != getInput(i)->dataType()) { + Log::notice("{} operator's inputs should have the same datatype: expected {} (input #0), given {} (input #{})", + type(), expectedDType, getInput(i)->dataType(), i); + return false; + } + } + + for (std::size_t o = 0; o < nbOutputs(); ++o) { + Log::debug("Setting output#{} dtype to {}", + o, expectedDType); + mOutputs[o]->setDataType(expectedDType); + } + return true; + } + + return false; +} + bool Aidge::OperatorTensor::dimsForwarded() const { bool forwarded = true; // check both inputs and outputs have been filled diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index b12fd486d16beb0a676e38cfdf808fa71996a5ba..b4cd272a18ab8a996507a3812a02c0735e665cca 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -59,6 +59,14 @@ bool Aidge::Reshape_Op::dimsForwarded() const { return OperatorTensor::dimsForwarded(); } +bool Aidge::Reshape_Op::forwardDType(){ + if (inputsAssociated()) { + mOutputs[0]->setDataType(getInput(0)->dataType()); + return true; + } + Log::notice("Reshape_Op: No input associated, failed to forward data type."); + return false; +} bool Aidge::Reshape_Op::forwardDims(bool allowDataDependency) { if (inputsAssociated()) { diff --git a/src/operator/Shape.cpp b/src/operator/Shape.cpp index 4db4704739b362426adb1831c1c95b3796aa918a..4791a14a5e4fa10a58bceccc46e537df8ac63cd0 100644 --- a/src/operator/Shape.cpp +++ b/src/operator/Shape.cpp @@ -49,7 +49,10 @@ Aidge::Shape_Op::Shape_Op(const Aidge::Shape_Op& op) std::shared_ptr<Aidge::Operator> Aidge::Shape_Op::clone() const { return std::make_shared<Shape_Op>(*this); } - +bool Aidge::Shape_Op::forwardDType(){ + mOutputs[0]->setDataType(DataType::Int64); + return true; +} bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) { if (inputsAssociated()) { if (this->start() < 0) diff --git a/src/operator/Unsqueeze.cpp b/src/operator/Unsqueeze.cpp index 679b420ec3d794f7efbbe730dd0d75fde4553dea..23d310bbeeff0a42d17a36099365908fa10045f2 100644 --- a/src/operator/Unsqueeze.cpp +++ b/src/operator/Unsqueeze.cpp @@ -55,6 +55,17 @@ bool Aidge::Unsqueeze_Op::dimsForwarded() const { return OperatorTensor::dimsForwarded(); } +bool Aidge::Unsqueeze_Op::forwardDType(){ + if (inputsAssociated()) { + Log::debug("Unsqueeze_Op: setting output dtype to {}", + getInput(0)->dataType()); + mOutputs[0]->setDataType(getInput(0)->dataType()); + return true; + } + Log::notice("Unsqueeze_Op: No input associated, failed to forward data type."); + return false; +} + bool Unsqueeze_Op::forwardDims(bool allowDataDependency) { // error checking if (!inputsAssociated(true)) {