Skip to content
Snippets Groups Projects
Commit e704a292 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'main' of https://git-dscin.intra.cea.fr/aidge/aidge_core into main

parents 730ec4ab 7e3a1f92
No related branches found
No related tags found
No related merge requests found
...@@ -29,10 +29,10 @@ public: ...@@ -29,10 +29,10 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return std::size_t * @return std::size_t
*/ */
virtual NbElts_t getNbRequiredData(IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0;
// Amount of input data that cannot be overwritten during the execution. // Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0;
// Memory required at an output for a given input size. // Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0; virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0;
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return DimSize_t * @return DimSize_t
*/ */
virtual NbElts_t getNbConsumedData(IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0;
/** /**
* @brief TOtal amount of produced data ready to be used on a specific output. * @brief TOtal amount of produced data ready to be used on a specific output.
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
* @param outputIdx Index of the output analysed. * @param outputIdx Index of the output analysed.
* @return DimSize_t * @return DimSize_t
*/ */
virtual NbElts_t getNbProducedData(IOIndex_t outputIdx) const = 0; virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0;
virtual ~OperatorImpl() = default; virtual ~OperatorImpl() = default;
}; };
......
...@@ -60,7 +60,7 @@ public: ...@@ -60,7 +60,7 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return NbElts_t * @return NbElts_t
*/ */
NbElts_t getNbRequiredData(IOIndex_t inputIdx) const; NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
/** /**
* @brief Amount of data from a specific input actually used in one computation pass. * @brief Amount of data from a specific input actually used in one computation pass.
...@@ -68,7 +68,7 @@ public: ...@@ -68,7 +68,7 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return NbElts_t * @return NbElts_t
*/ */
NbElts_t getNbConsumedData(IOIndex_t inputIdx) const; NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/** /**
* @brief Amount of data ready to be used on a specific output. * @brief Amount of data ready to be used on a specific output.
...@@ -76,7 +76,7 @@ public: ...@@ -76,7 +76,7 @@ public:
* @param outputIdx Index of the output analysed. * @param outputIdx Index of the output analysed.
* @return NbElts_t * @return NbElts_t
*/ */
NbElts_t getNbProducedData(IOIndex_t outputIdx) const; NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
virtual void forward(); virtual void forward();
......
...@@ -17,9 +17,27 @@ ...@@ -17,9 +17,27 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_GRegex(py::module& m){ void init_GRegex(py::module& m){
py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex") py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex", "GRegex class combines a Node Regex and a list of Graph Regex that together describes a graph pattern as a graph regular expression. GRegex find patterns in a given graph that matches the graph regular expression.")
.def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps")) .def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps"), R"mydelimiter(
.def("match", &GRegex::match, py::arg("graphToMatch")) Constructor of GRegex
Parameters
----------
:param nodesRegex: Describe the conditions an operator has to fulfill. Dictionnary mapping a string (keys) to :py:class:`aidge_core.NodeRegex` (values).
:type nodesRegex: :py:class: `Dict`
:param seqRegexps: Describe the graph topological pattern. List of Graph Regex as strings.
:type seqRegexps: :py:class: `list`
)mydelimiter")
.def("match", &GRegex::match, py::arg("graphToMatch"), R"mydelimiter(
Launch the graph matching algorithm on a given graph. Returns the matched graph patterns in a :py:class: `aidge_core.Match`.
Parameters
----------
:param graphToMatch: The graph to perform the matching algorithm on.
:type graphToMatch: :py:class: `aidge_core.GraphView`
)mydelimiter")
; ;
} }
} }
...@@ -16,10 +16,16 @@ ...@@ -16,10 +16,16 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Match(py::module& m){ void init_Match(py::module& m){
py::class_<Match, std::shared_ptr<Match>>(m, "Match") py::class_<Match, std::shared_ptr<Match>>(m, "Match", "Match class stores the matched patterns resulting from a graph matching query. A matched pattern is the combinaison of the graph pattern start nodes and the set of all the nodes in the matched pattern (including the start nodes)")
.def(py::init<>()) .def(py::init<>())
.def("get_nb_match", &Match::getNbMatch) .def("get_nb_match", &Match::getNbMatch, R"mydelimiter(
.def("get_start_nodes", &Match::getStartNodes) The number of graph patterns matched
.def("get_match_nodes", &Match::getMatchNodes); )mydelimiter")
.def("get_start_nodes", &Match::getStartNodes, R"mydelimiter(
All matched graph patterns start nodes
)mydelimiter")
.def("get_match_nodes", &Match::getMatchNodes, R"mydelimiter(
All matched graph patterns sets of matched nodes
)mydelimiter");
} }
} }
...@@ -15,8 +15,16 @@ ...@@ -15,8 +15,16 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_NodeRegex(py::module& m){ void init_NodeRegex(py::module& m){
py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex") py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex", "NodeRegex class describes a condition to test on any operator. Current version only tests the type of the operator.")
.def(py::init<const std::string>(), py::arg("condition")) .def(py::init<const std::string>(), py::arg("condition"), R"mydelimiter(
Constructor of NodeRegex
Parameters
----------
:param condition: Condition to be fulfilled by an operator. Currently supported conditions are only the operator types.
:type condition: `string`
)mydelimiter")
; ;
} }
} }
...@@ -20,8 +20,24 @@ namespace py = pybind11; ...@@ -20,8 +20,24 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Recipies(py::module &m) { void init_Recipies(py::module &m) {
m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes")); m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter(
m.def("remove_flatten", &removeFlatten, py::arg("view")); Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator.
Parameters
----------
:param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of `aidge.node`
)mydelimiter");
m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
Parameters
----------
:param nodes: The flatten operator to remove.
:type nodes: list of `aidge.node`
)mydelimiter");
} }
} // namespace Aidge } // namespace Aidge
...@@ -16,13 +16,9 @@ ...@@ -16,13 +16,9 @@
#include "aidge/utils/Recipies.hpp" #include "aidge/utils/Recipies.hpp"
namespace Aidge { namespace Aidge {
void removeFlatten(std::shared_ptr<GraphView> view) { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
for (auto& nodePtr : view->getNodes()) { auto g = std::make_shared<GraphView>();
if (nodePtr->type() == "Flatten") { g->add(std::set<std::shared_ptr<Node>>({nodes}));
auto g = std::make_shared<GraphView>(); g->replaceWith({});
g->add(std::set<std::shared_ptr<Node>>({nodePtr}));
g->replaceWith({});
}
}
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment