Skip to content
Snippets Groups Projects
Commit f0acf0de authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

Merge branch aidge_core:main into vit_operators

parents 02fdf34a 3cd94d47
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
......@@ -11,6 +11,11 @@
namespace Aidge{
/**
* type for recipes function use in query and resolve
*/
using RecipesFunctionType = std::function<void(std::shared_ptr<MatchSolution>)>;
/**
* @brief class which is the hight level interface for graph matching, used to simplify match definition
*
......@@ -19,9 +24,10 @@ class GraphRegex{
private:
std::vector<std::string> mQuery;
//std::vector<std::string> mQuery;
std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest;
std::map<std::string, std::function<bool(NodePtr)>> mAllLambda;
std::map<std::string,RecipesFunctionType> mQueryRecipe;
public:
GraphRegex(){};
......@@ -31,7 +37,15 @@ class GraphRegex{
* @brief add a topology query to the match
* @param query the topology query to find
**/
void addQuery(const std::string query);
//void addQuery(const std::string query);
/**
* @brief add a topology query to the match and a function for recipe
* @param query the topology query to find
* @param f the funct
**/
void addQuery(const std::string query,RecipesFunctionType f = nullptr);
/**
* @brief get all the types of a graph and set it as type key in the query
......@@ -53,13 +67,19 @@ class GraphRegex{
**/
void setNodeKey(const std::string key,std::function<bool(NodePtr)> f);
/***
/**
* @brief brief match the queries in the graph
* @param Reference the graph were the querys in search
* @param ref the graph were the querys in search
* @return the result
*/
std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref);
/***
* @brief match the queries in the graph and applied the recipes fuction
* @param ref the graph were the querys in search
*/
void appliedRecipes(std::shared_ptr<GraphView> ref);
private:
void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index,
......
......@@ -116,7 +116,7 @@ namespace Aidge{
};
/**
* @brief class spesialisation for not commun node (node that must be match one Unique) transition
* @brief class specialization for not commun node (node that must be match one Unique) transition
*/
class FsmEdgeUnique:public FsmEdge
{
......@@ -127,7 +127,7 @@ namespace Aidge{
};
/**
* @brief class spesialisation for commun node transition
* @brief class specialization for commun node transition
* @see FsmEdge
*/
class FsmEdgeCommon:public FsmEdge
......@@ -181,7 +181,7 @@ namespace Aidge{
};
/**
* @brief class spesialisation for ref empty transition
* @brief class specialization for ref empty transition
* @see FsmEdge
*/
class FsmEdgeEmpty:public FsmEdge
......@@ -195,6 +195,20 @@ namespace Aidge{
};
/**
* @brief class specialization for ref empty transition
* @see FsmEdge
*/
class FsmEdgeNone:public FsmEdge
{
public:
FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest);
const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/) override;
};
////////////////////////
// FACTORY
......
......@@ -78,12 +78,6 @@ public:
void computeOutputDims() override final {
// Forward dims of micro-graph
mGraph->forwardDims();
// Associate outputs to micro-graph outputs for custom implementation
for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) {
const auto& outputOp = mOutputOps[outputIdx];
mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second);
}
}
bool outputDimsForwarded() const override final { return !(mOutputs[0]->empty()); }
......
......@@ -128,7 +128,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs
for(auto valid : allValid){
if(haveCommon){
/*
the // quantif case
the // quantify case
get the go back and make a lexeme id(number)
we need to go back to the ref delta min #TODO
*/
......@@ -145,7 +145,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fs
edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::REF,mNodesCondition, lexem.str());
}else{
/*
the sequensial quantif case
the sequencial quantify case
no reference to common
*/
edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::EMPTY,mNodesCondition,"");
......
......@@ -26,10 +26,17 @@ void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){
void GraphRegex::addQuery(const std::string query){
mQuery.push_back(query);
}
// void GraphRegex::addQuery(const std::string query){
// //TODO one query only but the same string is a same query but
// //2 different string it's maybe the same query , we need to check the AST
// mQueryRecipe[query] = nullptr;
// }
void GraphRegex::addQuery(const std::string query,RecipesFunctionType f ){
mQueryRecipe[query] = f;
}
// Function to generate all combinations of n elements from a set
......@@ -87,7 +94,9 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
std::vector<std::shared_ptr<MatchSolution>> solutions = {};
for (const std::string& query : mQuery) {
//for (const std::string& query : mQuery) {
for (auto it = mQueryRecipe.begin(); it != mQueryRecipe.end(); ++it) {
const std::string query = it->first;
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest);
std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret();
......@@ -108,6 +117,15 @@ std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<Graph
return _findLargestCompatibleSet(solutions);
}
void GraphRegex::appliedRecipes(std::shared_ptr<GraphView> ref){
std::set<std::shared_ptr<MatchSolution>> matchRef = match(ref);
for (const auto& solution : matchRef) {
if(mQueryRecipe[solution->getQuery()] != nullptr){
mQueryRecipe[solution->getQuery()](solution);
}
}
}
void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){
mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions));
_majConditionalInterpreterLambda();
......
......@@ -226,6 +226,14 @@ const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext>
}
return {true,std::set<NodePtr>({opNode})};//none
}
//////////////
FsmEdgeNone::FsmEdgeNone(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest)
:FsmEdge(source,dest,nullptr)
{}
const EdgeTestResult FsmEdgeNone::test(const std::shared_ptr<FsmRunTimeContext> /*stmContext*/){
return {false,std::set<NodePtr>()};
}
/// factory
std::shared_ptr<FsmEdge> FsmEdgeFactory::make(
......@@ -260,7 +268,10 @@ const std::string lexeme)
std::string commonKey = edgeType + std::to_string(commonIdx);
if(allTest.find(edgeType) == allTest.end()){
throw std::invalid_argument("Bad Node Test " + edgeType );
//if the key is not linked to a condition
//by default, it is initialized by a edge that is always false
return std::make_shared<FsmEdgeNone>(source, dest);
//throw std::invalid_argument("Bad Node Test " + edgeType );
}
return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey);
......@@ -274,7 +285,11 @@ const std::string lexeme)
std::string edgeType = m[1];
if(allTest.find(edgeType) == allTest.end()){
throw std::invalid_argument("Bad Node Test " + edgeType );
//if the key is not linked to a condition
//by default, it is initialized by a edge that is always false
return std::make_shared<FsmEdgeNone>(source, dest);
//throw std::invalid_argument("Bad Node Test " + edgeType );
}
return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType));
......
......@@ -22,10 +22,6 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
for (std::size_t i = 0; i < mInputs.size(); ++i) {
mInputs[i] = std::make_shared<Tensor>();
}
mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size());
for (std::size_t i = 0; i < mOutputs.size(); ++i) {
mOutputs[i] = std::make_shared<Tensor>();
}
// Fill inputsNodes and outputsNodes when there is no ambiguity
if (inputNodes.empty()) {
......@@ -46,7 +42,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph");
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs();
int inputIdx = 0; // input idx relative to the current node
for (const auto& in : inputNodeinputs) {
if (in.first == nullptr || !mGraph->inView(in.first)) {
......@@ -71,8 +67,15 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
}
}
AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size());
AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size());
mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size());
// Associate outputs to micro-graph outputs for custom implementation
for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) {
const auto& outputOp = mOutputOps[outputIdx];
mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second);
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
......@@ -114,7 +117,7 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() {
// Lazy initialization
mScheduler = std::make_shared<SequentialScheduler>(mGraph);
}
// TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule.
// It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()"
mScheduler->generateScheduling();
......
......@@ -2,6 +2,15 @@
#include <catch2/catch_test_macros.hpp>
#include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Recipies.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
......@@ -46,13 +55,9 @@ TEST_CASE("GraphRegexUser") {
}
SECTION("CC") {
SECTION("2 query") {
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1");
......@@ -81,4 +86,93 @@ TEST_CASE("GraphRegexUser") {
}
}
SECTION("Not define node Test") {
//test if the FC is not define only match query not query2
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 1, 1, "c3");
g1->add(conv);
g1->addChild(conv1, "c");
g1->addChild(conv2, "c1");
g1->addChild(conv3, "c2");
//sut->setKeyFromGraph(g1);
const std::string query = "Conv->Conv";
const std::string query2 = "Conv->FC";
sut->setNodeKey("Conv","getType($) =='Conv'");
sut->addQuery(query);
sut->addQuery(query2);
for (const auto& solution : sut->match(g1)) {
REQUIRE(solution->getQuery() == query);
}
}
SECTION("Applied Recipes"){
// generate the original GraphView
auto matmul0 = MatMul(5, "matmul0");
auto add0 = Add<2>("add0");
auto matmul1 = MatMul(5, "matmul1");
auto add1 = Add<2>("add1");
auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);
auto fc = GenericOperator("FC", 1, 1, 1, "c");
auto fl = GenericOperator("Flatten", 1, 1, 1, "c");
auto g = std::make_shared<GraphView>();
g->add({matmul0, add0, matmul1, add1, b0, b1,fl,fc});
std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>();
kitchenBook->setNodeKey("Add","getType($) =='Add'");
kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'");
kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'");
kitchenBook->setNodeKey("FC","getType($) =='FC'");
kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten));
kitchenBook->appliedRecipes(g);
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc}));
//REQUIRE(newNodes.size() == 6);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment