diff --git a/include/aidge/graphRegex/GraphRegex.hpp b/include/aidge/graphRegex/GraphRegex.hpp index 12a5139a36135979639d2447b869568b943ee840..b62a42fcfeb258e5c659eaeb6681190482f37aa4 100644 --- a/include/aidge/graphRegex/GraphRegex.hpp +++ b/include/aidge/graphRegex/GraphRegex.hpp @@ -11,6 +11,11 @@ namespace Aidge{ +/** + * type for recipes function use in query and resolve +*/ +using RecipesFunctionType = std::function<void(std::shared_ptr<MatchSolution>)>; + /** * @brief class which is the hight level interface for graph matching, used to simplify match definition * @@ -19,9 +24,10 @@ class GraphRegex{ private: - std::vector<std::string> mQuery; + //std::vector<std::string> mQuery; std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest; std::map<std::string, std::function<bool(NodePtr)>> mAllLambda; + std::map<std::string,RecipesFunctionType> mQueryRecipe; public: GraphRegex(){}; @@ -31,7 +37,15 @@ class GraphRegex{ * @brief add a topology query to the match * @param query the topology query to find **/ - void addQuery(const std::string query); + //void addQuery(const std::string query); + + /** + * @brief add a topology query to the match and a function for recipe + * @param query the topology query to find + * @param f the funct + **/ + void addQuery(const std::string query,RecipesFunctionType f = nullptr); + /** * @brief get all the types of a graph and set it as type key in the query @@ -53,13 +67,19 @@ class GraphRegex{ **/ void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); - /*** + /** * @brief brief match the queries in the graph - * @param Reference the graph were the querys in search + * @param ref the graph were the querys in search * @return the result */ std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref); + /*** + * @brief match the queries in the graph and applied the recipes fuction + * @param ref the graph were the querys in search + */ + void appliedRecipes(std::shared_ptr<GraphView> ref); + private: void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp index 3e63f92337f6394382f6d92ef9f6dd7b5098a454..a6cc3e59247d4be98caa9881182bfba1c44e0178 100644 --- a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -116,7 +116,7 @@ namespace Aidge{ }; /** - * @brief class spesialisation for not commun node (node that must be match one Unique) transition + * @brief class specialization for not commun node (node that must be match one Unique) transition */ class FsmEdgeUnique:public FsmEdge { @@ -127,7 +127,7 @@ namespace Aidge{ }; /** - * @brief class spesialisation for commun node transition + * @brief class specialization for commun node transition * @see FsmEdge */ class FsmEdgeCommon:public FsmEdge @@ -181,7 +181,7 @@ namespace Aidge{ }; /** - * @brief class spesialisation for ref empty transition + * @brief class specialization for ref empty transition * @see FsmEdge */ class FsmEdgeEmpty:public FsmEdge @@ -195,6 +195,20 @@ namespace Aidge{ }; + /** + * @brief class specialization for ref empty transition + * @see FsmEdge + */ + class FsmEdgeNone:public FsmEdge + { + + public: + FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/) override; + + }; + + //////////////////////// // FACTORY diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c5cd9bb62e0097c9a0e646caaf14cddd73bf512d --- /dev/null +++ b/include/aidge/operator/Identity.hpp @@ -0,0 +1,126 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_IDENTITY_H_ +#define AIDGE_CORE_OPERATOR_IDENTITY_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +namespace Aidge { + +/** + * @brief Indentity_Op is an helper operator made to ease the declaration of MetaNodes. + * This Operator has no Implementation, it just forward its input Tensor. + * Note: Error may occur if new methods are added in Operator which use an implementation. + * Has we need to update this class to remove the use of Impl. + * + */ +class Identity_Op : public OperatorTensor, + public Registrable<Identity_Op, std::string, std::unique_ptr<OperatorImpl>(const Identity_Op&)> { +public: + static constexpr const char* Type = "Identity"; + + Identity_Op() + : OperatorTensor(Type, 1, 0, 0) + { + mImpl = std::make_shared<OperatorImpl>(*this); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Identity_Op(const Identity_Op& op) + : OperatorTensor(op) + { + mImpl = std::make_shared<OperatorImpl>(*this); + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Identity_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Identity_Op>(*this); + } + + void computeOutputDims() override final {} // Do nothing + + bool outputDimsForwarded() const override final { + if (mInputs[0]) + return !mInputs[0]->empty(); + else + return false; + } + + + void forward() override final { runHooks(); } + + void backward() override final { } + + void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final { + if (strcmp(data->type(), "Tensor") != 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as outputs", type().c_str()); + } + if (outputIdx >= nbInputs()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs()); + } + *mInputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data); + } + + void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final { + if (strcmp(data->type(), "Tensor") != 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str()); + } + if (outputIdx >= nbInputs()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs()); + } + *mInputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data)); + } + + const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const override final { + if (outputIdx >= nbInputs()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs()); + } + return mInputs[outputIdx]; + } + void setBackend(const std::string& name) override final { + // setBackend do nothing, Identity node has no backend it just pass the same Tensor + } + void setDataType(const DataType& dataType) const override final { + // setDatatype do nothing, Identity node has no backend it just pass the same Tensor + } + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Identity(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Identity_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_IDENTITY_H_ */ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 5775dd24d7d8b59b2ea945bcd15a38ddfdb71cd0..4c8feb46c3e3db33bd380302e3e0683f1b8734f5 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -74,12 +74,6 @@ public: void computeOutputDims() override final { // Forward dims of micro-graph mGraph->forwardDims(); - - // Associate outputs to micro-graph outputs for custom implementation - for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { - const auto& outputOp = mOutputOps[outputIdx]; - mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); - } } diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 6a29c6941d454a010c0e05fd4d47e99bf10ac5b2..61392470adaeb7db8812a3063edc5f8eee1d3083 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -51,9 +51,18 @@ void init_GraphView(py::module& m) { Include a Node to the current GraphView object. :param other_node: Node to add - :type oth_Node: Node - :param includeLearnableParameter: include non-data inputs, like weights and biases. Default True. - :type includeLearnableParameter: bool + :type other_node: Node + :param include_learnable_parameters: include non-data inputs, like weights and biases, default True. + :type include_learnable_parameters: bool, optional + )mydelimiter") + + .def("add", (void (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add, + py::arg("other_graph"), + R"mydelimiter( + Include a GraphView to the current GraphView object. + + :param other_graph: GraphView to add + :type other_graph: GraphView )mydelimiter") .def("add_child", @@ -105,4 +114,4 @@ void init_GraphView(py::module& m) { // }) ; } -} // namespace Aidge \ No newline at end of file +} // namespace Aidge diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index 3b63189c8c3cfca6e89c790c8bc294f03024c732..aa5c21372730536662106a035307d885fa011107 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -16,136 +16,150 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/graph/Connector.hpp" #include "aidge/utils/Types.h" namespace py = pybind11; namespace Aidge { void init_Node(py::module& m) { py::class_<Node, std::shared_ptr<Node>>(m, "Node") - .def("name", &Node::name, - R"mydelimiter( - Name of the Node. - )mydelimiter") - - .def("type", &Node::type, - R"mydelimiter( - Type of the node. - )mydelimiter") - - .def("get_operator", &Node::getOperator, - R"mydelimiter( - Get the Operator object of the Node. - )mydelimiter") - - .def("set_name", &Node::setName, py::arg("name"), - R"mydelimiter( - Set the Node name. - - :param name: New name for the node. - :type name: str - :rtype: str - )mydelimiter") - - .def("add_child", - (void (Node::*)(std::shared_ptr<Node>, const IOIndex_t, IOIndex_t)) & - Node::addChild, - py::arg("other_node"), py::arg("out_id") = 0, py::arg("other_in_id") = gk_IODefaultIndex, - R"mydelimiter( - Link another Node to an output of the current Node. - - :param other_node: Pointer to the other Node. - :type other_node: :py:class: Node - :param out_id: ID of the current Node output to connect to the other Node. Default to 0. - :type out_id: int - :param other_in_id: ID of the other Node input to connect to the current Node. Default to the first avaible data input. - :type other_in_id: int - )mydelimiter") - - .def("add_child", - (void (Node::*)(std::shared_ptr<GraphView>, const IOIndex_t, - std::pair<std::shared_ptr<Node>, IOIndex_t>)) & - Node::addChild, - py::arg("other_graph"), py::arg("out_id") = 0, - py::arg("other_in_id") = - std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex), - R"mydelimiter( - Link a Node from a specific GraphView to the current Node. - - :param other_view: Pointer to the GraphView whose content should be linked to the current Node. - :type other_view: :py:class: GraphView - :param out_id: ID of the current Node output to connect to the other Node. Default to 0. - :type out_id: int - :param other_in_id: Pair of Node and input connection ID for specifying the connection. If the GraphView whose content is linked has only one input Node, then it defaults to the first available data input ID of this Node. - :type other_in_id: tuple[:py:class: Node, int] - )mydelimiter") - - .def("inputs", &Node::inputs, - R"mydelimiter( - Get ordered list of parent Node and the associated output index connected to the current Node's inputs. - - :return: List of connections. When an input is not linked to any parent, the default value is (None, default_index) - :rtype: list[tuple[Node, int]] - )mydelimiter") - - .def("input", &Node::input, py::arg("in_id"), - R"mydelimiter( - Get the parent Node and the associated output index connected to the i-th input of the current Node. - - :param in_id: input index of the current Node object. - :type in_id: int - :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) - :rtype: tuple[Node, int] - )mydelimiter") - - .def("outputs", &Node::outputs, - R"mydelimiter( - Get, for each output of the Node, a list of the children Node and the associated input index connected to it. - - :return: List of a list of connections. When an outut is not linked to any child, its list a empty. - :rtype: list[list[tuple[Node, int]]] - )mydelimiter") - - .def("output", &Node::output, py::arg("out_id"), - R"mydelimiter( - Get a list of the children Node for a specific output and the associated input index connected to it. - - :param out_id: input index of the current Node object. - :type out_id: int - :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) - :rtype: list[tuple[Node, int]] - )mydelimiter") - - .def("get_nb_inputs", &Node::nbInputs, - R"mydelimiter( - Number of inputs. - - :rtype: int - )mydelimiter") - - .def("get_nb_data", &Node::nbData, - R"mydelimiter( - Number of data inputs. - - :rtype: int - )mydelimiter") - - .def("get_nb_outputs", &Node::nbOutputs, - R"mydelimiter( - Number of outputs. - - :rtype: int - )mydelimiter") - - .def("get_parents", &Node::getParents, - R"mydelimiter( - Get parents. - )mydelimiter") - - .def("get_children", (std::set<std::shared_ptr<Node>> (Node::*)() const) &Node::getChildren, - R"mydelimiter( - Get children. - )mydelimiter") - - .def("__call__", &Node::operator(), py::arg("connectors")); + .def("name", &Node::name, + R"mydelimiter( + Name of the Node. + )mydelimiter") + + .def("type", &Node::type, + R"mydelimiter( + Type of the node. + )mydelimiter") + + .def("get_operator", &Node::getOperator, + R"mydelimiter( + Get the Operator object of the Node. + )mydelimiter") + + .def("set_name", &Node::setName, py::arg("name"), + R"mydelimiter( + Set the Node name. + + :param name: New name for the node. + :type name: str + :rtype: str + )mydelimiter") + + .def("add_child", + (void (Node::*)(std::shared_ptr<Node>, const IOIndex_t, IOIndex_t)) & + Node::addChild, + py::arg("other_node"), py::arg("out_id") = 0, py::arg("other_in_id") = gk_IODefaultIndex, + R"mydelimiter( + Link another Node to an output of the current Node. + + :param other_node: Pointer to the other Node. + :type other_node: :py:class: Node + :param out_id: ID of the current Node output to connect to the other Node. Default to 0. + :type out_id: int + :param other_in_id: ID of the other Node input to connect to the current Node. Default to the first avaible data input. + :type other_in_id: int + )mydelimiter") + + .def("add_child", + (void (Node::*)(std::shared_ptr<GraphView>, const IOIndex_t, + std::pair<std::shared_ptr<Node>, IOIndex_t>)) & + Node::addChild, + py::arg("other_graph"), py::arg("out_id") = 0, + py::arg("other_in_id") = + std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex), + R"mydelimiter( + Link a Node from a specific GraphView to the current Node. + + :param other_view: Pointer to the GraphView whose content should be linked to the current Node. + :type other_view: :py:class: GraphView + :param out_id: ID of the current Node output to connect to the other Node. Default to 0. + :type out_id: int + :param other_in_id: Pair of Node and input connection ID for specifying the connection. If the GraphView whose content is linked has only one input Node, then it defaults to the first available data input ID of this Node. + :type other_in_id: tuple[:py:class: Node, int] + )mydelimiter") + + .def("inputs", &Node::inputs, + R"mydelimiter( + Get ordered list of parent Node and the associated output index connected to the current Node's inputs. + + :return: List of connections. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: list[tuple[Node, int]] + )mydelimiter") + + .def("input", &Node::input, py::arg("in_id"), + R"mydelimiter( + Get the parent Node and the associated output index connected to the i-th input of the current Node. + + :param in_id: input index of the current Node object. + :type in_id: int + :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: tuple[Node, int] + )mydelimiter") + + .def("outputs", &Node::outputs, + R"mydelimiter( + Get, for each output of the Node, a list of the children Node and the associated input index connected to it. + + :return: List of a list of connections. When an outut is not linked to any child, its list a empty. + :rtype: list[list[tuple[Node, int]]] + )mydelimiter") + + .def("output", &Node::output, py::arg("out_id"), + R"mydelimiter( + Get a list of the children Node for a specific output and the associated input index connected to it. + + :param out_id: input index of the current Node object. + :type out_id: int + :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: list[tuple[Node, int]] + )mydelimiter") + + .def("get_nb_inputs", &Node::nbInputs, + R"mydelimiter( + Number of inputs. + + :rtype: int + )mydelimiter") + + .def("get_nb_data", &Node::nbData, + R"mydelimiter( + Number of data inputs. + + :rtype: int + )mydelimiter") + + .def("get_nb_outputs", &Node::nbOutputs, + R"mydelimiter( + Number of outputs. + + :rtype: int + )mydelimiter") + + .def("get_parents", &Node::getParents, + R"mydelimiter( + Get parents. + )mydelimiter") + + .def("get_children", (std::set<std::shared_ptr<Node>> (Node::*)() const) &Node::getChildren, + R"mydelimiter( + Get children. + )mydelimiter") + + .def("__call__", + [](Node &self, pybind11::args args) { + std::vector<Connector> connectors; + for (const auto &arg : args) { + // Check if the argument is an instance of Connector + if (pybind11::isinstance<Connector>(arg)) { + // Convert Python object to C++ object adn push it ot vector + connectors.push_back(arg.cast<Connector>()); + } else { + throw std::runtime_error("One of the arguments was not a Connector."); + } + } + return self(connectors); + }); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Identity.cpp b/python_binding/operator/pybind_Identity.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b1b1e8888976c578ff490f35776c890ba59911dc --- /dev/null +++ b/python_binding/operator/pybind_Identity.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Identity.hpp" +#include "aidge/operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Identity(py::module& m) { + py::class_<Identity_Op, std::shared_ptr<Identity_Op>, Operator>(m, "IdentityOp", py::multiple_inheritance()) + .def("get_inputs_name", &Identity_Op::getInputsName) + .def("get_outputs_name", &Identity_Op::getOutputsName); + + m.def("Identity", &Identity, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index aa87ce28c17d9ba272ef8501a510014f391b381c..6df5a43f64bf8335108ccd99a1588a1367955b77 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -122,6 +122,15 @@ void init_MetaOperatorDefs(py::module &m) { declare_PaddedMaxPoolingOp<2>(m); declare_PaddedMaxPoolingOp<3>(m); + py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, Operator>(m, "MetaOperator_Op", py::multiple_inheritance()); + + m.def("meta_operator", &MetaOperator, + py::arg("type"), + py::arg("graph"), + py::arg("name") = "", + py::arg("input_nodes") = std::vector<NodePtr>(), + py::arg("output_nodes") = std::vector<NodePtr>() + ); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index f9482eda2f93b5492cfcc89175da69d140f23df8..559c4b0d6a4c8dd2f7cb7429e663e95a058d7f20 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -27,6 +27,7 @@ void init_Operator(py::module& m){ .def("nb_data", &Operator::nbData) .def("nb_param", &Operator::nbParam) .def("nb_outputs", &Operator::nbOutputs) + .def("output_dims_forwarded", &Operator::outputDimsForwarded) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDataType, py::arg("dataType")) .def("set_backend", &Operator::setBackend, py::arg("name")) diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index e625978ba2f15d4aff9e847e18ebc8076f31a165..23b54e46b23a341add8ba7291551c0f84f705bea 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -40,6 +40,7 @@ void init_ReLU(py::module&); void init_Softmax(py::module&); void init_Sqrt(py::module&); void init_Sub(py::module&); +void init_Identity(py::module&); void init_Node(py::module&); void init_GraphView(py::module&); @@ -85,6 +86,7 @@ void init_Aidge(py::module& m){ init_Softmax(m); init_Sqrt(m); init_Sub(m); + init_Identity(m); init_Producer(m); init_GraphRegex(m); diff --git a/src/graph/Connector.cpp b/src/graph/Connector.cpp index cd2ceff8b58076a5054269e4676120b94c8b5beb..98f58259a97b7c4194b29ae7b75a4140885ee122 100644 --- a/src/graph/Connector.cpp +++ b/src/graph/Connector.cpp @@ -41,6 +41,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ct std::vector<std::shared_ptr<Node>> parents = nodesToAdd.back()->getParents(); const std::set<std::shared_ptr<Node>>& alreadyAdded = graph->getNodes(); for (std::shared_ptr<Node> parent : parents) { + if (!parent) continue; if (alreadyAdded.find(parent) == alreadyAdded.end()) { buffer.push_back(parent); } @@ -51,4 +52,4 @@ std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ct buffer = {}; } return graph; -} \ No newline at end of file +} diff --git a/src/graphRegex/GraphFsmInterpreter.cpp b/src/graphRegex/GraphFsmInterpreter.cpp index 03e86487513065af47d91fc5265335bba456e64e..18b768c6567e64caf6841ed4a339f13fd16f69d6 100644 --- a/src/graphRegex/GraphFsmInterpreter.cpp +++ b/src/graphRegex/GraphFsmInterpreter.cpp @@ -128,7 +128,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs for(auto valid : allValid){ if(haveCommon){ /* - the // quantif case + the // quantify case get the go back and make a lexeme id(number) we need to go back to the ref delta min #TODO */ @@ -145,7 +145,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::REF,mNodesCondition, lexem.str()); }else{ /* - the sequensial quantif case + the sequencial quantify case no reference to common */ edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::EMPTY,mNodesCondition,""); diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp index ef0db8c88f3e753f9b9633b1ffb05bbec6d00424..9a9b53da615f77dbdb8e597763411a2e84920b2a 100644 --- a/src/graphRegex/GraphRegex.cpp +++ b/src/graphRegex/GraphRegex.cpp @@ -26,10 +26,17 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ -void GraphRegex::addQuery(const std::string query){ - mQuery.push_back(query); -} +// void GraphRegex::addQuery(const std::string query){ +// //TODO one query only but the same string is a same query but +// //2 different string it's maybe the same query , we need to check the AST +// mQueryRecipe[query] = nullptr; +// } + +void GraphRegex::addQuery(const std::string query,RecipesFunctionType f ){ + mQueryRecipe[query] = f; + +} // Function to generate all combinations of n elements from a set @@ -87,7 +94,9 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph std::vector<std::shared_ptr<MatchSolution>> solutions = {}; - for (const std::string& query : mQuery) { + //for (const std::string& query : mQuery) { + for (auto it = mQueryRecipe.begin(); it != mQueryRecipe.end(); ++it) { + const std::string query = it->first; std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); @@ -108,6 +117,15 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph return _findLargestCompatibleSet(solutions); } +void GraphRegex::appliedRecipes(std::shared_ptr<GraphView> ref){ + std::set<std::shared_ptr<MatchSolution>> matchRef = match(ref); + for (const auto& solution : matchRef) { + if(mQueryRecipe[solution->getQuery()] != nullptr){ + mQueryRecipe[solution->getQuery()](solution); + } + } +} + void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){ mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions)); _majConditionalInterpreterLambda(); diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp index ab307e023209ab770fc63f0550811279bd42eb46..d16dcf9505f5c3324fa621df2895065b7b019e19 100644 --- a/src/graphRegex/matchFsm/FsmEdge.cpp +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -226,6 +226,14 @@ const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext> } return {true,std::set<NodePtr>({opNode})};//none } +////////////// + +FsmEdgeNone::FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest) +:FsmEdge(source,dest,nullptr) +{} + const EdgeTestResult FsmEdgeNone::test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/){ + return {false,std::set<NodePtr>()}; + } /// factory std::shared_ptr<FsmEdge> FsmEdgeFactory::make( @@ -260,7 +268,10 @@ const std::string lexeme) std::string commonKey = edgeType + std::to_string(commonIdx); if(allTest.find(edgeType) == allTest.end()){ - throw std::invalid_argument("Bad Node Test " + edgeType ); + //if the key is not linked to a condition + //by default, it is initialized by a edge that is always false + return std::make_shared<FsmEdgeNone>(source, dest); + //throw std::invalid_argument("Bad Node Test " + edgeType ); } return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey); @@ -274,7 +285,11 @@ const std::string lexeme) std::string edgeType = m[1]; if(allTest.find(edgeType) == allTest.end()){ - throw std::invalid_argument("Bad Node Test " + edgeType ); + + //if the key is not linked to a condition + //by default, it is initialized by a edge that is always false + return std::make_shared<FsmEdgeNone>(source, dest); + //throw std::invalid_argument("Bad Node Test " + edgeType ); } return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType)); diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 28085759f62a2e54746af83fc31043f01808c6f2..bbc921d3c7b334223b2a92a8fbfee1ffae9c10e1 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -18,10 +18,6 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< : OperatorTensor(type, graph->dataInputs().size(), (graph->inputs().size() - graph->dataInputs().size()), graph->outputs().size()), mGraph(graph) { - // for (std::size_t i = 0; i < mInputs.size(); ++i) { - // mInputs[i] = std::make_shared<Tensor>(); - // } - // Fill inputsNodes and outputsNodes when there is no ambiguity if (inputNodes.empty()) { AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping"); @@ -66,8 +62,14 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr< } } + AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); + // Associate outputs to micro-graph outputs for custom implementation + for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { + const auto& outputOp = mOutputOps[outputIdx]; + mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + } } Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { @@ -110,6 +112,7 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() { mScheduler = std::make_shared<SequentialScheduler>(mGraph); } + // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule. // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()" mScheduler->generateScheduling(); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index ad8f257e4a7d335e438f10c58dd351c61192166d..4383f3bd74c3278eade5eca081cf34827863d596 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -2,6 +2,15 @@ #include <catch2/catch_test_macros.hpp> #include "aidge/graphRegex/GraphRegex.hpp" + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Recipies.hpp" + #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" @@ -46,13 +55,9 @@ TEST_CASE("GraphRegexUser") { } - SECTION("CC") { - - - + SECTION("2 query") { std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); @@ -81,4 +86,93 @@ TEST_CASE("GraphRegexUser") { } } + + + SECTION("Not define node Test") { + + //test if the FC is not define only match query not query2 + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 1, 1, "c3"); + + g1->add(conv); + g1->addChild(conv1, "c"); + g1->addChild(conv2, "c1"); + g1->addChild(conv3, "c2"); + + + //sut->setKeyFromGraph(g1); + + const std::string query = "Conv->Conv"; + const std::string query2 = "Conv->FC"; + + sut->setNodeKey("Conv","getType($) =='Conv'"); + + sut->addQuery(query); + sut->addQuery(query2); + + + for (const auto& solution : sut->match(g1)) { + REQUIRE(solution->getQuery() == query); + } + + } + + + SECTION("Applied Recipes"){ + + // generate the original GraphView + auto matmul0 = MatMul(5, "matmul0"); + auto add0 = Add<2>("add0"); + auto matmul1 = MatMul(5, "matmul1"); + auto add1 = Add<2>("add1"); + + auto b0 = Producer({5}, "B0"); + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(add0, 0, 0); + b0->addChild(add0, 0, 1); + + add0->addChild(matmul1, 0, 0); + w1->addChild(matmul1, 0, 1); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto fc = GenericOperator("FC", 1, 1, 1, "c"); + auto fl = GenericOperator("Flatten", 1, 1, 1, "c"); + + + auto g = std::make_shared<GraphView>(); + g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc}); + + std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); + + kitchenBook->setNodeKey("Add","getType($) =='Add'"); + kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'"); + kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); + kitchenBook->setNodeKey("FC","getType($) =='FC'"); + + kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); + kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); + + kitchenBook->appliedRecipes(g); + + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc})); + //REQUIRE(newNodes.size() == 6); + + + } + } \ No newline at end of file