Skip to content
Snippets Groups Projects
Commit 7932a2c1 authored by vincent  lorrain's avatar vincent lorrain
Browse files

[fix] ConditionalInterpreter

parent 19a41f8e
No related branches found
No related tags found
1 merge request!48Refactor/recipies
Pipeline #34142 failed
......@@ -22,7 +22,7 @@ namespace Aidge{
/////////////////////////////
/**
* @brief class used to register any lambda function without context,
* it encapsulates the source lambda in a lambda which takes as argument ConditionalData* which are any type.
* it encapsulates the source lambda in a lambda which takes as argument std::shared_ptr<ConditionalData> which are any type.
* @see ConditionalData
*/
class ConditionalRegisterFunction {
......@@ -31,12 +31,12 @@ class ConditionalRegisterFunction {
//////////////////////////
/**
* @brief recast the ConditionalData* to the argument type of the lambda
* @brief recast the std::shared_ptr<ConditionalData> to the argument type of the lambda
* @tparam T type of the lambda argument
* @see ConditionalData
*/
template <typename T>
T safeCastInput(ConditionalData* data) {
T safeCastInput( std::shared_ptr<ConditionalData> data) {
//cnvertion and type cheking
if (data->isTypeEqualTo<T>()){
return data->getValue<T>();
......@@ -48,14 +48,14 @@ class ConditionalRegisterFunction {
/**
* @brief recaste the output of the lambda to a ConditionalData*
* @brief recaste the output of the lambda to a std::shared_ptr<ConditionalData>
* @tparam T type of the lambda return
* @see ConditionalData
*/
template <typename T>
ConditionalData* safeCastOutput(T data) {
std::shared_ptr<ConditionalData> safeCastOutput(T data) {
ConditionalData* out = new ConditionalData;
std::shared_ptr<ConditionalData> out = std::make_shared<ConditionalData>();
out->setValue<T>(data);
return out;
......@@ -111,11 +111,11 @@ class ConditionalRegisterFunction {
};
/////////////////////
//change the function to ConditionalData*(std::vector<ConditionalData*>)
//change the function to std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>)
/////////////////////
/**
* @brief Converts a function to a ConditionalData*(std::vector<ConditionalData*>).
* @brief Converts a function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>).
* @tparam F The type of the function to convert.
* @tparam ParamsIdx The indices of the function parameters.
* @param f The function to convert.
......@@ -124,25 +124,31 @@ class ConditionalRegisterFunction {
template <class F, std::size_t... ParamsIdx>
auto funcPointer(F f, std::index_sequence<ParamsIdx...>) {
//wrapp the lambda in a new one that as ConditionalData as inputs and output
return [this,f](std::vector<ConditionalData*> &args) {
if (args.size() != sizeof...(ParamsIdx)){
return [this,f](std::vector< std::shared_ptr<ConditionalData>> &args) {
if (args.size() < sizeof...(ParamsIdx)){
std::ostringstream errorMessage;
errorMessage << "bad Number of argument: get " << args.size() << " need " << sizeof...(ParamsIdx) << "\n";
throw std::runtime_error(errorMessage.str());
}
//assert(args.size() == sizeof...(ParamsIdx));//the size of the vector valide
//we used std::vector< std::shared_ptr<ConditionalData>> as a fifo
std::size_t offset = args.size()-sizeof...(ParamsIdx);
using FuncTraits = function_traits<decltype(f)>;
using outType = typename FuncTraits::return_type;
outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[ParamsIdx])...);
outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[offset+ParamsIdx])...);
//suppress what we used
for (size_t i = 0; i < sizeof...(ParamsIdx); ++i) {
args.pop_back();
}
//typename
return safeCastOutput<outType>(result);
};
}
/**
* @brief Converts a function pointer to a ConditionalData*(std::vector<ConditionalData*>).
* @brief Converts a function pointer to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>).
* @tparam R The return type of the function.
* @tparam Params The parameter types of the function.
* @param f The function pointer to convert.
......@@ -154,7 +160,7 @@ class ConditionalRegisterFunction {
}
/**
* @brief Converts a std::function to a ConditionalData*(std::vector<ConditionalData*>).
* @brief Converts a std::function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>).
* @tparam R The return type of the function.
* @tparam Params The parameter types of the function.
* @param f The function pointer to convert.
......@@ -196,7 +202,7 @@ class ConditionalRegisterFunction {
* @param datas The vector of input data.
* @return A pointer to the output ConditionalData object.
*/
ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas);
std::shared_ptr<ConditionalData> run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas);
bool isLambdaRegister(const std::string &key) {
if(mWlambda.find(key) != mWlambda.end()){
......@@ -207,7 +213,7 @@ class ConditionalRegisterFunction {
private:
/// @brief map of name and the converted function.
std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda;
std::map<const std::string, std::function< std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>> &)>> mWlambda;
};
///////////////////
......@@ -237,15 +243,15 @@ class ConditionalInterpreter
ConditionalRegisterFunction mLambdaRegister;
std::vector<ConditionalData*> mResolution ;
std::vector< std::shared_ptr<ConditionalData>> mResolution ;
void clearRes(){
// void clearRes(){
for (std::size_t i = 0; i < mResolution.size(); ++i) {
delete mResolution[i];
}
mResolution.clear();
}
// for (std::size_t i = 0; i < mResolution.size(); ++i) {
// delete mResolution[i];
// }
// mResolution.clear();
// }
public:
......@@ -258,7 +264,7 @@ class ConditionalInterpreter
ConditionalInterpreter(const std::string key,const std::string ConditionalExpressions);
~ConditionalInterpreter(){clearRes();}
~ConditionalInterpreter(){}
/**
* @brief get the condition key
......@@ -293,12 +299,12 @@ class ConditionalInterpreter
* @param NodeOp The node currently being tested
* @param nodes The AST given by the parsing process
*/
std::vector<ConditionalData*> visit(const ASTNodeCh& nodes, const NodePtr NodeOp );
std::vector< std::shared_ptr<ConditionalData>> visit(const ASTNodeCh& nodes, const NodePtr NodeOp );
/**
* @defgroup ASTnodeInterpreterF Functions for interpreting AST nodes
* @brief For each node type in the AST, function defines the processing to be performed
* they return a std::vector<ConditionalData*> which corresponds to the value(s) obtained
* they return a std::vector< std::shared_ptr<ConditionalData>> which corresponds to the value(s) obtained
*/
/**
......@@ -308,38 +314,38 @@ class ConditionalInterpreter
void fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief Converted the lexeme to a int and to ConditionalData*
* @brief Converted the lexeme to a int and to std::shared_ptr<ConditionalData>
*/
void fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief Converted the lexeme to a float and to ConditionalData*
* @brief Converted the lexeme to a float and to std::shared_ptr<ConditionalData>
*/
void fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief Converted the lexeme to a str and to ConditionalData*
* @brief Converted the lexeme to a str and to std::shared_ptr<ConditionalData>
*/
void fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the == operation between two previously converted ConditionalData*
* @brief makes the == operation between two previously converted std::shared_ptr<ConditionalData>
*/
void fEq(void);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the != operation between two previously converted ConditionalData*
* @brief makes the != operation between two previously converted std::shared_ptr<ConditionalData>
*/
void fNeq(void);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the && operation between two previously converted ConditionalData* in bool
* @brief makes the && operation between two previously converted std::shared_ptr<ConditionalData> in bool
*/
void fAnd(void);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the || operation between two previously converted ConditionalData* in bool
* @brief makes the || operation between two previously converted std::shared_ptr<ConditionalData> in bool
*/
void fOr(void);
......
......@@ -70,7 +70,12 @@ void removeFlatten(std::shared_ptr<GraphView> graphView);
*
* @param nodes Strict set of Node to merge.
*/
void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes);
void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);
void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
/**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
......
......@@ -8,7 +8,7 @@ using namespace Aidge;
//ConditionalRegisterFunction
///////////////////////////////
ConditionalData* ConditionalRegisterFunction::run(const std::string key,std::vector<ConditionalData*> & datas){
std::shared_ptr<ConditionalData> ConditionalRegisterFunction::run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas){
auto lambdaIt = mWlambda.find(key);
if (lambdaIt != mWlambda.end()) {
......@@ -45,10 +45,9 @@ using namespace Aidge;
bool ConditionalInterpreter::test( const NodePtr nodeOp)
{
clearRes();
mResolution.clear();
try{
std::vector<ConditionalData*> r = visit({mTree},nodeOp);
std::vector< std::shared_ptr<ConditionalData>> r = visit({mTree},nodeOp);
if (mResolution.size() != 1){
throw std::runtime_error("Multi output interpretation output");
......@@ -72,8 +71,8 @@ using namespace Aidge;
}
/////
std::vector<ConditionalData*> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){
std::vector<ConditionalData*> dataVector;
std::vector< std::shared_ptr<ConditionalData>> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){
std::vector< std::shared_ptr<ConditionalData>> dataVector;
for ( std::shared_ptr<AstNode<ConditionalTokenTypes>> node : nodes) {
try{
......@@ -140,7 +139,7 @@ using namespace Aidge;
case ConditionalTokenTypes::NODE: //TODO
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<NodePtr>(nodeOp);
mResolution.push_back(data);
......@@ -157,7 +156,7 @@ using namespace Aidge;
case ConditionalTokenTypes::BOOL: //TODO
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
if(node->getValue() == "true"){
data->setValue<bool>(true);
......@@ -195,7 +194,8 @@ using namespace Aidge;
void ConditionalInterpreter::fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<int>(std::stoi(node->getValue()));
mResolution.push_back(data);
}
......@@ -203,14 +203,14 @@ using namespace Aidge;
void ConditionalInterpreter::fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<float>(std::stof(node->getValue()));
mResolution.push_back(data);
}
void ConditionalInterpreter::fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<std::string>(node->getValue());
mResolution.push_back(data);
}
......@@ -218,7 +218,7 @@ using namespace Aidge;
void ConditionalInterpreter::fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
//if the lambda have input
ConditionalData* data;
std::shared_ptr<ConditionalData> data;
try {
data = mLambdaRegister.run(node->getValue(),mResolution);
} catch (const std::exception& e) {
......@@ -227,17 +227,20 @@ using namespace Aidge;
throw std::runtime_error(errorMessage.str());
}
clearRes();
//clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fEq(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("EQ need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != b->getType()){
throw std::runtime_error("EQ Unsupported between type :" + a->getType() +" "+ b->getType());
......@@ -245,7 +248,7 @@ using namespace Aidge;
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
if (a->isTypeEqualTo<int>()) {
data->setValue<bool>( a->getValue<int>() == b->getValue<int>());
......@@ -259,23 +262,25 @@ using namespace Aidge;
throw std::runtime_error("EQ Unknown type encountered :" + a->getType() );
}
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fNeq(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("NEQ need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != b->getType()){
throw std::runtime_error("NEQ Unsupported between type :" + a->getType() +" "+ b->getType());
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
if (a->isTypeEqualTo<int>()) {
data->setValue<bool>( a->getValue<int>() != b->getValue<int>());
......@@ -288,67 +293,72 @@ using namespace Aidge;
throw std::runtime_error("NEQ Unknown type encountered :" + a->getType() );
}
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fAnd(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("AND need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){
throw std::runtime_error("AND Unknown type encountered need bool get :" + a->getType() );
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<bool>( a->getValue<bool>() && b->getValue<bool>());
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fOr(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("OR need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){
throw std::runtime_error("OR Unknown type encountered need bool get :" + a->getType() );
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<bool>( a->getValue<bool>() || b->getValue<bool>());
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fNot()
{
if (mResolution.size() != 1){
if (mResolution.size() < 1){
throw std::runtime_error("NOT need 1 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto a = mResolution.back();
mResolution.pop_back();
if (a->getType() != typeid(bool).name()){
throw std::runtime_error("NOT Unknown type encountered need bool get :" + a->getType() );
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<bool>( !a->getValue<bool>() );
clearRes();
mResolution.push_back(data);
}
......@@ -24,25 +24,30 @@
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
using namespace Aidge;
void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
using namespace Aidge;
// Assert the nodes types are correct to be fused
std::shared_ptr<Node> conv;
std::shared_ptr<Node> batchnorm;
for (const auto& element : nodes) {
assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace");
if (element->type() == "Conv"){
conv = element;
}
else if (element->type() == "BatchNorm") {
batchnorm = element;
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){
// assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// // Assert the nodes types are correct to be fused
// std::shared_ptr<Node> conv;
// std::shared_ptr<Node> batchnorm;
// for (const auto& element : nodes) {
// assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace");
// if (element->type() == "Conv"){
// conv = element;
// }
// else if (element->type() == "BatchNorm") {
// batchnorm = element;
// }
// }
// TODO : check if batchnorm is the only child of the Conv or FC
std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second);
std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second);
std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second);
......@@ -127,19 +132,45 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
}
void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op,batchNorm);
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm");
nodesRegex["Conv"] = new NodeRegex("Conv");
nodesRegex["FC"] = new NodeRegex("FC");
std::vector<std::string> seqRegex;
seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC)
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
fuseBatchNorm(matchNodes[i]);
// std::map<std::string,NodeRegex*> nodesRegex ;
// nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm");
// nodesRegex["Conv"] = new NodeRegex("Conv");
// nodesRegex["FC"] = new NodeRegex("FC");
// std::vector<std::string> seqRegex;
// seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC)
// GRegex GReg(nodesRegex, seqRegex);
// Match matches = GReg.match(graphView);
// std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
// for (size_t i = 0; i < matches.getNbMatch(); ++i) {
// fuseBatchNorm(matchNodes[i]);
// }
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'");
regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
for (const auto& solution : regex->match(graphView)) {
fuseBatchNorm(solution);
}
}
......@@ -12,13 +12,38 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") {
SECTION("custom Lambda") {
const std::string test = " !toto($) == true " ;
ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test);
conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;});
ConditionalInterpreter conditionalParserB = ConditionalInterpreter("A"," bad($) == false ");
ConditionalInterpreter conditionalParserG = ConditionalInterpreter("A"," good($) == true ");
conditionalParserB.insertLambda("bad",+[](NodePtr NodeOp){return NodeOp->name() == "ZZ";});
conditionalParserG.insertLambda("good",+[](NodePtr NodeOp){return NodeOp->name() == "Gop1";});
std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1");
bool result = conditionalParser.test(nodeOp);
REQUIRE(result == true);
REQUIRE(conditionalParserB.test(nodeOp) == true);
REQUIRE(conditionalParserG.test(nodeOp) == true);
}
ConditionalInterpreter conditionalParserT = ConditionalInterpreter("A","isConv($)==true");
conditionalParserT.insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";});
std::shared_ptr<Node> zz = GenericOperator("conv", 0, 0, 0, "Gop1");
conditionalParserT.test(zz);
SECTION("Lambdas") {
ConditionalInterpreter conditionalParser = ConditionalInterpreter("OP_test","getType($) =='Conv' || getType($) =='FC' ");
std::shared_ptr<Node> A = GenericOperator("Conv", 0, 0, 0, "A");
REQUIRE(conditionalParser.test(A) == true);
std::shared_ptr<Node> B = GenericOperator("FC", 0, 0, 0, "B");
REQUIRE(conditionalParser.test(B) == true);
std::shared_ptr<Node> C = GenericOperator("A", 0, 0, 0, "C");
conditionalParser.test(C);
REQUIRE(conditionalParser.test(C) == false);
}
SECTION("syntax error") {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/utils/Recipies.hpp"
#include <cstddef>
namespace Aidge {
TEST_CASE("[FuseBatchNorm] conv") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
BatchNorm<32>()
});
fuseBatchNorm(g1);
SECTION("Check resulting nodes") {
// REQUIRE(g1->getNodes().size() == 2);
// REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling");
// REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
// REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling");
// REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
// REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling");
}
}
}
\ No newline at end of file
......@@ -26,6 +26,7 @@
namespace Aidge {
/*
TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
// generate the original GraphView
auto matmul0 = MatMul(5, "matmul0");
......@@ -74,4 +75,5 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
}
}
*/
} // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment