diff --git a/python_binding/graphRegex/pybind_GraphRegex.cpp b/python_binding/graphRegex/pybind_GraphRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be3cd9e9124ba1306226dcbdc13ee39748cf0606 --- /dev/null +++ b/python_binding/graphRegex/pybind_GraphRegex.cpp @@ -0,0 +1,69 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "aidge/graphRegex/GraphRegex.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_GraphRegex(py::module& m){ + + + py::class_<GraphRegex, std::shared_ptr<GraphRegex>>(m, "GraphRegex", "GraphRegex class describes a regex to test a graph.") + .def(py::init<>()) + + .def("add_query", &GraphRegex::addQuery, R"mydelimiter( + :rtype: str + )mydelimiter") + + .def("set_key_from_graph", &GraphRegex::setKeyFromGraph, R"mydelimiter( + :param ref: The graph use to define type of Node. + :type ref: :py:class:`aidge_core.GraphView` + )mydelimiter") + +// void setNodeKey(const std::string key, const std::string conditionalExpressions ); +// void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); + + .def("match", &GraphRegex::match, R"mydelimiter( + :param graphToMatch: The graph to perform the matching algorithm on. + :type graphToMatch: :py:class:`aidge_core.GraphView` + )mydelimiter") + + + + .def("set_node_key", + (void (GraphRegex::*)(const std::string, const std::string )) & + GraphRegex::setNodeKey, + py::arg("key"), py::arg("conditionalExpressions"), + R"mydelimiter( + Add a node test + :param key: the key of the node test to use in the query. + :param conditionalExpressions: the test to do . + + )mydelimiter") + + + .def("set_node_key", + (void (GraphRegex::*)(const std::string, std::function<bool(NodePtr)>)) & + GraphRegex::setNodeKey, + py::arg("key"), py::arg("f"), + R"mydelimiter( + Add a node test + :param key: the key of the lambda test to use in the conditional expressions. + :param f: bool lambda (nodePtr) . + + )mydelimiter") + + + + ; +} +} diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp index 45e268797e5eb91372a0ac57a1cce6c6dfbffb93..13facefd2979a9b0ca4409ead6972013cb1bc0a8 100644 --- a/unit_tests/recipies/Test_FuseBatchNorm.cpp +++ b/unit_tests/recipies/Test_FuseBatchNorm.cpp @@ -8,11 +8,16 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ - +/* #include <catch2/catch_test_macros.hpp> #include <set> +//#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" +//#include "aidge/backend/cpu/operator/ConvImpl.hpp" + + + #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" @@ -20,19 +25,36 @@ #include "aidge/operator/BatchNorm.hpp" #include "aidge/utils/Recipies.hpp" +//#include "aidge/backend/TensorImpl.hpp" +//#include "aidge/backend/cpu.hpp" +//#include "aidge/" + #include <cstddef> namespace Aidge { + TEST_CASE("[FuseBatchNorm] conv") { auto g1 = Sequential({ Producer({16, 3, 224, 224}, "dataProvider"), Conv(3, 32, {3, 3}, "conv1"), - BatchNorm<32>() + BatchNorm<2>() }); - fuseBatchNorm(g1); + g1->setDatatype(DataType::Float32); + g1->setBackend("cpu"); + g1->forwardDims(); + + // std::set<std::string> availableBackends = Tensor::getAvailableBackends(); + // if (availableBackends.find("cpu") != availableBackends.end()){ + // g1->setBackend("cpu"); + // newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); + // }else{ + // printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); + // } + + fuseBatchNorm(g1); SECTION("Check resulting nodes") { // REQUIRE(g1->getNodes().size() == 2); @@ -43,4 +65,6 @@ namespace Aidge { // REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); } } -} \ No newline at end of file + +} +*/ \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 6a8079b3e30a4f211db9160659e66364bc0ef40a..b99de66d3e23377c13ed86526f6c1a318a00e4e8 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -26,7 +26,7 @@ namespace Aidge { -/* + TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView auto matmul0 = MatMul(5, "matmul0"); @@ -75,5 +75,5 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); } } -*/ + } // namespace Aidge \ No newline at end of file