diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index a8b4b0938dd674eec4ac67a0b5c5b05789134b27..7e022145d1eeaa8a2bd79afe69ca06ca57a62651 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -29,10 +29,10 @@ public: * @param inputIdx Index of the input analysed. * @return std::size_t */ - virtual NbElts_t getNbRequiredData(IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0; // Amount of input data that cannot be overwritten during the execution. - virtual NbElts_t getNbRequiredProtected(IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0; // Memory required at an output for a given input size. virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0; @@ -43,7 +43,7 @@ public: * @param inputIdx Index of the input analysed. * @return DimSize_t */ - virtual NbElts_t getNbConsumedData(IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0; /** * @brief TOtal amount of produced data ready to be used on a specific output. @@ -51,7 +51,7 @@ public: * @param outputIdx Index of the output analysed. * @return DimSize_t */ - virtual NbElts_t getNbProducedData(IOIndex_t outputIdx) const = 0; + virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0; virtual ~OperatorImpl() = default; }; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index e8ce72228c9313c0de301738dfc8facd8a87bdf9..bb20177b106f81c47723020e1d674fbc5b3f7974 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -60,7 +60,7 @@ public: * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbRequiredData(IOIndex_t inputIdx) const; + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; /** * @brief Amount of data from a specific input actually used in one computation pass. @@ -68,7 +68,7 @@ public: * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbConsumedData(IOIndex_t inputIdx) const; + NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Amount of data ready to be used on a specific output. @@ -76,7 +76,7 @@ public: * @param outputIdx Index of the output analysed. * @return NbElts_t */ - NbElts_t getNbProducedData(IOIndex_t outputIdx) const; + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; virtual void forward(); diff --git a/python_binding/graphmatching/pybind_GRegex.cpp b/python_binding/graphmatching/pybind_GRegex.cpp index 17cfa64bce0e421f6bfebfdf7fad583167716cb0..bda8f72f389f7224a207ff4453509744d23ea296 100644 --- a/python_binding/graphmatching/pybind_GRegex.cpp +++ b/python_binding/graphmatching/pybind_GRegex.cpp @@ -17,9 +17,27 @@ namespace py = pybind11; namespace Aidge { void init_GRegex(py::module& m){ - py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex") - .def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps")) - .def("match", &GRegex::match, py::arg("graphToMatch")) + py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex", "GRegex class combines a Node Regex and a list of Graph Regex that together describes a graph pattern as a graph regular expression. GRegex find patterns in a given graph that matches the graph regular expression.") + .def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps"), R"mydelimiter( + Constructor of GRegex + + Parameters + ---------- + :param nodesRegex: Describe the conditions an operator has to fulfill. Dictionnary mapping a string (keys) to :py:class:`aidge_core.NodeRegex` (values). + :type nodesRegex: :py:class: `Dict` + :param seqRegexps: Describe the graph topological pattern. List of Graph Regex as strings. + :type seqRegexps: :py:class: `list` + + )mydelimiter") + .def("match", &GRegex::match, py::arg("graphToMatch"), R"mydelimiter( + Launch the graph matching algorithm on a given graph. Returns the matched graph patterns in a :py:class: `aidge_core.Match`. + + Parameters + ---------- + :param graphToMatch: The graph to perform the matching algorithm on. + :type graphToMatch: :py:class: `aidge_core.GraphView` + + )mydelimiter") ; } } diff --git a/python_binding/graphmatching/pybind_Match.cpp b/python_binding/graphmatching/pybind_Match.cpp index c36ca381fe1f370ef03f3df3d34318e800284350..600d39208e309ef64d9803c39cf23636d1a64869 100644 --- a/python_binding/graphmatching/pybind_Match.cpp +++ b/python_binding/graphmatching/pybind_Match.cpp @@ -16,10 +16,16 @@ namespace py = pybind11; namespace Aidge { void init_Match(py::module& m){ - py::class_<Match, std::shared_ptr<Match>>(m, "Match") + py::class_<Match, std::shared_ptr<Match>>(m, "Match", "Match class stores the matched patterns resulting from a graph matching query. A matched pattern is the combinaison of the graph pattern start nodes and the set of all the nodes in the matched pattern (including the start nodes)") .def(py::init<>()) - .def("get_nb_match", &Match::getNbMatch) - .def("get_start_nodes", &Match::getStartNodes) - .def("get_match_nodes", &Match::getMatchNodes); + .def("get_nb_match", &Match::getNbMatch, R"mydelimiter( + The number of graph patterns matched + )mydelimiter") + .def("get_start_nodes", &Match::getStartNodes, R"mydelimiter( + All matched graph patterns start nodes + )mydelimiter") + .def("get_match_nodes", &Match::getMatchNodes, R"mydelimiter( + All matched graph patterns sets of matched nodes + )mydelimiter"); } } diff --git a/python_binding/graphmatching/pybind_NodeRegex.cpp b/python_binding/graphmatching/pybind_NodeRegex.cpp index 76e91c495abfadbe930a260ecde2696fcd41fef4..6bbb74764afafbce3e3356027212c121895e4533 100644 --- a/python_binding/graphmatching/pybind_NodeRegex.cpp +++ b/python_binding/graphmatching/pybind_NodeRegex.cpp @@ -15,8 +15,16 @@ namespace py = pybind11; namespace Aidge { void init_NodeRegex(py::module& m){ - py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex") - .def(py::init<const std::string>(), py::arg("condition")) + py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex", "NodeRegex class describes a condition to test on any operator. Current version only tests the type of the operator.") + .def(py::init<const std::string>(), py::arg("condition"), R"mydelimiter( + Constructor of NodeRegex + + Parameters + ---------- + :param condition: Condition to be fulfilled by an operator. Currently supported conditions are only the operator types. + :type condition: `string` + + )mydelimiter") ; } } diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 4746aeacf2bf6ad43306b60660ecd0bace0a7ab5..b4147dcb4fb82dbfe9f5b4605604725c6945ece9 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -20,8 +20,24 @@ namespace py = pybind11; namespace Aidge { void init_Recipies(py::module &m) { - m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes")); - m.def("remove_flatten", &removeFlatten, py::arg("view")); + m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator. + + Parameters + ---------- + :param nodes: The MatMul and Add nodes to fuse. + :type nodes: list of `aidge.node` + + )mydelimiter"); + m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter( + Recipie to remove a flatten operator. + + Parameters + ---------- + :param nodes: The flatten operator to remove. + :type nodes: list of `aidge.node` + + )mydelimiter"); } } // namespace Aidge diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index 23a9d645af52d0fe6af0b4d03f773a550b104e0c..cc3c3324e40636a1edcbc73cdc4a9dcfeec8a026 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -16,13 +16,9 @@ #include "aidge/utils/Recipies.hpp" namespace Aidge { - void removeFlatten(std::shared_ptr<GraphView> view) { - for (auto& nodePtr : view->getNodes()) { - if (nodePtr->type() == "Flatten") { - auto g = std::make_shared<GraphView>(); - g->add(std::set<std::shared_ptr<Node>>({nodePtr})); - g->replaceWith({}); - } - } + void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { + auto g = std::make_shared<GraphView>(); + g->add(std::set<std::shared_ptr<Node>>({nodes})); + g->replaceWith({}); } } \ No newline at end of file