diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py index 4f6a2960348c44dc7b8a0b957f777ddac5a8562a..14bb6c3e9a4c7be1e1aecee02a04a5dc42e0a5d4 100644 --- a/aidge_core/show_graphview.py +++ b/aidge_core/show_graphview.py @@ -4,8 +4,9 @@ import builtins import aidge_core import numpy as np from pathlib import Path +from typing import Any, Dict, List, Optional -def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]: +def _retrieve_operator_attrs(node : aidge_core.Node) -> Dict[str, Optional[Any]]: """ Returns the dictionary containing the attributes of a given Node. @@ -13,7 +14,7 @@ def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bo :type graph: aidge_core.Node :return: A dictionary with the Node's attributes. - :rtype: dict[str, int, float, bool, None] + :rtype: Dict[str, Optional[Any]] """ if node.get_operator().attr is not None: @@ -27,7 +28,7 @@ def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bo return node_attr_dict -def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_embed : bool, write_trainable_params_ext : bool, path_trainable_params : Path, params_file_format : str) -> dict[str, int, float, bool, None]: +def _create_dict(ordered_nodes : List[aidge_core.Node], write_trainable_params_embed : bool, write_trainable_params_ext : bool, path_trainable_params : Path, params_file_format : str) -> Dict[str, Optional[Any]]: """ Creates a dictionary to store the information of a given ordered GraphView. @@ -43,7 +44,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e :type params_file_format: str :return: A dictionary with the GraphView description. - :rtype: dict[str, int, float, bool, None] + :rtype: Dict[str, Optional[Any]] """ graphview_dict = {'graph': []} @@ -79,7 +80,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e if parents[0] is None: parents.append(parents.pop(0)) else: pass - + parents_inputs = [] input_idx = 0 for parent in node.get_parents(): @@ -88,11 +89,11 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e for child in children: if child[0] == node and child[1] == input_idx: parents_inputs.append((parent.name(), input_idx)) - + elif parent is None: if input_idx not in [item[1] for item in parents_inputs]: parents_inputs.append((None, input_idx)) - + input_idx += 1 node_dict['parents'] = parents_inputs @@ -167,7 +168,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e return graphview_dict -def _write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None: +def _write_dict_json(graphview_dict : Dict[str, Optional[Any]], json_path : str) -> None: """ Writes dictionary containing GraphView description to a JSON file. diff --git a/aidge_core/testing/utils/tree_cache.py b/aidge_core/testing/utils/tree_cache.py index 5b363c7c73ea36636a40c007b24cc244b10303c2..9bb4f7734a46c1c2ec35c573fa3f72d0cae5e736 100644 --- a/aidge_core/testing/utils/tree_cache.py +++ b/aidge_core/testing/utils/tree_cache.py @@ -43,6 +43,7 @@ For more evolved scenarii, specialize the provided FileTreeCache class. from pathlib import Path import shutil +import sys import filecmp from typing import Optional, Union, List @@ -54,6 +55,21 @@ __all__ = [ "tree_update_from_cache", ] +def is_relative_to(path: Path, other: Path) -> bool: + """ + Dynamically choose implementation based on Python version + """ + # Python 3.9+ + if sys.version_info >= (3, 9): + return path.is_relative_to(other) + + # Python 3.8 and earlier + try: + path.relative_to(other) + return True + except ValueError: + return False + class FileTreeCache(): """ @@ -66,8 +82,8 @@ class FileTreeCache(): default_tmp_prefix = "__tmp_" def __init__(self, - src_path: Union[str|Path], - cache_path: Optional[Union[str|Path]] = None + src_path: Union[str, Path], + cache_path: Optional[Union[str, Path]] = None ) -> None: self.src_path = Path(src_path).absolute() self.cache_path = ( @@ -78,7 +94,7 @@ class FileTreeCache(): ) ctx_msg = f"tree_cache: {src_path = }, {cache_path = }" assert self.src_path != self.cache_path, f"src_path and cache_path must differ on {ctx_msg}" - assert not self.src_path.is_relative_to(self.cache_path), f"src_path must not be relative to cache_path on {ctx_msg}" + assert not is_relative_to(self.src_path, self.cache_path), f"src_path must not be relative to cache_path on {ctx_msg}" self._tmp_path = ( self.src_path.parent / f"{self.default_tmp_prefix}{self.src_path.name}") @@ -92,7 +108,7 @@ class FileTreeCache(): assert not dst_cache_dir.exists() assert src_dir.is_dir() assert not cache_dir.exists() or cache_dir.is_dir() - assert not cache_dir.is_relative_to(src_dir) + assert not is_relative_to(cache_dir, src_dir) def copy_or_cache(src, dst): base_src = Path(src).relative_to(src_dir) @@ -132,8 +148,8 @@ class FileTreeCache(): def tree_update_from_cache( - src_path: Union[str|Path], - cache_path: Optional[Union[str|Path]] = None) -> None: + src_path: Union[str, Path], + cache_path: Optional[Union[str, Path]] = None) -> None: """ Update from cache the current generation of a tree from the older generations, preserving file stamps when files contents are identical. diff --git a/aidge_core/testing/utils/tree_utils.py b/aidge_core/testing/utils/tree_utils.py index 3a6b2aad88e16075ed64bee03ba8e8fa550376e2..990ab2641f5aa914a9d8c03105c00d8c05d1243f 100644 --- a/aidge_core/testing/utils/tree_utils.py +++ b/aidge_core/testing/utils/tree_utils.py @@ -5,8 +5,9 @@ Provide utility function for file trees manipulations. """ import shutil +import sys from pathlib import Path -from typing import Union, Optional +from typing import Union __all__ = [ @@ -15,8 +16,24 @@ __all__ = [ ] +def is_relative_to(path: Path, other: Path) -> bool: + """ + Dynamically choose implementation based on Python version + """ + # Python 3.9+ + if sys.version_info >= (3, 9): + return path.is_relative_to(other) + + # Python 3.8 and earlier + try: + path.relative_to(other) + return True + except ValueError: + return False + + def tree_remove( - path: Union[str|Path], + path: Union[str, Path], ignore_missing: bool = False, ) -> None: """ @@ -35,8 +52,8 @@ def tree_remove( def tree_move( - src_path: Union[str|Path], - dst_path: Union[str|Path], + src_path: Union[str, Path], + dst_path: Union[str, Path], ignore_missing: bool = False, exist_ok: bool = False, ) -> None: @@ -56,8 +73,8 @@ def tree_move( assert ignore_missing or src_path.exists(), f"src_path must exists when ignore_missing is False on {ctx_msg}" assert exist_ok or not dst_path.exists(), f"dst_path must not exists when exist_ok is False on {ctx_msg}" assert src_path != dst_path, f"paths must not be identical on {ctx_msg}" - assert not dst_path.is_relative_to(src_path), f"dst_path must not be relative to src_path on {ctx_msg}" - assert not src_path.is_relative_to(dst_path), f"src_path must not be relative to dst_path on {ctx_msg}" + assert not is_relative_to(dst_path, src_path), f"dst_path must not be relative to src_path on {ctx_msg}" + assert not is_relative_to(src_path, dst_path), f"src_path must not be relative to dst_path on {ctx_msg}" if ignore_missing and not src_path.exists(): return if exist_ok and dst_path.exists(): diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index b915cb8f16546e6626e99e41f5f9ebb1c038863e..47eb6cf97e238f820b8ef2d1f3296040e73aa43f 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -36,7 +36,10 @@ public: std::shared_ptr<SequentialScheduler> mScheduler; std::weak_ptr<Node> mUpperNode; - public: +private: + const std::shared_ptr<DynamicAttributes> mAttributes = std::make_shared<DynamicAttributes>(); + +public: MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory = {}); /** @@ -92,7 +95,7 @@ public: mGraph->setDataType(datatype); } - std::shared_ptr<Attributes> attributes() const override; + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp index 9644620d71276c5e35fc9daaf634f4d4cdb28405..21633e451961148deda8a3f39afd1965b3002d2e 100644 --- a/include/aidge/operator/Stack.hpp +++ b/include/aidge/operator/Stack.hpp @@ -50,7 +50,7 @@ class StackOp public: static const std::string s_type; - StackOp(std::uint32_t maxElements); + StackOp(std::uint32_t maxElements = 0); /** * @brief Copy-constructor. Copy the operator attributes and its output @@ -71,6 +71,7 @@ class StackOp std::set<std::string> getAvailableBackends() const override; + bool dimsForwarded() const override final; bool forwardDims(bool allowDataDependency = false) override final; void forward() override; @@ -87,14 +88,14 @@ class StackOp } static const std::vector<std::string> getInputsName() { - return {"data_input"}; + return {"data_input", "max_elements"}; } static const std::vector<std::string> getOutputsName() { return {"data_output"}; } }; -std::shared_ptr<Node> stack(std::uint32_t maxElements, +std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0, const std::string &name = ""); } // namespace Aidge diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 6e08885fc3f8966fba48be1c55a6965ac9e70775..28ecde6d9319ae05be20f591cc9a6a4e2a29acc0 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -82,6 +82,14 @@ protected: std::chrono::time_point<std::chrono::high_resolution_clock> end; /** Actual end time of execution */ }; public: + enum class AvailableDataStatus { + Connected, + UpperNodeInputFound, + UpperNodeInputConnected, + ValidTensor, + NotConnected + }; + /** * @struct PriorProducersConsumers * @brief Manages producer-consumer relationships for nodes. @@ -179,7 +187,7 @@ protected: */ std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; - Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; + Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx, AvailableDataStatus& status) const; /** * @brief Get the prior producers and consumers for a node. @@ -233,4 +241,23 @@ protected: }; } // namespace Aidge +namespace Aidge { +inline auto format_as(Scheduler::AvailableDataStatus status) { + switch (status) { + case Scheduler::AvailableDataStatus::Connected: + return "The input is connected to a Node."; + case Scheduler::AvailableDataStatus::UpperNodeInputFound: + return "The input is an upper node input, but is not connected in any GraphView."; + case Scheduler::AvailableDataStatus::UpperNodeInputConnected: + return "The input is an upper node input and is connected to a Node."; + case Scheduler::AvailableDataStatus::ValidTensor: + return "The input is not connected in the current GraphView but has a valid tensor assigned."; + case Scheduler::AvailableDataStatus::NotConnected: + return "The input is not connected in the current GraphView."; + default: + return "UNKNOWN STATUS."; + } +} +} + #endif /* AIDGE_CORE_SCHEDULER_SCHEDULER_H_ */ diff --git a/python_binding/operator/pybind_Stack.cpp b/python_binding/operator/pybind_Stack.cpp index 2328892108d724438a39cc37eaf97b856caa3a8a..c9bd969faf714cacb0dbf44a0b0fe6e84281ffd8 100644 --- a/python_binding/operator/pybind_Stack.cpp +++ b/python_binding/operator/pybind_Stack.cpp @@ -29,8 +29,8 @@ void init_Stack(py::module &m) { .def_readonly_static("Type", &StackOp::s_type); m.def("Stack", - &stack, - py::arg("max_elements"), + &Stack, + py::arg("max_elements") = 0, py::arg("name") = "", R"mydelimiter( Initialize a node containing a Stack operator. diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 1708d9e36c174527c648e37b63b080211aa6df05..c74b538a4e566b3b88e77dd4d097344d52838505 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -95,7 +95,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) Log::debug("getBestMatch() for requirements: {}", requiredSpecs); const auto availableSpecsSet = getAvailableImplSpecs(); - AIDGE_ASSERT(availableSpecsSet.size() > 0 , + AIDGE_ASSERT(availableSpecsSet.size() > 0 , "OperatorImpl::getBestMatch(): No available specs found by" "getAvailableSpecs(). " "Cannot find best implementation for required specs, aborting."); @@ -139,7 +139,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) if (mandatory) { // Required attribute: if (!spec.attrs.hasAttr(name)) { - Log::debug("Could not find mandatory attribute {} value {}.", name); + Log::debug("Could not find mandatory attribute '{}'.", name); // Missing attribute match = false; break; diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 1354281933b69bb6e038587cc27ee0397d05c6f1..465359757eadd2799aa7f272e2d85b032a60cfdd 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -157,15 +157,15 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator()); if (op && !op->getOutput(outputIdx)->undefined()) { dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims()); - dtype += "\n" + fmt::format("{}", op->getOutput(outputIdx)->dataType()); + dtype += " " + fmt::format("{}", op->getOutput(outputIdx)->dataType()); } if (mNodes.find(child) != mNodes.end()) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}<br/>↓<br/>{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), outputIdx, dims, dtype, inputIdx, child->type(), namePtrTable.at(child)); } else if (verbose) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}<br/>↓<br/>{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), outputIdx, dims, dtype, inputIdx, static_cast<void*>(child.get())); } // Do no break here because the same child can be connected to several inputs @@ -182,11 +182,13 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd for (const auto& input : mInputNodes) { if (input.first != nullptr) { const auto& op_ = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator()); - if (op_->getInput(input.second) && (!op_->getInput(input.second)->empty())) { - fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"→{}{}\"|{}_{}\n", inputIdx, inputIdx, - input.second, op_->getInput(input.second)->dims(), input.first->type(), namePtrTable.at(input.first)); + if (op_->getInput(input.second) && (!op_->getInput(input.second)->undefined())) { + std::string dims = " " + fmt::format("{}", op_->getInput(input.second)->dims()); + std::string dtype = " " + fmt::format("{}", op_->getInput(input.second)->dataType()); + fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"{}{}<br/>↓<br/>{}\"|{}_{}\n", inputIdx, inputIdx, + dims, dtype, input.second, input.first->type(), namePtrTable.at(input.first)); } else { - fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"→{}\"|{}_{}\n", inputIdx, inputIdx, + fmt::print(fp.get(), "input{}((in#{})):::inputCls--->|\"↓<br/>{}\"|{}_{}\n", inputIdx, inputIdx, input.second, input.first->type(), namePtrTable.at(input.first)); } } @@ -201,14 +203,16 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd if (output.first != nullptr) { // Add-on to display the operator's output dimensions std::string dims = ""; + std::string dtype = ""; const auto op = std::dynamic_pointer_cast<OperatorTensor>(output.first->getOperator()); if (op && op->getOutput(output.second) && !op->getOutput(output.second)->undefined()) { dims += " " + fmt::format("{}", op->getOutput(output.second)->dims()); + dtype += " " + fmt::format("{}", op->getOutput(output.second)->dataType()); } - fmt::print(fp.get(), "{}_{}--->|\"{}{}→\"|output{}((out#{})):::outputCls\n", + fmt::print(fp.get(), "{}_{}--->|\"{}{}{}<br/>↓\"|output{}((out#{})):::outputCls\n", output.first->type(), namePtrTable.at(output.first), output.second, - dims, outputIdx, outputIdx); + dims, dtype, outputIdx, outputIdx); } else { fmt::print(fp.get(), "output{}((out#{})):::outputCls\n", outputIdx, outputIdx); @@ -1236,7 +1240,6 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const if (removeFromGraphs) { for (const auto& g : commonGraphViews) { g -> remove(nodePtr, false); - g -> updateInputsOutputsDelete(nodePtr); } nodePtr -> resetConnections(true); } diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 060c725482fedf4d6093e5acb988b2c721c27edc..cd307c9d15043d3ee5f5de48695e04e4ad2ada6b 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -48,6 +48,10 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second)); } } + + for (const auto& node : mGraph->getRankedNodesName("{1}_{3}")) { + mAttributes->addAttr(node.second, node.first->getOperator()->attributes()); + } } std::shared_ptr<Aidge::Operator> Aidge::MetaOperator_Op::clone() const { @@ -119,22 +123,6 @@ std::set<std::string> Aidge::MetaOperator_Op::getAvailableBackends() const { return backendsList; } -std::shared_ptr<Aidge::Attributes> Aidge::MetaOperator_Op::attributes() const { - auto attrs = std::make_shared<DynamicAttributes>(); - - for (const auto& node : mGraph->getRankedNodesName("{3}")) { - const auto attributes = node.first->getOperator()->attributes(); - if (attributes) { - const auto nodeAttrs = DynamicAttributes(attributes->getAttrs()); - attrs->addAttr(node.first->type() + "#" + node.second, nodeAttrs); - if (node.second == "0") { - attrs->addAttr(node.first->type(), nodeAttrs); - } - } - } - - return attrs; -} Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 586dbc2037d36d26f39dd06404b3b70b99270c1e..3bdb4b17127eb8a9115f8dec045db32bf041b00b 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -88,7 +88,7 @@ std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IO } const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const { - AIDGE_ASSERT(outputIdx < nbOutputs(), "{} Operator has {} outputs", type(), nbOutputs()); + AIDGE_ASSERT(outputIdx < nbOutputs(), "{} Operator has {} outputs, asked for output#{}", type(), nbOutputs(), outputIdx); return mOutputs[outputIdx]; } diff --git a/src/operator/Shape.cpp b/src/operator/Shape.cpp index 6de0854e8cdc166a3f938a166348db481956e792..ecaa12191173ac74ace8d6d224ddfe08469eb521 100644 --- a/src/operator/Shape.cpp +++ b/src/operator/Shape.cpp @@ -20,14 +20,8 @@ #include "aidge/utils/Types.h" void Aidge::Shape_OpImpl::forward() { - const Shape_Op& op = dynamic_cast<const Shape_Op&>(mOp); - const auto start = op.start(); - const auto end = op.end(); - - op.getOutput(0)->getImpl()->copyCast(std::next(op.getInput(0)->dims().data(), - start), - DataType::UInt64, - end - start + 1); + // Do nothing... + // Output is already valid after forwardDims() } /////////////////////////////////////////////// @@ -75,6 +69,11 @@ bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) { AIDGE_ASSERT(roi> 1, "Invalid ROI for Shape"); mOutputs[0]->resize({roi}); + // Ensure the output of this operator is valid after forwardDims(): + mOutputs[0]->getImpl()->copyCast(std::next(getInput(0)->dims().data(), + start), + DataType::UInt64, + end - start + 1); return true; } diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index 4ca7cc9831c091a8ea79051115decd489a4a03be..ab9ddc4f705cb00cebbe5b9ee68fb1433586a043 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -26,7 +26,7 @@ namespace Aidge { // inputSize Elts_t StackProdConso::getRequiredMemory( const Aidge::IOIndex_t inputIdx, - const std::vector<DimSize_t> &inputsSize) const { + const std::vector<DimSize_t> &/*inputsSize*/) const { assert(mOp.getRawInput(inputIdx) && "requires valid input"); const StackOp &op = dynamic_cast<const StackOp &>(mOp); @@ -62,15 +62,10 @@ void StackOpImpl::forward() { } StackOp::StackOp(std::uint32_t maxElements) - : OperatorTensor(s_type, {InputCategory::Data}, 1), + : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1), mAttributes(std::make_shared<Attributes_>( attr<StackAttr::MaxElements>(maxElements), attr<StackAttr::ForwardStep>(0))) { - if (maxElements == 0) { - AIDGE_THROW_OR_ABORT( - std::invalid_argument, - "StackOp creation failed: maxElements must be greater than 0."); - } mImpl = std::make_shared<StackOpImpl>(*this); } @@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const { return std::make_shared<StackOp>(*this); } -bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { +bool Aidge::StackOp::dimsForwarded() const { + if ((getInput(1) && !getInput(1)->undefined())) + { + // output dims are data dependent + return false; + } + + return OperatorTensor::dimsForwarded(); +} + +bool Aidge::StackOp::forwardDims(bool allowDataDependency) { if (inputsAssociated()) { + // Copy optional input #1 first dimension, if present, to attribute MaxElements + if (getInput(1)) { + if (!allowDataDependency) { + Log::warn("StackOp: unable to forwardDims() because output dims are data dependent on input#1"); + return false; + } + + std::shared_ptr<Tensor> fallback; + const auto& maxElements = getInput(1)->refCastFrom(fallback, NativeType<std::uint32_t>::type, "cpu"); + AIDGE_ASSERT(maxElements.size() > 0, "Input#1 size should be > 0"); + this->maxElements() = static_cast<std::uint32_t*>(maxElements.getImpl()->hostPtr())[0]; + } + + AIDGE_ASSERT(this->maxElements() > 0, "Input#1 first element or MaxElements attribute should be > 0"); + auto inputDims = getInput(0)->dims(); inputDims.insert(inputDims.begin(), maxElements()); getOutput(0)->resize(inputDims); @@ -116,7 +136,7 @@ void StackOp::forward() { ++forwardStep(); } -std::shared_ptr<Node> stack(std::uint32_t maxElements, +std::shared_ptr<Node> Stack(std::uint32_t maxElements, const std::string &name) { return std::make_shared<Node>(std::make_shared<StackOp>(maxElements), name); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 7af3c62c5d0af33b01e596ecf4c91c35ab3e17b7..2e9dc034ef0bf2b0be2ee27c26b1995a2d0e4244 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -82,6 +82,8 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera // the requiredProducers list. std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes(); std::set<std::shared_ptr<Node>> producers; + std::string level1Diagnostic; + std::string level2Diagnostic; do { // 2) From the current consumers list, check if any prior consumer node @@ -144,22 +146,37 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera // there is multiple successive priors for example). std::set<std::shared_ptr<Node>> runnableConsumers; Log::debug("Updated list of consumers:"); + level1Diagnostic.clear(); + level2Diagnostic.clear(); for (const auto& consumer : consumers) { summarizeConsumerState(consumer, namePtrTable.at(consumer)); // debug print bool isRunnable = true; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); + AvailableDataStatus status; if ((consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > - getNbAvailableData(consumer, inputIdx)) { + getNbAvailableData(consumer, inputIdx, status)) { Log::debug(" not runnable: C{} + R{} > P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), consumer->getOperator()->getNbRequiredData(inputIdx), - getNbAvailableData(consumer, inputIdx), inputIdx); + getNbAvailableData(consumer, inputIdx, status), inputIdx); // not enough data to run isRunnable = false; + if (status == Scheduler::AvailableDataStatus::UpperNodeInputFound + || status == Scheduler::AvailableDataStatus::NotConnected) + { + level1Diagnostic += fmt::format("- No data available for node {} input #{}. {}\n", namePtrTable.at(consumer), inputIdx, fmt::styled(status, fmt::fg(fmt::color::red))); + } + else { + level2Diagnostic += fmt::format("- No data available for node {} input #{}. {}\n", namePtrTable.at(consumer), inputIdx, fmt::styled(status, fmt::fg(fmt::color::green))); + level2Diagnostic += fmt::format(" ↳ Available data is {}, but {} was already consummed and {} more is required.\n", + getNbAvailableData(consumer, inputIdx, status), + consumer->getOperator()->getNbConsumedData(inputIdx), + consumer->getOperator()->getNbRequiredData(inputIdx)); + } break; } } @@ -204,12 +221,13 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { if (consumer->inputCategory(inputIdx) == InputCategory::Data) { AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); + AvailableDataStatus status; if (consumer->getOperator()->getNbConsumedData(inputIdx) < - getNbAvailableData(consumer, inputIdx)) { + getNbAvailableData(consumer, inputIdx, status)) { Log::debug(" still consumer: C{} < P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), - getNbAvailableData(consumer, inputIdx), inputIdx); + getNbAvailableData(consumer, inputIdx, status), inputIdx); // there is still data to consume isStillConsumer = true; @@ -293,7 +311,15 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera std::back_inserter(consumersName), [&namePtrTable](auto val){ return namePtrTable.at(val); }); - Log::warn("Remaining consumers: {}. Possible dead-lock.", consumersName); + Log::warn("Remaining consumers: {}.", consumersName); + + Log::info("Reasons:"); + if (!level1Diagnostic.empty()) { + Log::info(level1Diagnostic); + } + else { + Log::info(level2Diagnostic); + } } return schedule; @@ -650,23 +676,27 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers( return consumers; } -Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx, AvailableDataStatus& status) const { const auto parent = node->inputs()[inputIdx]; if (parent.first) { // Parent is connected, everything if fine! + status = AvailableDataStatus::Connected; return parent.first->getOperator()->getNbProducedData(parent.second); } else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) { - // We are inside an upper operator (for instance a MetaOperator) - // We need to connect the "local" producer-consumer model to the upper - // one, by mapping local node inputs to the upper node inputs. + // We are inside an upper operator (for instance a MetaOperator). + // Check if the node input is also an upper node input... IOIndex_t upperInputIdx = 0; for (const auto& input : mGraphView->getOrderedInputs()) { if (input.first == node && input.second == inputIdx) { - // Current node is an input + // Current node is an input! + // We need to connect the "local" producer-consumer model to the upper + // one, by mapping local node inputs to the upper node inputs. + status = AvailableDataStatus::UpperNodeInputFound; const auto upperInput = upperNode->inputs()[upperInputIdx]; if (upperInput.first) { + status = AvailableDataStatus::UpperNodeInputConnected; return upperInput.first->getOperator()->getNbProducedData(upperInput.second); } } @@ -678,6 +708,7 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& // - There is no data, it is assumed to be an optional input // - A valid tensor exists: if (node->getOperator()->getRawInput(inputIdx)) { + status = AvailableDataStatus::ValidTensor; // => This means data was fed manually to the input, without a Producer // In this case, we assume a single-use data (unlike a Producer, which // keep producing the data each time it is needed). @@ -685,6 +716,7 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& return Elts_t::DataElts(std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size()); } + status = AvailableDataStatus::NotConnected; return Elts_t::NoneElts(); } diff --git a/unit_tests/operator/Test_StackImpl.cpp b/unit_tests/operator/Test_StackImpl.cpp index d853a1ba27fea0a071c1c2373bbd7ef7f4eacd11..ccdf5787d666f030b8856704eb0e4fb108089075 100644 --- a/unit_tests/operator/Test_StackImpl.cpp +++ b/unit_tests/operator/Test_StackImpl.cpp @@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { REQUIRE(op2.maxElements() == maxElements); REQUIRE(op2.forwardStep() == 0); } - - // Invalid arguments - REQUIRE_THROWS_AS(StackOp(0), std::invalid_argument); } SECTION("forwardDims") { @@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems); } - auto myStack = stack(numTensors); + auto myStack = Stack(numTensors); myStack->getOperator()->setBackend("cpu"); myStack->getOperator()->setDataType(DataType::Float32); auto op =