Skip to content
Snippets Groups Projects
Commit fd0e25a7 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/main' into OperatorTensor

parents ae6cc90d 3cd94d47
No related branches found
No related tags found
No related merge requests found
Showing
with 507 additions and 160 deletions
......@@ -11,6 +11,11 @@
namespace Aidge{
/**
* type for recipes function use in query and resolve
*/
using RecipesFunctionType = std::function<void(std::shared_ptr<MatchSolution>)>;
/**
* @brief class which is the hight level interface for graph matching, used to simplify match definition
*
......@@ -19,9 +24,10 @@ class GraphRegex{
private:
std::vector<std::string> mQuery;
//std::vector<std::string> mQuery;
std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest;
std::map<std::string, std::function<bool(NodePtr)>> mAllLambda;
std::map<std::string,RecipesFunctionType> mQueryRecipe;
public:
GraphRegex(){};
......@@ -31,7 +37,15 @@ class GraphRegex{
* @brief add a topology query to the match
* @param query the topology query to find
**/
void addQuery(const std::string query);
//void addQuery(const std::string query);
/**
* @brief add a topology query to the match and a function for recipe
* @param query the topology query to find
* @param f the funct
**/
void addQuery(const std::string query,RecipesFunctionType f = nullptr);
/**
* @brief get all the types of a graph and set it as type key in the query
......@@ -53,13 +67,19 @@ class GraphRegex{
**/
void setNodeKey(const std::string key,std::function<bool(NodePtr)> f);
/***
/**
* @brief brief match the queries in the graph
* @param Reference the graph were the querys in search
* @param ref the graph were the querys in search
* @return the result
*/
std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref);
/***
* @brief match the queries in the graph and applied the recipes fuction
* @param ref the graph were the querys in search
*/
void appliedRecipes(std::shared_ptr<GraphView> ref);
private:
void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index,
......
......@@ -116,7 +116,7 @@ namespace Aidge{
};
/**
* @brief class spesialisation for not commun node (node that must be match one Unique) transition
* @brief class specialization for not commun node (node that must be match one Unique) transition
*/
class FsmEdgeUnique:public FsmEdge
{
......@@ -127,7 +127,7 @@ namespace Aidge{
};
/**
* @brief class spesialisation for commun node transition
* @brief class specialization for commun node transition
* @see FsmEdge
*/
class FsmEdgeCommon:public FsmEdge
......@@ -181,7 +181,7 @@ namespace Aidge{
};
/**
* @brief class spesialisation for ref empty transition
* @brief class specialization for ref empty transition
* @see FsmEdge
*/
class FsmEdgeEmpty:public FsmEdge
......@@ -195,6 +195,20 @@ namespace Aidge{
};
/**
* @brief class specialization for ref empty transition
* @see FsmEdge
*/
class FsmEdgeNone:public FsmEdge
{
public:
FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest);
const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/) override;
};
////////////////////////
// FACTORY
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_IDENTITY_H_
#define AIDGE_CORE_OPERATOR_IDENTITY_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge {
/**
* @brief Indentity_Op is an helper operator made to ease the declaration of MetaNodes.
* This Operator has no Implementation, it just forward its input Tensor.
* Note: Error may occur if new methods are added in Operator which use an implementation.
* Has we need to update this class to remove the use of Impl.
*
*/
class Identity_Op : public OperatorTensor,
public Registrable<Identity_Op, std::string, std::unique_ptr<OperatorImpl>(const Identity_Op&)> {
public:
static constexpr const char* Type = "Identity";
Identity_Op()
: OperatorTensor(Type, 1, 0, 0)
{
mImpl = std::make_shared<OperatorImpl>(*this);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Identity_Op(const Identity_Op& op)
: OperatorTensor(op)
{
mImpl = std::make_shared<OperatorImpl>(*this);
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Identity_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Identity_Op>(*this);
}
void computeOutputDims() override final {} // Do nothing
bool outputDimsForwarded() const override final {
if (mInputs[0])
return !mInputs[0]->empty();
else
return false;
}
void forward() override final { runHooks(); }
void backward() override final { }
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as outputs", type().c_str());
}
if (outputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs());
}
*mInputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs());
}
*mInputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const override final {
if (outputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs());
}
return mInputs[outputIdx];
}
void setBackend(const std::string& name) override final {
// setBackend do nothing, Identity node has no backend it just pass the same Tensor
}
void setDataType(const DataType& dataType) const override final {
// setDatatype do nothing, Identity node has no backend it just pass the same Tensor
}
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Identity(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Identity_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_IDENTITY_H_ */
......@@ -74,12 +74,6 @@ public:
void computeOutputDims() override final {
// Forward dims of micro-graph
mGraph->forwardDims();
// Associate outputs to micro-graph outputs for custom implementation
for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) {
const auto& outputOp = mOutputOps[outputIdx];
mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second);
}
}
......
......@@ -51,9 +51,18 @@ void init_GraphView(py::module& m) {
Include a Node to the current GraphView object.
:param other_node: Node to add
:type oth_Node: Node
:param includeLearnableParameter: include non-data inputs, like weights and biases. Default True.
:type includeLearnableParameter: bool
:type other_node: Node
:param include_learnable_parameters: include non-data inputs, like weights and biases, default True.
:type include_learnable_parameters: bool, optional
)mydelimiter")
.def("add", (void (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add,
py::arg("other_graph"),
R"mydelimiter(
Include a GraphView to the current GraphView object.
:param other_graph: GraphView to add
:type other_graph: GraphView
)mydelimiter")
.def("add_child",
......@@ -105,4 +114,4 @@ void init_GraphView(py::module& m) {
// })
;
}
} // namespace Aidge
\ No newline at end of file
} // namespace Aidge
......@@ -16,136 +16,150 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/Connector.hpp"
#include "aidge/utils/Types.h"
namespace py = pybind11;
namespace Aidge {
void init_Node(py::module& m) {
py::class_<Node, std::shared_ptr<Node>>(m, "Node")
.def("name", &Node::name,
R"mydelimiter(
Name of the Node.
)mydelimiter")
.def("type", &Node::type,
R"mydelimiter(
Type of the node.
)mydelimiter")
.def("get_operator", &Node::getOperator,
R"mydelimiter(
Get the Operator object of the Node.
)mydelimiter")
.def("set_name", &Node::setName, py::arg("name"),
R"mydelimiter(
Set the Node name.
:param name: New name for the node.
:type name: str
:rtype: str
)mydelimiter")
.def("add_child",
(void (Node::*)(std::shared_ptr<Node>, const IOIndex_t, IOIndex_t)) &
Node::addChild,
py::arg("other_node"), py::arg("out_id") = 0, py::arg("other_in_id") = gk_IODefaultIndex,
R"mydelimiter(
Link another Node to an output of the current Node.
:param other_node: Pointer to the other Node.
:type other_node: :py:class: Node
:param out_id: ID of the current Node output to connect to the other Node. Default to 0.
:type out_id: int
:param other_in_id: ID of the other Node input to connect to the current Node. Default to the first avaible data input.
:type other_in_id: int
)mydelimiter")
.def("add_child",
(void (Node::*)(std::shared_ptr<GraphView>, const IOIndex_t,
std::pair<std::shared_ptr<Node>, IOIndex_t>)) &
Node::addChild,
py::arg("other_graph"), py::arg("out_id") = 0,
py::arg("other_in_id") =
std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex),
R"mydelimiter(
Link a Node from a specific GraphView to the current Node.
:param other_view: Pointer to the GraphView whose content should be linked to the current Node.
:type other_view: :py:class: GraphView
:param out_id: ID of the current Node output to connect to the other Node. Default to 0.
:type out_id: int
:param other_in_id: Pair of Node and input connection ID for specifying the connection. If the GraphView whose content is linked has only one input Node, then it defaults to the first available data input ID of this Node.
:type other_in_id: tuple[:py:class: Node, int]
)mydelimiter")
.def("inputs", &Node::inputs,
R"mydelimiter(
Get ordered list of parent Node and the associated output index connected to the current Node's inputs.
:return: List of connections. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("input", &Node::input, py::arg("in_id"),
R"mydelimiter(
Get the parent Node and the associated output index connected to the i-th input of the current Node.
:param in_id: input index of the current Node object.
:type in_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: tuple[Node, int]
)mydelimiter")
.def("outputs", &Node::outputs,
R"mydelimiter(
Get, for each output of the Node, a list of the children Node and the associated input index connected to it.
:return: List of a list of connections. When an outut is not linked to any child, its list a empty.
:rtype: list[list[tuple[Node, int]]]
)mydelimiter")
.def("output", &Node::output, py::arg("out_id"),
R"mydelimiter(
Get a list of the children Node for a specific output and the associated input index connected to it.
:param out_id: input index of the current Node object.
:type out_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("get_nb_inputs", &Node::nbInputs,
R"mydelimiter(
Number of inputs.
:rtype: int
)mydelimiter")
.def("get_nb_data", &Node::nbData,
R"mydelimiter(
Number of data inputs.
:rtype: int
)mydelimiter")
.def("get_nb_outputs", &Node::nbOutputs,
R"mydelimiter(
Number of outputs.
:rtype: int
)mydelimiter")
.def("get_parents", &Node::getParents,
R"mydelimiter(
Get parents.
)mydelimiter")
.def("get_children", (std::set<std::shared_ptr<Node>> (Node::*)() const) &Node::getChildren,
R"mydelimiter(
Get children.
)mydelimiter")
.def("__call__", &Node::operator(), py::arg("connectors"));
.def("name", &Node::name,
R"mydelimiter(
Name of the Node.
)mydelimiter")
.def("type", &Node::type,
R"mydelimiter(
Type of the node.
)mydelimiter")
.def("get_operator", &Node::getOperator,
R"mydelimiter(
Get the Operator object of the Node.
)mydelimiter")
.def("set_name", &Node::setName, py::arg("name"),
R"mydelimiter(
Set the Node name.
:param name: New name for the node.
:type name: str
:rtype: str
)mydelimiter")
.def("add_child",
(void (Node::*)(std::shared_ptr<Node>, const IOIndex_t, IOIndex_t)) &
Node::addChild,
py::arg("other_node"), py::arg("out_id") = 0, py::arg("other_in_id") = gk_IODefaultIndex,
R"mydelimiter(
Link another Node to an output of the current Node.
:param other_node: Pointer to the other Node.
:type other_node: :py:class: Node
:param out_id: ID of the current Node output to connect to the other Node. Default to 0.
:type out_id: int
:param other_in_id: ID of the other Node input to connect to the current Node. Default to the first avaible data input.
:type other_in_id: int
)mydelimiter")
.def("add_child",
(void (Node::*)(std::shared_ptr<GraphView>, const IOIndex_t,
std::pair<std::shared_ptr<Node>, IOIndex_t>)) &
Node::addChild,
py::arg("other_graph"), py::arg("out_id") = 0,
py::arg("other_in_id") =
std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex),
R"mydelimiter(
Link a Node from a specific GraphView to the current Node.
:param other_view: Pointer to the GraphView whose content should be linked to the current Node.
:type other_view: :py:class: GraphView
:param out_id: ID of the current Node output to connect to the other Node. Default to 0.
:type out_id: int
:param other_in_id: Pair of Node and input connection ID for specifying the connection. If the GraphView whose content is linked has only one input Node, then it defaults to the first available data input ID of this Node.
:type other_in_id: tuple[:py:class: Node, int]
)mydelimiter")
.def("inputs", &Node::inputs,
R"mydelimiter(
Get ordered list of parent Node and the associated output index connected to the current Node's inputs.
:return: List of connections. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("input", &Node::input, py::arg("in_id"),
R"mydelimiter(
Get the parent Node and the associated output index connected to the i-th input of the current Node.
:param in_id: input index of the current Node object.
:type in_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: tuple[Node, int]
)mydelimiter")
.def("outputs", &Node::outputs,
R"mydelimiter(
Get, for each output of the Node, a list of the children Node and the associated input index connected to it.
:return: List of a list of connections. When an outut is not linked to any child, its list a empty.
:rtype: list[list[tuple[Node, int]]]
)mydelimiter")
.def("output", &Node::output, py::arg("out_id"),
R"mydelimiter(
Get a list of the children Node for a specific output and the associated input index connected to it.
:param out_id: input index of the current Node object.
:type out_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("get_nb_inputs", &Node::nbInputs,
R"mydelimiter(
Number of inputs.
:rtype: int
)mydelimiter")
.def("get_nb_data", &Node::nbData,
R"mydelimiter(
Number of data inputs.
:rtype: int
)mydelimiter")
.def("get_nb_outputs", &Node::nbOutputs,
R"mydelimiter(
Number of outputs.
:rtype: int
)mydelimiter")
.def("get_parents", &Node::getParents,
R"mydelimiter(
Get parents.
)mydelimiter")
.def("get_children", (std::set<std::shared_ptr<Node>> (Node::*)() const) &Node::getChildren,
R"mydelimiter(
Get children.
)mydelimiter")
.def("__call__",
[](Node &self, pybind11::args args) {
std::vector<Connector> connectors;
for (const auto &arg : args) {
// Check if the argument is an instance of Connector
if (pybind11::isinstance<Connector>(arg)) {
// Convert Python object to C++ object adn push it ot vector
connectors.push_back(arg.cast<Connector>());
} else {
throw std::runtime_error("One of the arguments was not a Connector.");
}
}
return self(connectors);
});
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Identity(py::module& m) {
py::class_<Identity_Op, std::shared_ptr<Identity_Op>, Operator>(m, "IdentityOp", py::multiple_inheritance())
.def("get_inputs_name", &Identity_Op::getInputsName)
.def("get_outputs_name", &Identity_Op::getOutputsName);
m.def("Identity", &Identity, py::arg("name") = "");
}
} // namespace Aidge
......@@ -122,6 +122,15 @@ void init_MetaOperatorDefs(py::module &m) {
declare_PaddedMaxPoolingOp<2>(m);
declare_PaddedMaxPoolingOp<3>(m);
py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, Operator>(m, "MetaOperator_Op", py::multiple_inheritance());
m.def("meta_operator", &MetaOperator,
py::arg("type"),
py::arg("graph"),
py::arg("name") = "",
py::arg("input_nodes") = std::vector<NodePtr>(),
py::arg("output_nodes") = std::vector<NodePtr>()
);
}
} // namespace Aidge
......@@ -27,6 +27,7 @@ void init_Operator(py::module& m){
.def("nb_data", &Operator::nbData)
.def("nb_param", &Operator::nbParam)
.def("nb_outputs", &Operator::nbOutputs)
.def("output_dims_forwarded", &Operator::outputDimsForwarded)
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDataType, py::arg("dataType"))
.def("set_backend", &Operator::setBackend, py::arg("name"))
......
......@@ -40,6 +40,7 @@ void init_ReLU(py::module&);
void init_Softmax(py::module&);
void init_Sqrt(py::module&);
void init_Sub(py::module&);
void init_Identity(py::module&);
void init_Node(py::module&);
void init_GraphView(py::module&);
......@@ -85,6 +86,7 @@ void init_Aidge(py::module& m){
init_Softmax(m);
init_Sqrt(m);
init_Sub(m);
init_Identity(m);
init_Producer(m);
init_GraphRegex(m);
......
......@@ -41,6 +41,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ct
std::vector<std::shared_ptr<Node>> parents = nodesToAdd.back()->getParents();
const std::set<std::shared_ptr<Node>>& alreadyAdded = graph->getNodes();
for (std::shared_ptr<Node> parent : parents) {
if (!parent) continue;
if (alreadyAdded.find(parent) == alreadyAdded.end()) {
buffer.push_back(parent);
}
......@@ -51,4 +52,4 @@ std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ct
buffer = {};
}
return graph;
}
\ No newline at end of file
}
......@@ -128,7 +128,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs
for(auto valid : allValid){
if(haveCommon){
/*
the // quantif case
the // quantify case
get the go back and make a lexeme id(number)
we need to go back to the ref delta min #TODO
*/
......@@ -145,7 +145,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs
edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::REF,mNodesCondition, lexem.str());
}else{
/*
the sequensial quantif case
the sequencial quantify case
no reference to common
*/
edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::EMPTY,mNodesCondition,"");
......
......@@ -26,10 +26,17 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){
void GraphRegex::addQuery(const std::string query){
mQuery.push_back(query);
}
// void GraphRegex::addQuery(const std::string query){
// //TODO one query only but the same string is a same query but
// //2 different string it's maybe the same query , we need to check the AST
// mQueryRecipe[query] = nullptr;
// }
void GraphRegex::addQuery(const std::string query,RecipesFunctionType f ){
mQueryRecipe[query] = f;
}
// Function to generate all combinations of n elements from a set
......@@ -87,7 +94,9 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
std::vector<std::shared_ptr<MatchSolution>> solutions = {};
for (const std::string& query : mQuery) {
//for (const std::string& query : mQuery) {
for (auto it = mQueryRecipe.begin(); it != mQueryRecipe.end(); ++it) {
const std::string query = it->first;
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest);
std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret();
......@@ -108,6 +117,15 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
return _findLargestCompatibleSet(solutions);
}
void GraphRegex::appliedRecipes(std::shared_ptr<GraphView> ref){
std::set<std::shared_ptr<MatchSolution>> matchRef = match(ref);
for (const auto& solution : matchRef) {
if(mQueryRecipe[solution->getQuery()] != nullptr){
mQueryRecipe[solution->getQuery()](solution);
}
}
}
void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){
mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions));
_majConditionalInterpreterLambda();
......
......@@ -226,6 +226,14 @@ const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext>
}
return {true,std::set<NodePtr>({opNode})};//none
}
//////////////
FsmEdgeNone::FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest)
:FsmEdge(source,dest,nullptr)
{}
const EdgeTestResult FsmEdgeNone::test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/){
return {false,std::set<NodePtr>()};
}
/// factory
std::shared_ptr<FsmEdge> FsmEdgeFactory::make(
......@@ -260,7 +268,10 @@ const std::string lexeme)
std::string commonKey = edgeType + std::to_string(commonIdx);
if(allTest.find(edgeType) == allTest.end()){
throw std::invalid_argument("Bad Node Test " + edgeType );
//if the key is not linked to a condition
//by default, it is initialized by a edge that is always false
return std::make_shared<FsmEdgeNone>(source, dest);
//throw std::invalid_argument("Bad Node Test " + edgeType );
}
return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey);
......@@ -274,7 +285,11 @@ const std::string lexeme)
std::string edgeType = m[1];
if(allTest.find(edgeType) == allTest.end()){
throw std::invalid_argument("Bad Node Test " + edgeType );
//if the key is not linked to a condition
//by default, it is initialized by a edge that is always false
return std::make_shared<FsmEdgeNone>(source, dest);
//throw std::invalid_argument("Bad Node Test " + edgeType );
}
return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType));
......
......@@ -18,10 +18,6 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
: OperatorTensor(type, graph->dataInputs().size(), (graph->inputs().size() - graph->dataInputs().size()), graph->outputs().size()),
mGraph(graph)
{
// for (std::size_t i = 0; i < mInputs.size(); ++i) {
// mInputs[i] = std::make_shared<Tensor>();
// }
// Fill inputsNodes and outputsNodes when there is no ambiguity
if (inputNodes.empty()) {
AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping");
......@@ -66,8 +62,14 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
}
}
AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size());
AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size());
// Associate outputs to micro-graph outputs for custom implementation
for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) {
const auto& outputOp = mOutputOps[outputIdx];
mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second);
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
......@@ -110,6 +112,7 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() {
mScheduler = std::make_shared<SequentialScheduler>(mGraph);
}
// TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule.
// It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()"
mScheduler->generateScheduling();
......
......@@ -2,6 +2,15 @@
#include <catch2/catch_test_macros.hpp>
#include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Recipies.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
......@@ -46,13 +55,9 @@ TEST_CASE("GraphRegexUser") {
}
SECTION("CC") {
SECTION("2 query") {
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
......@@ -81,4 +86,93 @@ TEST_CASE("GraphRegexUser") {
}
}
SECTION("Not define node Test") {
//test if the FC is not define only match query not query2
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 1, 1, "c3");
g1->add(conv);
g1->addChild(conv1, "c");
g1->addChild(conv2, "c1");
g1->addChild(conv3, "c2");
//sut->setKeyFromGraph(g1);
const std::string query = "Conv->Conv";
const std::string query2 = "Conv->FC";
sut->setNodeKey("Conv","getType($) =='Conv'");
sut->addQuery(query);
sut->addQuery(query2);
for (const auto& solution : sut->match(g1)) {
REQUIRE(solution->getQuery() == query);
}
}
SECTION("Applied Recipes"){
// generate the original GraphView
auto matmul0 = MatMul(5, "matmul0");
auto add0 = Add<2>("add0");
auto matmul1 = MatMul(5, "matmul1");
auto add1 = Add<2>("add1");
auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);
auto fc = GenericOperator("FC", 1, 1, 1, "c");
auto fl = GenericOperator("Flatten", 1, 1, 1, "c");
auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc});
std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>();
kitchenBook->setNodeKey("Add","getType($) =='Add'");
kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'");
kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'");
kitchenBook->setNodeKey("FC","getType($) =='FC'");
kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten));
kitchenBook->appliedRecipes(g);
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc}));
//REQUIRE(newNodes.size() == 6);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment