diff --git a/.gitlab/ci/_global.gitlab-ci.yml b/.gitlab/ci/_global.gitlab-ci.yml index aab5d745367d22052f82c6e3ef144680a822cd45..94e5658ff6adc8e07036d3d59ea39a68fbddc4bf 100644 --- a/.gitlab/ci/_global.gitlab-ci.yml +++ b/.gitlab/ci/_global.gitlab-ci.yml @@ -9,6 +9,14 @@ variables: GIT_SSL_NO_VERIFY: 1 DEBIAN_FRONTEND: noninteractive +# See https://docs.gitlab.com/ee/ci/yaml/workflow.html#switch-between-branch-pipelines-and-merge-request-pipelines +workflow: + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH && $CI_OPEN_MERGE_REQUESTS + when: never + - if: $CI_COMMIT_BRANCH + default: image: nvidia/cuda:12.2.0-devel-ubuntu22.04 before_script: diff --git a/.gitlab/ci/build.gitlab-ci.yml b/.gitlab/ci/build.gitlab-ci.yml index da0d23c9de978ebcdbb370a6f4a92262829e05b9..a4579e2951ccbafc4335ae428c62eba94c0757e5 100644 --- a/.gitlab/ci/build.gitlab-ci.yml +++ b/.gitlab/ci/build.gitlab-ci.yml @@ -12,10 +12,71 @@ build:ubuntu_cpp: - make -j4 all install artifacts: + expire_in: 1 week paths: - build_cpp/ - install_cpp/ +build:ubuntu_cpp_g++10: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y g++-10 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/g++-10 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_g++12: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y g++-12 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/g++-12 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_clang12: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y clang-12 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/clang++-12 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_clang15: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y clang-15 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/clang++-15 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + build:ubuntu_python: stage: build needs: [] @@ -26,9 +87,11 @@ build:ubuntu_python: - python3 -m pip install virtualenv - virtualenv venv - source venv/bin/activate - - export AIDGE_INSTALL=`pwd`/install + # Numpy dependancy for unit test + - python3 -m pip install -r requirements.txt - python3 -m pip install . artifacts: + expire_in: 1 week paths: - venv/ @@ -57,6 +120,35 @@ build:windows_cpp: - cmake --install . --config Debug artifacts: + expire_in: 1 week paths: - build_cpp/ - install_cpp/ + +build:windows_python: + stage: build + needs: [] + tags: + - windows + + image: buildtools + before_script: + # Install Chocolatey + - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + # Install dependencies + - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + - choco install git -Y + - choco install python -Y + # Update PATH + - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + script: + - python -m pip install virtualenv + - virtualenv venv + - venv\Scripts\Activate.ps1 + # Numpy dependancy for unit test + - python -m pip install -r requirements.txt + - python -m pip install . + artifacts: + expire_in: 1 week + paths: + - venv/ diff --git a/.gitlab/ci/coverage.gitlab-ci.yml b/.gitlab/ci/coverage.gitlab-ci.yml index 027f3078180bb32b36ca4666f171dda90ef7f7be..3c7b7654190e0768adc6a904f1cb548f020b0c92 100644 --- a/.gitlab/ci/coverage.gitlab-ci.yml +++ b/.gitlab/ci/coverage.gitlab-ci.yml @@ -24,8 +24,10 @@ coverage:ubuntu_python: script: - source venv/bin/activate - python3 -m pip install numpy coverage - - cd aidge_core - - python3 -m coverage run --source=. -m unittest discover -s unit_tests/ -v -b + - cd ${CI_PROJECT_NAME} + # Retrieve the installation path of the module, since it is installed with pip. + - export MODULE_LOCATION=`python -c "import ${CI_PROJECT_NAME} as _; print(_.__path__[0])"` + - python3 -m coverage run --source=$MODULE_LOCATION -m unittest discover -s unit_tests/ -v -b - python3 -m coverage report - python3 -m coverage xml coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' @@ -33,4 +35,4 @@ coverage:ubuntu_python: reports: coverage_report: coverage_format: cobertura - path: aidge_core/coverage.xml + path: ${CI_PROJECT_NAME}/coverage.xml diff --git a/.gitlab/ci/static_analysis.gitlab-ci.yml b/.gitlab/ci/static_analysis.gitlab-ci.yml index f7c09a33a65801fb25b1f20f76eac5a7a7952917..3955b87d4efdd9b3610b661779ab9709320754f2 100644 --- a/.gitlab/ci/static_analysis.gitlab-ci.yml +++ b/.gitlab/ci/static_analysis.gitlab-ci.yml @@ -26,8 +26,8 @@ static_analysis:python: script: - pip install pylint - pip install pylint-gitlab - - pylint --rcfile=.pylintrc --exit-zero --output-format=pylint_gitlab.GitlabCodeClimateReporter aidge_core/ > codeclimate.json - - pylint --rcfile=.pylintrc --exit-zero --output-format=pylint_gitlab.GitlabPagesHtmlReporter aidge_core/ > pylint.html + - pylint --rcfile=.pylintrc --exit-zero --output-format=pylint_gitlab.GitlabCodeClimateReporter ${CI_PROJECT_NAME}/ > codeclimate.json + - pylint --rcfile=.pylintrc --exit-zero --output-format=pylint_gitlab.GitlabPagesHtmlReporter ${CI_PROJECT_NAME}/ > pylint.html - mkdir -p public/python/$CI_COMMIT_REF_NAME - mv pylint.html public/python/$CI_COMMIT_REF_NAME/ artifacts: diff --git a/.gitlab/ci/test.gitlab-ci.yml b/.gitlab/ci/test.gitlab-ci.yml index 1e67ce273abc7d6b02f9e3148264ff3f9ea1cf07..81e6ca9ac5b868287aa0ef27040c0ead785d3639 100644 --- a/.gitlab/ci/test.gitlab-ci.yml +++ b/.gitlab/ci/test.gitlab-ci.yml @@ -17,14 +17,14 @@ test:ubuntu_python: - docker script: - source venv/bin/activate - - cd aidge_core + - cd ${CI_PROJECT_NAME} - python3 -m pip install unittest-xml-reporting - python3 -m pip list # Run on discovery all tests located in core/unit_tests/python - python3 -m xmlrunner discover -s unit_tests/ -v -b --output-file xmlrunner-results.xml artifacts: reports: - junit: aidge_core/xmlrunner-results.xml + junit: ${CI_PROJECT_NAME}/xmlrunner-results.xml test:windows_cpp: stage: test @@ -37,6 +37,7 @@ test:windows_cpp: - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) # Install dependencies - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + - choco install python -Y # Update PATH - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") script: diff --git a/CMakeLists.txt b/CMakeLists.txt index 67ad9304bc3e682a9436fb52306b3ca8120c1c4b..f8dbe375e217020a4c4570bd67c1b466e6593130 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project) message(STATUS "Project name: ${project}") message(STATUS "Project version: ${version}") -# Note : project name is {project} and python module name is also {project} +# Note : project name is {project} and python module name is also {project} set(module_name _${project}) # target name @@ -52,12 +52,12 @@ target_include_directories(${module_name} ) # PYTHON BINDING -generate_python_binding(${project} ${module_name}) - if (PYBIND) + generate_python_binding(${project} ${module_name}) + # Handles Python + pybind11 headers dependencies target_link_libraries(${module_name} - PUBLIC + PUBLIC pybind11::pybind11 PRIVATE Python::Python @@ -66,22 +66,16 @@ endif() target_compile_features(${module_name} PRIVATE cxx_std_14) - -if(WERROR) - target_compile_options(${module_name} PRIVATE +# -fvisibility=hidden required by pybind11 +target_compile_options(${module_name} PUBLIC + $<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>: + -fvisibility=hidden>) +target_compile_options(${module_name} PRIVATE $<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>: - -Wall -Wextra -fPIC -Wold-style-cast -Winline -pedantic -Werror=narrowing -Wshadow -Werror>) - target_compile_options(${module_name} PRIVATE + -Wall -Wextra -Wold-style-cast -Winline -pedantic -Werror=narrowing -Wshadow $<$<BOOL:${WERROR}>:-Werror>>) +target_compile_options(${module_name} PRIVATE $<$<CXX_COMPILER_ID:MSVC>: /W4>) -else() - target_compile_options(${module_name} PRIVATE - $<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>: - -Wall -Wextra -fPIC -Wold-style-cast -Winline -pedantic -Werror=narrowing -Wshadow -Wpedantic>) - target_compile_options(${module_name} PRIVATE - $<$<CXX_COMPILER_ID:MSVC>: - /W4>) -endif() if(CMAKE_COMPILER_IS_GNUCXX AND COVERAGE) append_coverage_compiler_flags() @@ -107,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) install(EXPORT ${project}-targets FILE "${project}-targets.cmake" DESTINATION ${INSTALL_CONFIGDIR} -# COMPONENT ${module_name} -) +# COMPONENT ${module_name} +) #Create a ConfigVersion.cmake file include(CMakePackageConfigHelpers) @@ -142,4 +136,4 @@ export(EXPORT ${project}-targets if(TEST) enable_testing() add_subdirectory(unit_tests) -endif() \ No newline at end of file +endif() diff --git a/README.md b/README.md index 992344a796a4634a25d2127fc49b57adeae45863..5b07e147cb05c2fa1a6d275d567dda218b131996 100644 --- a/README.md +++ b/README.md @@ -6,16 +6,19 @@ You can find here the C++ code of the Core library of Aidge. ## Pip installation -To install aidge_core using pip, make sure to set the desired install path : -``` bash -export AIDGE_INSTALL = '<path_to_aidge>/install' -``` -Then run in your python environnement : + +To install aidge_core using pip, run the following command in your python environnement : ``` bash pip install . -v ``` +**Note:** you can specify a custom install folder by setting an environment variable: + +``` bash +export AIDGE_INSTALL='<path_to_aidge>/install' +``` + ## Standard C++ Compilation Create two directories ``build`` and ``ìnstall``. diff --git a/aidge_core/__init__.py b/aidge_core/__init__.py index ad18a8ef1b23625dcb52951f52c43adc4222c997..c65dcc6cfc4be8825d1213854014718fb7170854 100644 --- a/aidge_core/__init__.py +++ b/aidge_core/__init__.py @@ -8,3 +8,4 @@ http://www.eclipse.org/legal/epl-2.0. SPDX-License-Identifier: EPL-2.0 """ from aidge_core.aidge_core import * # import so generated by PyBind +from aidge_core.export import ExportNode diff --git a/aidge_core/export/__init__.py b/aidge_core/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00b44121d68af06171525fdf953bf50e53328421 --- /dev/null +++ b/aidge_core/export/__init__.py @@ -0,0 +1 @@ +from .node_export import * diff --git a/aidge_core/export/node_export.py b/aidge_core/export/node_export.py new file mode 100644 index 0000000000000000000000000000000000000000..980cb05a5814b7476d64757353e393ad6130218b --- /dev/null +++ b/aidge_core/export/node_export.py @@ -0,0 +1,61 @@ +import aidge_core + +from abc import ABC, abstractmethod + + +class ExportNode(ABC): + """Abstract class to interface node with export generation. + """ + + @abstractmethod + def __init__(self, aidge_node: aidge_core.Node) -> None: + """Create ExportNode and retieve attirubtes from ``aidge_node``: + + - name: aidge Node name + - attributes: dictionnary of attributes of the aidge Operator linked to the node, attributes name follow aidge naming convention + - parameters: List of parameters node, order in the list is the same as the one defined by the aidge operator + + """ + super().__init__() + self.node = aidge_node + self.operator = aidge_node.get_operator() + self.name = self.node.name() + self.attributes = {} # Attributes are auto fetched from aidge operators + if isinstance(self.operator, aidge_core.Attributes): + for attr_name in self.operator.get_attrs_name(): + self.attributes[attr_name] = self.operator.get_attr(attr_name) + + # rename is_leaf ? + self.is_last = len(self.node.get_children()) == 0 + + + self.inputs = [] + self.outputs = [] + self.inputs_dims = [] + self.outputs_dims = [] + + for idx, parent_node in enumerate(self.node.get_parents()): + self.inputs.append(parent_node) + if parent_node is not None: + self.inputs_dims.append(self.operator.input(idx).dims()) + else: + self.inputs_dims.append(None) + + for idx, child_node in enumerate(self.node.get_children()): + self.outputs.append(child_node) + + # Dirty hot fix, change it quickly + self.outputs_dims.append(self.operator.output(0).dims()) + + @abstractmethod + def export(self, export_folder:str, list_configs:list): + """Define how to export the node definition. + """ + pass + + @abstractmethod + def forward(self, list_actions:list): + """Define how to generate code to perform a forward pass. + """ + pass + diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index b326e0748c2c77612dd79122fe891a6207d945dc..7bd1e730a973810db89aa786b52fa05c53c43590 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -30,36 +30,102 @@ class test_operator_binding(unittest.TestCase): self.assertNotEqual(gop.name(), "") def test_param_bool(self): - self.generic_operator.add_parameter("bool", True) - self.assertEqual(self.generic_operator.get_parameter("bool"), True) + self.generic_operator.add_attr("bool", True) + self.assertEqual(self.generic_operator.has_attr("bool"), True) + self.assertEqual(self.generic_operator.get_attr("bool"), True) + self.assertEqual(self.generic_operator.get_attr_type("bool"), "bool") + self.assertEqual(self.generic_operator.get_attrs_name(), {"bool"}) + self.generic_operator.del_attr("bool") + self.assertEqual(self.generic_operator.has_attr("bool"), False) + self.assertEqual(len(self.generic_operator.get_attrs_name()), 0) def test_param_int(self): - self.generic_operator.add_parameter("int", 1) - self.assertEqual(self.generic_operator.get_parameter("int"), 1) + self.generic_operator.add_attr("int", 1) + self.assertEqual(self.generic_operator.get_attr("int"), 1) def test_param_float(self): - self.generic_operator.add_parameter("float", 2.0) - self.assertEqual(self.generic_operator.get_parameter("float"), 2.0) + self.generic_operator.add_attr("float", 2.0) + self.assertEqual(self.generic_operator.get_attr("float"), 2.0) def test_param_str(self): - self.generic_operator.add_parameter("str", "value") - self.assertEqual(self.generic_operator.get_parameter("str"), "value") + self.generic_operator.add_attr("str", "value") + self.assertEqual(self.generic_operator.get_attr("str"), "value") def test_param_l_int(self): - self.generic_operator.add_parameter("l_int", [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) - self.assertEqual(self.generic_operator.get_parameter("l_int"), [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) + self.generic_operator.add_attr("l_int", [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) + self.assertEqual(self.generic_operator.get_attr("l_int"), [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) def test_param_l_bool(self): - self.generic_operator.add_parameter("l_bool", [True, False, False, True]) - self.assertEqual(self.generic_operator.get_parameter("l_bool"), [True, False, False, True]) + self.generic_operator.add_attr("l_bool", [True, False, False, True]) + self.assertEqual(self.generic_operator.get_attr("l_bool"), [True, False, False, True]) def test_param_l_float(self): - self.generic_operator.add_parameter("l_float", [2.0, 1.0]) - self.assertEqual(self.generic_operator.get_parameter("l_float"), [2.0, 1.0]) + self.generic_operator.add_attr("l_float", [2.0, 1.0]) + self.assertEqual(self.generic_operator.get_attr("l_float"), [2.0, 1.0]) def test_param_l_str(self): - self.generic_operator.add_parameter("l_str", ["ok"]) - self.assertEqual(self.generic_operator.get_parameter("l_str"), ["ok"]) + self.generic_operator.add_attr("l_str", ["ok"]) + self.assertEqual(self.generic_operator.get_attr("l_str"), ["ok"]) + + def test_dynamicattribute_binding(self): + # Check original C++ attributes are binded + attrs = aidge_core.test_DynamicAttributes_binding() + self.assertEqual(attrs.has_attr("a"), True) + self.assertEqual(attrs.get_attr("a"), 42) + self.assertEqual(attrs.has_attr("b"), True) + self.assertEqual(attrs.get_attr("b"), "test") + self.assertEqual(attrs.has_attr("c"), True) + self.assertEqual(attrs.get_attr("c"), [True, False, True]) + self.assertEqual(attrs.get_attrs_name(), {"a", "b", "c"}) + self.assertEqual(attrs.has_attr("d"), False) + + # Add Python attributes + attrs.add_attr("d", 18.56) + self.assertEqual(attrs.get_attr("d"), 18.56) + self.assertEqual(attrs.has_attr("d"), True) + self.assertEqual(attrs.get_attrs_name(), {"a", "b", "c", "d"}) + self.assertEqual(attrs.has_attr("e"), False) + + # Check that added Python attribute is accessible in C++ + # Return the value of an attribute named "d" of type float64 (double in C++) + self.assertEqual(aidge_core.test_DynamicAttributes_binding_check(attrs), 18.56) + attrs.set_attr("d", 23.89) + self.assertEqual(aidge_core.test_DynamicAttributes_binding_check(attrs), 23.89) + + def test_compute_output_dims(self): + in_dims=[25, 25] + input = aidge_core.Producer(in_dims, name="In") + genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp") + _ = aidge_core.sequential([input, genOp]) + self.assertListEqual(genOp.get_operator().output(0).dims(), []) + genOp.get_operator().set_compute_output_dims(lambda x:x) + genOp.get_operator().compute_output_dims() + self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) + + def test_set_impl(self): + + class PythonCustomImpl(aidge_core.OperatorImpl): + """Dummy implementation to test that C++ call python code + """ + def __init__(self, op: aidge_core.Operator): + aidge_core.OperatorImpl.__init__(self, op) # Recquired to avoid type error ! + self.idx = 0 + + def forward(self): + """Increment idx attribute on forward. + """ + self.idx += 1 + + generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu") + generic_op = generic_node.get_operator() + customImpl = PythonCustomImpl(generic_op) + + generic_op.forward() # Do nothing, no implementation set + generic_op.set_impl(customImpl) + generic_op.forward() # Increment idx + self.assertEqual(customImpl.idx, 1) + + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/aidge_core/unit_tests/test_parameters.py b/aidge_core/unit_tests/test_parameters.py index 02c7598820d2429bc49ff9a2f02c8ee841783173..566650713c36236c19763f466ee906970466c02e 100644 --- a/aidge_core/unit_tests/test_parameters.py +++ b/aidge_core/unit_tests/test_parameters.py @@ -11,7 +11,7 @@ SPDX-License-Identifier: EPL-2.0 import unittest import aidge_core -class test_parameters(unittest.TestCase): +class test_attributes(unittest.TestCase): """Very basic test to make sure the python APi is not broken. Can be remove in later stage of the developpement. """ @@ -27,21 +27,21 @@ class test_parameters(unittest.TestCase): out_channels = 8 k_dims = [2, 2] conv_op = aidge_core.Conv2D(in_channels , out_channels, k_dims).get_operator() - self.assertEqual(conv_op.get("InChannels"), in_channels) - self.assertEqual(conv_op.get("OutChannels"), out_channels) - self.assertEqual(conv_op.get("KernelDims"), k_dims) + self.assertEqual(conv_op.get_attr("InChannels"), in_channels) + self.assertEqual(conv_op.get_attr("OutChannels"), out_channels) + self.assertEqual(conv_op.get_attr("KernelDims"), k_dims) def test_fc(self): out_channels = 8 nb_bias = True fc_op = aidge_core.FC(out_channels, nb_bias).get_operator() - self.assertEqual(fc_op.get("OutChannels"), out_channels) - self.assertEqual(fc_op.get("NoBias"), nb_bias) + self.assertEqual(fc_op.get_attr("OutChannels"), out_channels) + self.assertEqual(fc_op.get_attr("NoBias"), nb_bias) def test_matmul(self): out_channels = 8 - matmul_op = aidge_core.Matmul(out_channels).get_operator() - self.assertEqual(matmul_op.get("OutChannels"), out_channels) + matmul_op = aidge_core.MatMul(out_channels).get_operator() + self.assertEqual(matmul_op.get_attr("OutChannels"), out_channels) def test_producer_1D(self): dims = [5] @@ -71,7 +71,7 @@ class test_parameters(unittest.TestCase): def test_leaky_relu(self): negative_slope = 0.25 leakyrelu_op = aidge_core.LeakyReLU(negative_slope).get_operator() - self.assertEqual(leakyrelu_op.get("NegativeSlope"), negative_slope) + self.assertEqual(leakyrelu_op.get_attr("NegativeSlope"), negative_slope) if __name__ == '__main__': unittest.main() diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py new file mode 100644 index 0000000000000000000000000000000000000000..754907443530f7e73d1e10ed9549d0c8eb78a011 --- /dev/null +++ b/aidge_core/unit_tests/test_recipies.py @@ -0,0 +1,78 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core + +class test_recipies(unittest.TestCase): + """ + """ + def setUp(self): + pass + + def tearDown(self): + pass + + def test_remove_flatten(self): + graph_view = aidge_core.sequential([ + aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), + aidge_core.FC(50, name='0') + ]) + old_nodes = graph_view.get_nodes() + aidge_core.remove_flatten(graph_view) + self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1) + self.assertTrue("Flatten0" not in [i.name for i in graph_view.get_nodes()]) + + self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) + + def test_fuse_matmul_add(self): + matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0") + add0 = aidge_core.Add(name="Add0") + matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1") + add1 = aidge_core.Add(name="Add1") + + graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1]) + + w0 = aidge_core.Producer([1, 1], name="W0") + w0.add_child(matmul0, 0, 1) + graph_view.add(w0) + + b0 = aidge_core.Producer([1], name="B0") + b0.add_child(add0, 0, 1) + graph_view.add(b0) + + w1 = aidge_core.Producer([1, 1], name="W1") + w1.add_child(matmul1, 0, 1) + graph_view.add(w1) + + b1 = aidge_core.Producer([1], name="B1") + b1.add_child(add1, 0, 1) + graph_view.add(b1) + + old_nodes = graph_view.get_nodes() + aidge_core.fuse_mul_add(graph_view) + + self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2) + self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("Add0" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("MatMul1" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("Add1" not in [i.name() for i in graph_view.get_nodes()]) + + self.assertTrue("W0" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("B0" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("W1" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("B1" in [i.name() for i in graph_view.get_nodes()]) + # TODO : Vérifier que FC bien crée + +if __name__ == '__main__': + unittest.main() + + + diff --git a/aidge_core/unit_tests/test_tensor.py b/aidge_core/unit_tests/test_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..a214a0e354c64b515d0a7ac24d81c85e116938ca --- /dev/null +++ b/aidge_core/unit_tests/test_tensor.py @@ -0,0 +1,44 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core + +from functools import reduce +import numpy as np + +class test_tensor(unittest.TestCase): + """ + """ + def setUp(self): + pass + + def tearDown(self): + pass + + def test_getcoord_getidx(self): + dims = [2,2,2] + size = reduce((lambda x, y: x*y), dims) + + np_array = np.arange(size).reshape(dims) + + t = aidge_core.Tensor(np_array) + for i in range(size): + coord = t.get_coord(i) + idx = t.get_idx(coord) + self.assertEqual(idx, i) + +if __name__ == '__main__': + unittest.main() + + + + + diff --git a/cmake/PybindModuleCreation.cmake b/cmake/PybindModuleCreation.cmake index 18f4abc38e2537c3f4d949f08772a57b90758cb0..8030c1a8639e4b7ae0c5fb865e928a4260c6ae7d 100644 --- a/cmake/PybindModuleCreation.cmake +++ b/cmake/PybindModuleCreation.cmake @@ -1,23 +1,21 @@ -function(generate_python_binding name target_to_bind) - if (PYBIND) - add_definitions(-DPYBIND) - Include(FetchContent) +function(generate_python_binding name target_to_bind) + add_definitions(-DPYBIND) + Include(FetchContent) - FetchContent_Declare( - PyBind11 - GIT_REPOSITORY https://github.com/pybind/pybind11.git - GIT_TAG v2.10.4 # or a later release - ) + FetchContent_Declare( + PyBind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.10.4 # or a later release + ) - # Use the New FindPython mode, recommanded. Requires CMake 3.15+ - find_package(Python COMPONENTS Interpreter Development) - FetchContent_MakeAvailable(PyBind11) + # Use the New FindPython mode, recommanded. Requires CMake 3.15+ + find_package(Python COMPONENTS Interpreter Development) + FetchContent_MakeAvailable(PyBind11) - message(STATUS "Creating binding for module ${name}") - file(GLOB_RECURSE pybind_src_files "python_binding/*.cpp") + message(STATUS "Creating binding for module ${name}") + file(GLOB_RECURSE pybind_src_files "python_binding/*.cpp") - pybind11_add_module(${name} MODULE ${pybind_src_files} "NO_EXTRAS") # NO EXTRA recquired for pip install - target_include_directories(${name} PUBLIC "python_binding") - target_link_libraries(${name} PUBLIC ${target_to_bind}) - endif() + pybind11_add_module(${name} MODULE ${pybind_src_files} "NO_EXTRAS") # NO EXTRA recquired for pip install + target_include_directories(${name} PUBLIC "python_binding") + target_link_libraries(${name} PUBLIC ${target_to_bind}) endfunction() diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 091bdbc0f0681352d6983b231c3a68a50a2be716..ff1ff00938e4d2f35b0a220bd4d199f51bfb802f 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -34,17 +34,27 @@ #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/Div.hpp" #include "aidge/operator/FC.hpp" #include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Matmul.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/MaxPooling.hpp" #include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/operator/Mul.hpp" #include "aidge/operator/Operator.hpp" +#include "aidge/operator/Pad.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Pow.hpp" #include "aidge/operator/ReLU.hpp" +#include "aidge/operator/Scaling.hpp" #include "aidge/operator/Softmax.hpp" +#include "aidge/operator/Sqrt.hpp" +#include "aidge/operator/Sub.hpp" #include "aidge/scheduler/Scheduler.hpp" -#include "aidge/utils/CParameter.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/Attributes.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/DynamicAttributes.hpp" #include "aidge/utils/Recipies.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 5aa2829e16f612b0867ab69feccb829ba2095e1b..19f0837504016f38ae96dd852bc6fa41b5ab53ba 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -14,13 +14,17 @@ #include <cstddef> #include <vector> +#include <memory> #include "aidge/utils/Types.h" namespace Aidge { +class Operator; + class OperatorImpl { public: - virtual void forward(){}; - virtual void backward() {} + OperatorImpl(const Operator& op); + virtual void forward(); + virtual void backward(); /** * @brief Minimum amount of data from a specific input required by the @@ -29,13 +33,13 @@ public: * @param inputIdx Index of the input analysed. * @return std::size_t */ - virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; // Amount of input data that cannot be overwritten during the execution. - virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; // Memory required at an output for a given input size. - virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0; + virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; /** * @brief Total amount of consumed data from a specific input. @@ -43,17 +47,28 @@ public: * @param inputIdx Index of the input analysed. * @return DimSize_t */ - virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0; + virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** - * @brief TOtal amount of produced data ready to be used on a specific output. + * @brief Total amount of produced data ready to be used on a specific output. * * @param outputIdx Index of the output analysed. * @return DimSize_t */ - virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0; + virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + + /** + * @brief Update the Consummer Producer system by simulating the consumption and production of i/o + * + */ + virtual void updateConsummerProducer(); virtual ~OperatorImpl() = default; + +protected: + const Operator &mOp; + std::vector<NbElts_t> mNbConsumedData; + std::vector<NbElts_t> mNbProducedData; }; } // namespace Aidge diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index b54d8b5d7cebdde1a938090f779fdd61663b5014..f8d398c7801f45a0411fafa446ae7c51ce671cfc 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -27,6 +27,9 @@ public: { printf("Cannot set raw pointer for backend %s\n", mBackend); }; + + virtual void* getRaw(std::size_t /*idx*/)=0; + virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) constexpr const char *backend() const { return mBackend; } virtual ~TensorImpl() = default; diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 81b7810a8a548df7e5a2829b1a31cbe337491382..02f4df320d87d1bb02edfa5c11ffe8bc7f560986 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -12,7 +12,7 @@ #ifndef AIDGE_DATA_H_ #define AIDGE_DATA_H_ -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/Attributes.hpp" namespace Aidge { enum class DataType { diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index c3a6e478f8943253a9f9b3565db2d4452a9ca133..58c434bccc7c8dd39a93c46ecf74c38d7d834d1a 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -446,29 +446,45 @@ class Tensor : public Data, */ bool empty() const { return mDims.empty(); } - template <typename expectedType, std::array<std::size_t, 1>::size_type DIM> - constexpr expectedType &get(std::array<std::size_t, DIM> idx) { - assert(DIM == mDims.size()); - assert(mImpl); - std::size_t unfoldedIdx = 0; - for (std::size_t i = 0; i < DIM - std::size_t(1); ++i) { - unfoldedIdx = (unfoldedIdx + idx[i]) * mDims[i + 1]; - } - unfoldedIdx += idx[DIM - 1]; - return static_cast<expectedType *>(mImpl->rawPtr())[unfoldedIdx]; + template <typename expectedType> + expectedType& get(std::size_t idx){ + // TODO : add assert expected Type compatible with datatype + // TODO : add assert idx < Size + return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + } + + template <typename expectedType> + expectedType& get(std::vector<std::size_t> coordIdx){ + return get<expectedType>(getIdx(coordIdx)); + } + + template <typename expectedType> + void set(std::size_t idx, expectedType value){ + // TODO : add assert expected Type compatible with datatype + // TODO : add assert idx < Size + void* dataPtr = mImpl->getRaw(idx); + std::memcpy(dataPtr, &value, sizeof(expectedType)); + } + + template <typename expectedType> + void set(std::vector<std::size_t> coordIdx, expectedType value){ + set<expectedType>(getIdx(coordIdx), value); } + + std::string toString() { if (dims().empty()) { return "{}"; } std::string res; std::size_t dim = 0; - std::size_t *dimVals = new std::size_t[nbDims()]; - for (std::size_t i = 0; i < nbDims(); ++i) { - dimVals[i] = 0; - } std::size_t counter = 0; - res += "{\n"; - if (nbDims()>=2){ + if (nbDims()>=2) { + std::size_t *dimVals = new std::size_t[nbDims()]; + for (std::size_t i = 0; i < nbDims(); ++i) { + dimVals[i] = 0; + } + // std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0); + res += "{\n"; while (counter < mSize) { std::string spaceString = std::string((dim+1)<<1,' '); if (dim < nbDims()-2) { @@ -517,31 +533,35 @@ class Tensor : public Data, } res += "\n"; } + if (dim == 0) { + break; + } dimVals[dim--] = 0; dimVals[dim]++; } } - for(int i = static_cast<int>(dim); i>=0; --i) { + delete[] dimVals; + + for(int i = static_cast<int>(dim); i > 0; --i) { res += std::string((dim+1)<<1,' ') + "}\n"; } - }else{ + } else { + res += "{"; for (DimSize_t j = 0; j < dims()[0]; ++j) { switch (mDataType) { case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); break; case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); break; default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); break; } } } - - res += "}"; return res; } @@ -559,6 +579,42 @@ class Tensor : public Data, return mGrad; } + /** + * @brief From the the 1D index, return the coordinate of an element in the tensor. + * + * @param flatIdx 1D index of the value considering a flatten tensor. + * @return std::vector<DimSize_t> + */ + std::vector<std::size_t> getCoord(std::size_t flatIdx) const { + std::vector<std::size_t> coordIdx = std::vector<std::size_t>(mDims.size()); + std::size_t idx = flatIdx; + for (std::size_t i = mDims.size() - 1; i > 0; --i){ + coordIdx[i] = (idx % mDims[i]); + idx/=mDims[i]; + } + coordIdx[0] = idx % mDims[0]; + return coordIdx; + } + + /** + * @brief From the coordinate returns the 1D index of an element in the tensor. + * + * @param coordIdx Coordinate to an element in the tensor + * @return DimSize_t + */ + std::size_t getIdx(std::vector<std::size_t> coordIdx) const { + // std::size_t flatIdx = 0; + // std::size_t stride = 1; + std::size_t flatIdx = 0; + assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions"); + std::size_t i = 0; + for(; i < mDims.size() - 1; ++i){ + assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor"); + flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1]; + } + return flatIdx + coordIdx[i]; + } + private: ///\bug not protected against overflow std::size_t computeSize() { diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 718eddeaf6a5d08c9dab4898f5a57c0192dcb80b..89ba148497709f0af475bbf953ff285c88036102 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -124,7 +124,7 @@ public: } /** - * @brief List dataInput connections of the GraphView object's inputNodes. + * @brief List outside dataInput connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; @@ -137,7 +137,7 @@ public: inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** - * @brief List input connections of the GraphView object's inputNodes. + * @brief List outside input connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; @@ -208,7 +208,7 @@ public: * @brief Get the Nodes pointed to by the GraphView object. * @return std::set<NodePtr> */ - inline std::set<NodePtr> getNodes() const { return mNodes; } + inline const std::set<NodePtr>& getNodes() const { return mNodes; } /** * @brief Get the operator with the corresponding name if it is in the @@ -217,7 +217,7 @@ public: * @return NodePtr returns a new empty node if the one asked for * was not found. */ - NodePtr getNode(const char *nodeName) const; + NodePtr getNode(const std::string& nodeName) const; /** * @brief Remove a Node from the current GraphView scope without affecting its connections. @@ -320,8 +320,20 @@ public: void link(std::string name1_inID, std::string name2_outID); - void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes, - IOIndex_t tensorIdx); + /** + * @brief Insert a node (newParentNode) as a parent of the passed node (childNode). + * + * @param childNode Node that gets a new parent. + * @param newParentNode Inserted Node. + * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output. + * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode. + * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor. + */ + void insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, + IOIndex_t newParentOutputTensorIdx); /** * @brief Replace the current GraphView with the set of given Nodes if possible @@ -336,6 +348,37 @@ public: */ void updateOutputNodes(); + /** + * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> cloneSharedOperators() const { + return cloneCallback(&Node::cloneSharedOperators); + } + + /** + * @brief Clone the GraphView with shared Producers. All the other Operators are copied. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> cloneSharedProducers() const { + return cloneCallback(&Node::cloneSharedProducers); + } + + /** + * @brief Clone the GraphView. Everything is cloned: Nodes and Operators. + * @return std::shared_ptr<GraphView> + */ + inline std::shared_ptr<GraphView> clone() const { + return cloneCallback(&Node::clone); + } + + /** + * @brief Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity). When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an identity operation). + * @param cloneNode Callback function to clone a node + * @return std::shared_ptr<GraphView> + */ + std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const; + private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index f056505e6e7839266213ac355cc0e1b93ab98f0d..1d8449ac25cf8c31192da0c350c14cbfa50a48f4 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -62,7 +62,7 @@ public: * @param op Operator giving the Node its number of connections. * @param name (optional) name for the Node. */ - Node(std::shared_ptr<Operator> op, const char *name = nullptr); + Node(std::shared_ptr<Operator> op, const std::string& name = ""); virtual ~Node() = default; @@ -303,7 +303,7 @@ public: * @param inId Input index. * @return std::shared_ptr<Node>& */ - inline NodePtr &getParents(const IOIndex_t inId) { + inline NodePtr &getParent(const IOIndex_t inId) { assert(inId != gk_IODefaultIndex); return mParents.at(inId); } @@ -350,6 +350,67 @@ public: */ void resetConnections(bool includeLearnableParam = false); + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + + /** + * @brief Clone the current Node. The Operator attribute of the new Node is not copied but shared with the current Node. The new node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedOperators() const; + + /** + * @brief Clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection. + * @return NodePtr + */ + NodePtr cloneSharedProducers() const; + + /** + * @brief Clone the Node and its Operator. The new Node has no connection. + * @return NodePtr + */ + NodePtr clone() const; + + /** + * @brief Callback function to clone the Node keeping the same Operator object instance. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedOperators(NodePtr node) { + return node->cloneSharedOperators(); + } + + /** + * @brief Callback function to clone the Node. Every attribute is copied, even Operator pointer except for Producers for which it is shared. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr cloneSharedProducers(NodePtr node) { + return node->cloneSharedProducers(); + } + + /** + * @brief Callback function to clone the Node and its Operator. The new Node has no connection. + * @param node Node to clone. + * @return NodePtr + */ + static NodePtr clone(NodePtr node) { + return node->clone(); + } + + + /** + * @brief Get the set of pointers to connected node at a distance of a delta. + * @details the recution are cut + * Return a nullptr is nofing found. + * @param delta Input delta. + * @return std::shared_ptr<Node> + */ + + std::set<NodePtr> getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee); + + private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/include/aidge/graph/OpArgs.hpp b/include/aidge/graph/OpArgs.hpp index 560c3a02c641c29526752dbf352905d0ded32a7e..9d1ba6fd1e1df594634bfd93a24663ff178b7ee6 100644 --- a/include/aidge/graph/OpArgs.hpp +++ b/include/aidge/graph/OpArgs.hpp @@ -55,7 +55,7 @@ public: * @param inputs List of Node and GraphView to link sequentially. * @return std::shared_ptr<GraphView> Pointer to the generated view. */ -std::shared_ptr<GraphView> Sequential(std::initializer_list<OpArgs> inputs); +std::shared_ptr<GraphView> Sequential(std::vector<OpArgs> inputs); ///////////////////////////// // Parallel @@ -65,7 +65,7 @@ std::shared_ptr<GraphView> Sequential(std::initializer_list<OpArgs> inputs); * @param inputs List of Node and GraphView to link sequentially. * @return std::shared_ptr<GraphView> pointer to the generated view. */ -std::shared_ptr<GraphView> Parallel(std::initializer_list<OpArgs> inputs); +std::shared_ptr<GraphView> Parallel(std::vector<OpArgs> inputs); ///////////////////////////// // Residual @@ -79,8 +79,8 @@ std::shared_ptr<GraphView> Parallel(std::initializer_list<OpArgs> inputs); * @param inputs List of Node and GraphView to link sequentially. * @return std::shared_ptr<GraphView> pointer to the generated view. */ -std::shared_ptr<GraphView> Residual(std::initializer_list<OpArgs> inputs); +std::shared_ptr<GraphView> Residual(std::vector<OpArgs> inputs); } -#endif /* AIDGE_CORE_GRAPH_OPARGS_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_GRAPH_OPARGS_H_ */ diff --git a/include/aidge/graphRegex/GraphFsmInterpreter.hpp b/include/aidge/graphRegex/GraphFsmInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e92b6fe8fc9d5e44cb8051e687e33d7192e0eb7 --- /dev/null +++ b/include/aidge/graphRegex/GraphFsmInterpreter.hpp @@ -0,0 +1,73 @@ +#ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ +#define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ + +#include <string> +#include <memory> + +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +namespace Aidge { + + class GraphFsmInterpreter + { + private: + /* data */ + GraphParser mParser; + std::size_t mActGroupe; + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> mNodesCondition; + + public: + GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition); + virtual ~GraphFsmInterpreter() =default; + + + std::shared_ptr<FsmGraph> interpret(void); + + private: + + + std::shared_ptr<FsmGraph> visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree); + + /** + * @defgroup graphFsmInterpreterF Functions for interpreting AST nodes + * @brief For each node type in the AST, define how build the FsmGraph + */ + + + /** + * @ingroup graphFsmInterpreterF + * @brief leaf of fsm make the fsm for test one transition + */ + std::shared_ptr<FsmGraph> keyF(std::shared_ptr<AstNode<gRegexTokenTypes>> AstNode); + /** + * @ingroup graphFsmInterpreterF + * @brief combine two fsm of two expression. + */ + std::shared_ptr<FsmGraph> sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm); + /** + * @ingroup graphFsmInterpreterF + * @brief combine two to make a new that match leftFsm next rigthFsm + */ + std::shared_ptr<FsmGraph> nextF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm); + /** + * @ingroup graphFsmInterpreterF + * @brief make the fsm match + + */ + std::shared_ptr<FsmGraph> qomF(std::shared_ptr<FsmGraph> fsm); + /** + * @ingroup graphFsmInterpreterF + * @brief make the fsm match * + */ + std::shared_ptr<FsmGraph> qzmF(std::shared_ptr<FsmGraph> fsm); + + }; + + + +} + + +#endif // AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ diff --git a/include/aidge/graphRegex/GraphLexer.hpp b/include/aidge/graphRegex/GraphLexer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4137ab093c466b7349007da91e032dae48eda51 --- /dev/null +++ b/include/aidge/graphRegex/GraphLexer.hpp @@ -0,0 +1,68 @@ +#ifndef AIDGE_CORE_GRAPH_LEXER_H_ +#define AIDGE_CORE_GRAPH_LEXER_H_ + +#include <string> +#include <memory> +#include <regex> +#include <stdexcept> //error +#include <sstream> + +#include "aidge/utilsParsing/ParsingToken.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +namespace Aidge { + + class GraphLexer + { + + public: + GraphLexer( const std::string gRegexExpressions ); + + /** + * @brief Get the next token on the gRegexExpressions + * @return ConditionalToken + */ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> getNextToken(void); + /** + * @brief Restart at the start of the gRegexExpressions + * + */ + void rstPosition(void); + + /** + * @brief Test if the string is completely read + * @return bool + */ + bool isEnd(void); + + + /** + * @brief Get the representation of the class + * @return string + */ + const std::string rep(); + + private: + + /** + * @brief Constructs an error message to display the character not understood by the lexer + * @return error mesage + */ + std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); + + /** + * @brief The expression of the test to be performed on the nodes + */ + const std::string mRegularExpressions; + /** + * @brief The lexer's current position in mConditionalExpressions + */ + std::size_t mPosition; + + }; +} + + + + +#endif //AIDGE_CORE_GRAPH_LEXER_H_ diff --git a/include/aidge/graphRegex/GraphParser.hpp b/include/aidge/graphRegex/GraphParser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..73406203a8be87e1df75cc694ab1ff281c27fbfa --- /dev/null +++ b/include/aidge/graphRegex/GraphParser.hpp @@ -0,0 +1,98 @@ +#ifndef AIDGE_CORE_GRAPH_PARSER_H_ +#define AIDGE_CORE_GRAPH_PARSER_H_ + + +#include <memory> // for shared_ptr +#include "aidge/graphRegex/GraphLexer.hpp" +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +namespace Aidge{ + +/** + * @brief this class uses the lexer to create an AST according to a set of gramer rules + */ +class GraphParser{ + + public: + /** + * @brief AST graph creation function + * @param gRegexExpressions String representing the logical fuction to be performed + */ + GraphParser(const std::string gRegexExpressions); + + virtual ~GraphParser() = default; + + /** + * @brief AST graph creation function + * @return The AST tree + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> parse(void); + + + private: + /** + * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken + */ + void rstParser(void); + + ////////////////// + + /** + * @defgroup ParsingFunctions Function for creating AST + * @brief Functions for recursive construction of the AST representing grammar rules + */ + + /** + * @ingroup ParsingFunctions + * @brief Token reading and verification function + * + */ + void ackToken(gRegexTokenTypes tokenType); + + //TODO TODO + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for key : KEY(QOM | QZM)? | CKEY + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstExp(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for sequence : seq :exp (NEXT seq)* + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstSeq(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for domain : (seq NEXT domain)? | LPAREN domain RPAREN (QOM | QZM) (NEXT domain)? + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstDomain(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for multiple exepresion : allExpr: domain (SEP allExpr)* + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstAllExpr(void); + + + /** + * @brief The actual token in the parce + */ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> mCurrentToken; + + /** + * @brief The lexem use + */ + GraphLexer mLexer; + +}; + + +} + +#endif //AIDGE_CORE_GRAPH_PARSER_H_ diff --git a/include/aidge/graphRegex/GraphRegexTypes.hpp b/include/aidge/graphRegex/GraphRegexTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e35f8f027eb363b71358922ffe0caa4a55fff1d --- /dev/null +++ b/include/aidge/graphRegex/GraphRegexTypes.hpp @@ -0,0 +1,29 @@ + +#ifndef AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ +#define AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ + + +namespace Aidge { + /** + * @brief enum for all types of token use in the of the regex + * 7-5 type + * 4-0 id + */ + enum class gRegexTokenTypes + { + STOP, + NEXT, /**< -> */ + + QOM, /**< + */ + QZM, /**< * */ + + KEY, /**< [A-Za-z_0-9]+ */ + CKEY, /**< [A-Za-z_0-9]+#[0-9]* */ + + SEP, /**< \( */ + LPAREN, /**< \( */ + RPAREN, /**< \) */ + }; + +} +#endif //AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ diff --git a/include/aidge/graphRegex/GraphStrInterpreter.hpp b/include/aidge/graphRegex/GraphStrInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98dca0e9f84de0be2614aed0e47c9d86ae552674 --- /dev/null +++ b/include/aidge/graphRegex/GraphStrInterpreter.hpp @@ -0,0 +1,40 @@ +#ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ +#define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ + +#include <iostream> +#include <sstream> +#include <memory> +#include <algorithm> + +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +namespace Aidge { + + class GraphStrInterpreter + { + private: + /* data */ + GraphParser mParser; + std::string mToTest; + public: + GraphStrInterpreter(const std::string graphMatchExpr); + virtual ~GraphStrInterpreter() =default; + + + std::string interpret(void); + + private: + + + std::string visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree); + }; + + + +} + + +#endif //AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3eae528808dbdb8023718c961b7c45cbf4afac9 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -0,0 +1,228 @@ +#ifndef AIDGE_CORE_FSM_EDGE_H_ +#define AIDGE_CORE_FSM_EDGE_H_ + +#include <memory> +#include <set> +#include <string> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + + +namespace Aidge{ + + class FsmNode; + class FsmRunTimeContext; + + struct EdgeTestResult { + bool success; + std::set<NodePtr> node; + }; + + /** + * @brief virtual class use test the node on the node to validate + */ + class FsmEdge: public std::enable_shared_from_this<FsmEdge> + { + private: + + /** + * @brief the relative position to this test relative to all the const key + * first is common id, second is the relative position + */ + std::map<size_t,int> mRelativePos; + /** + * @brief the ptr on the source node + */ + std::shared_ptr<FsmNode> mNodeSource; + /** + * @brief the ptr on the dest node + */ + std::shared_ptr<FsmNode> mNodeDest; + /** + * @brief the weak ptr + */ + std::weak_ptr<FsmEdge> weakPtr; + + public: + FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); + + virtual ~FsmEdge(){}; + + FsmEdge() : weakPtr(shared_from_this()) {} + + + /** + * @brief test is the validation of the node, it must be defined for all types of edge + * it takes as argument an FSM traversal context and returns a set of next nodes + * @return set of next node or nullptr if not next + */ + + virtual const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) =0; + + /** + * @brief test is the egde test a common node + * @return true if is a common + */ + virtual bool isCommon(void); + /** + * @brief get the Common idx of the common test in this edge (if is a common edge) + * @return idx of the common + */ + virtual size_t getCommonIdx(void); + /** + * @brief get the relative postion to the common node deffine in this edge + * @return map + */ + const std::map<size_t,int>& getRelative(void); + /** + * @brief add new relative position + */ + void updateRelative( const std::map<size_t,int>& relativePos ); + /** + * @brief get source FsmNode + * @return FsmNode + */ + std::shared_ptr<FsmNode> getSourceNode(void); + /** + * @brief set a new source to the edge + * @return FsmNode + */ + void reSetSouceNode(const std::shared_ptr<FsmNode>& newSource); + /** + * @brief get dest FsmNode + * @return FsmNode + */ + std::shared_ptr<FsmNode> getDestNode(void); + /** + * @brief set a new dest to the edge + * @return FsmNode + */ + void reSetDestNode(const std::shared_ptr<FsmNode>& newDest); + /** + * @brief propagate the edge mRelativePos to the others Edge and recalcul the relative position + */ + void propagateRelativePos(void); + + /** + * @brief test to make on the node to validate + * @see ConditionalInterpreter + */ + const std::shared_ptr<ConditionalInterpreter> mToTest; + + /** + * @brief update week ptr for the node, TODO best + */ + void updateWeak(void); + }; + + /** + * @brief class spesialisation for not commun node (node that must be match one Unique) transition + */ + class FsmEdgeUnique:public FsmEdge + { + + public: + FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + }; + + /** + * @brief class spesialisation for commun node transition + * @see FsmEdge + */ + class FsmEdgeCommon:public FsmEdge + { + + private: + /** + * @brief the map that defind the ralation between the commonKey find by the lexer and a unique id use to refer to the common node + */ + static std::map<std::string,int> mCommonIdxMap; + /** + * @brief the common id test in this transition + */ + int mCommonIdx; + public: + + /** + * @brief constructor commun node , + * @details during construction, + * the node key found by the lexer is converted to a unique id and the relative positions are updated. + */ + FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey); + // ~FsmEdgeCommon() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + bool isCommon(void) override; + + }; + + + + /** + * @brief class spesialisation for ref transition + * @see FsmEdge + */ + class FsmEdgeRef:public FsmEdge + { + private: + /** + * @brief the id of one common node that we use as an anchor + */ + const int mRefCommonIdx; + /** + * @brief the delta in terme of child or parent refer to the anchor + */ + const int mdeltaCommonIdx; + public: + FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx); + //~FsmEdgeRef() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + + }; + + /** + * @brief class spesialisation for ref empty transition + * @see FsmEdge + */ + class FsmEdgeEmpty:public FsmEdge + { + + public: + FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + //~FsmEdgeEmpty() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + + }; + + + +//////////////////////// +// FACTORY +//////////////////////// + +enum class FsmEdgeTypes { + EMPTY = 0, + REF, + COMMON, + UNIQUE +}; + + +class FsmEdgeFactory { + public: + /** + * @brief factory for making edge and read the info in the lexeme of the token + * @param source source node of the edge + * @param dest Dest node of the edge + * @param type type of the edge + * @param lexeme the additional information to build the edge + * @return s prt of the edge + */ + static std::shared_ptr<FsmEdge> make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, + FsmEdgeTypes type,std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest, + const std::string lexeme = ""); + }; + +} + +#endif //AIDGE_CORE_FSM_EDGE_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a74551367dd492cb0abb820e4c5ce5a601d071e --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp @@ -0,0 +1,98 @@ + +#ifndef AIDGE_CORE_FSM_GRAPH_H_ +#define AIDGE_CORE_FSM_GRAPH_H_ + +#include <set> +#include <vector> +#include <memory> +#include <stdexcept> //error + +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" +namespace Aidge{ + + + +class FsmGraph +{ +private: + /** + * @brief all node origine + */ + std::set<std::size_t> mAllOrigine; + std::set<std::shared_ptr<FsmEdge>> mEdges; +public: + FsmGraph(/* args */); + virtual ~FsmGraph() = default; + +std::shared_ptr<MatchResult> test(std::vector<NodePtr>& StartNodes); + + + +const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); +/** + * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph +*/ +void addEdge(std::shared_ptr<FsmEdge>& edge); + +/** + * @brief get the liste of the starting states + * @details we need to use a vector because the order of the nodes is important for start node initialization \ref test() +*/ +const std::vector<std::shared_ptr<FsmNode>> getStartNodes(void); + +/** + * @brief get the set of the valide states + * @return set of valide state +*/ +const std::set<std::shared_ptr<FsmNode>> getValidNodes(void); + +/** + * @brief get the set of all the node in the graph + * @return set of all nodes +*/ +const std::set<std::shared_ptr<FsmNode>> getNodes(void); + +/** + * @brief set a groupe idx for all the nodes in the graph +*/ +void setGroupe(std::size_t groupeIdx); + +/** + * @brief make the union beteen this graph and an input graph + * @param fsmGraph graph to union +*/ +void unionG(const std::shared_ptr<FsmGraph> fsmGraph); + + +/** + * @brief make the union beteen this graph and an input graph and merge the valide state to the start state + * @param fsmGraph graph to merge +*/ +void mergeOneStartOneValid(const std::shared_ptr< FsmGraph> fsmGraph); +/** + * @brief get the number of sub FSM + * @return number of sub Fsm +*/ +std::size_t getNbSubFsm(void); + +/** + * @brief increment the origine of all node in the graph + * @param incr the incrémentation value +*/ +void incOrigineAllNodeBy(std::size_t incr); + +private: + +/** + * @brief merge tow node of the graph + * @param node +*/ +void _mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + +}; + + +} +#endif //AIDGE_CORE_FSM_GRAPH_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmNode.hpp b/include/aidge/graphRegex/matchFsm/FsmNode.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2776ff8eb297fd5ad9a4c425fb386adde0a25269 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmNode.hpp @@ -0,0 +1,99 @@ +#ifndef AIDGE_CORE_FSM_NODE_H_ +#define AIDGE_CORE_FSM_NODE_H_ + +#include <set> +#include <vector> +#include <memory> + +//#include "graphRegex/matchFsm/FsmEdge.hpp" +//#include "graphRegex/matchFsm/FsmRunTimeContext.hpp" + +namespace Aidge{ + // Forward declaration of the class defined in graphRegex/matchFsm/FsmEdge.hpp + class FsmEdge; + struct EdgeTestResult; + class FsmRunTimeContext; + + + //------------------------------------------------------------------------------ + + // MAY BE IN UTILE + template <typename T> + struct lex_compare { + bool operator() (const std::weak_ptr<T> &lhs, const std::weak_ptr<T> &rhs)const { + auto lptr = lhs.lock(), rptr = rhs.lock(); + if (!rptr) return false; // nothing after expired pointer + if (!lptr) return true; + return lptr < rptr; + } + }; + + /** + * @brief is a node in the FSM graph, it's a state in the FSM + * @details a state can be and/or : + * - a valide state, the match is valide if it stop on this edge + * - a start state , the match start on this state + * The state is also define by this origine (is the unique id of it's expretion ) + * and it's groupe (for inner expression TODO) + */ + class FsmNode : public std::enable_shared_from_this<FsmNode> + { + private: + /** + * @brief the edge of the node + * @details the edge have a shared ref to the node so we use weak ref + */ + std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>> mEdges; + /** + * @brief the parent of the node + */ + std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>> mParents; + + std::size_t mOrigineStm = 0; + std::size_t mGroupeStm = 0; + + bool mIsAValid; + bool mIsAStart; + + public: + FsmNode(bool isAValid,bool isAStart ); + virtual ~FsmNode() = default; + /** + * @brief use to MAG the actual context , and return all the posible new context + * @details one input context can generate a multitude of contexts because a graph node + * can have more than one child, and each traversal possibility is a new context. + * @param actContext the actual context + * @return A vector of all the new context + */ + const std::vector<std::shared_ptr<FsmRunTimeContext>> test( std::shared_ptr<FsmRunTimeContext>); + + + std::size_t getOrigine(void); + void incOrigine(std::size_t inc); + + + void rmEdge(std::shared_ptr<FsmEdge>); + void addEdge(std::shared_ptr<FsmEdge>); + + //const std::set<std::shared_ptr<FsmNode>> getChildNodes(void); + + const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& getParentNodes(void); + const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& getEdges(void); + + void setGroupe(std::size_t groupeIdx); + + bool isValid(void); + bool isStart(void); + void unValid(void); + void valid(void); + void unStart(void); + void start(void); + + + + void addParent(std::shared_ptr<FsmNode>); + void rmParent(std::shared_ptr<FsmNode>); + }; + +} +#endif //AIDGE_CORE_FSM_NODE_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f1b9fc2bfe68195f67cfc0bf17d57aed5345219 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp @@ -0,0 +1,173 @@ +#ifndef AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ +#define AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ + +#include <memory> +#include <vector> +#include <set> +#include <algorithm> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" +#include "aidge/graph/Node.hpp" + + + +namespace Aidge{ + + class FsmNode; + + class FsmNode; + + /** + * @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph + * all node that have been Validate,Rejecte or Considered common + */ + class FsmRunTimeContext + { + private: + /** + * @brief the list of node rejected for all the context + */ + static std::vector<std::set<NodePtr>> mRejectedNodes; + /** + * @brief the actual state of this Context (where it's in the FSM graph) + */ + std::shared_ptr<FsmNode> mActState; + /** + * @brief the actual node of this Context (where it's in the graph) + */ + NodePtr mActOpNode; + /** + * @brief the map of the node consider as common and the common ID + * @details we need to store what node it's consider as common because of the end + * resolution of the matching, all node consider as common need to be the same in all context + */ + std::map<NodePtr,std::size_t> mCommonNodes; + /** + * @brief the map of the node that as been valid in this context , and the test that valide the node + */ + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes; + /** + * @brief the index in the rejected node of this context + */ + std::size_t mLocalIdxRejeced; + public: + /** + * @brief constructor + * @param actState the actual state in the FSM + * @param actOpNode the actual node in the graph + * @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind + */ + FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() ); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ); + + virtual ~FsmRunTimeContext()=default; + + /** + * @defgroup FsmRunTimeContextRejected Function for managing rejected nodes + */ + + /** + * @ingroup FsmRunTimeContextRejected + * @brief Add a node as rejected in this context + */ + void addRejectedNode(NodePtr node); + + /** + * @ingroup FsmRunTimeContextRejected + * @brief get the rejected nodes of this context + */ + std::set<NodePtr> getRejectedNodes(void); + + + /** + * @defgroup FsmRunTimeContextTest Function for test the context + */ + + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the actual state is valide + * @return bool + */ + bool isOnValidState(void); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the node is considered as common in this context + * @param node node to test + * @return bool + */ + bool isCommonDefined(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if has already validated in this context + * @param node node to test + * @return bool + */ + bool isAlreadyValid(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is compatible with an others + * @details to say that two contexts are compatible is to check : + * that the contexts do not validate the same nodes (other than the common ones) + * and that the common ones have the same idx + * @param fsmContext the others context + * @return bool + */ + bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is strictly equal with an others + * @param fsmContext the others context + * @return bool + */ + bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext); + + /** + * @defgroup FsmRunTimeContextSet Function set context + */ + + + void setCommon(NodePtr node,std::size_t commonIdx); + + + void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag); + + /** + * @defgroup FsmRunTimeContextGet Function get context + */ + + + /** + * @ingroup FsmRunTimeContextGet + * @brief get the sub idx state + * @return bool + */ + std::size_t getSubStmId(void); + + NodePtr getCommonNodeFromIdx(std::size_t commonIdx); + std::size_t getCommonNodeIdx(NodePtr node); + std::set<NodePtr> getCommonNodes(void); + + std::map<NodePtr,std::size_t> getCommon(void); + std::set<NodePtr> getValidNodes(void); + + std::set<NodePtr> getValidNodesNoCommon(void); + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> getValid(void); + + + NodePtr getActNode(void); + std::shared_ptr<FsmNode> getActState(void); + + + /** + * @defgroup FsmRunTimeContextMem + */ + + void rst(void); + + + }; + +} + +#endif //AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ diff --git a/include/aidge/graphRegex/matchFsm/MatchResult.hpp b/include/aidge/graphRegex/matchFsm/MatchResult.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac2f2a627a9d88b3cabeac4b181af2f3b7566d72 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/MatchResult.hpp @@ -0,0 +1,60 @@ +#ifndef AIDGE_CORE_MATCH_RESULT_H_ +#define AIDGE_CORE_MATCH_RESULT_H_ + +#include <memory> +#include <vector> +#include <map> + + +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge{ + +/** + * @brief class that old the result of a matching + * give acess to all node ant there tag in the expression +*/ +class MatchResult +{ +private: + /* data */ + std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid; + + /* + the Run time of eatch sub FSM , to have a valide match we need a set of one run time per FSM compatible + the id must be contigue + */ + std::vector<std::vector<std::shared_ptr<FsmRunTimeContext>>> mIdToRunTime; + + std::vector<std::set<NodePtr>> mSolve; + + std::size_t mNbSubStm; + +public: + MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm); + + virtual ~MatchResult() = default; + + /** + * @brief get the set of the node match for une expression + * @return the set of node of the graph that corresponding to an expression + */ + std::set<NodePtr> getBiggerSolution(void); + +private: + +/** + * @brief recurent function use to inite mSolve in the constructor + * + **/ + +void _generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence); + +}; + + +} + + +#endif //AIDGE_CORE_MATCH_RESULT_H_ diff --git a/include/aidge/hook/ExecTime.hpp b/include/aidge/hook/ExecTime.hpp new file mode 100644 index 0000000000000000000000000000000000000000..212fef58696be702e89c8ad973dcc0dd0fc389ae --- /dev/null +++ b/include/aidge/hook/ExecTime.hpp @@ -0,0 +1,59 @@ +/** + * \file execTime.hpp + * \brief execTime structure + * \version file 1.0.0 + * \date Creation 27 June 2023 + * \date 27 June 2023 + * \par ChangeLog + * \par + * v1.0.0, 27 June 2023<br> + * - Initial version. + * \author mn271187, ik243221 + * \copyright + * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All + * rights reserved. + */ + +#ifndef execTime_H_ +#define execTime_H_ + +#include "aidge/operator/Operator.hpp" +#include "aidge/hook/hook.hpp" +#include <memory> +#include <chrono> +#include <vector> + +namespace Aidge { + +class ExecTime : public Hook { +private: + std::vector<std::chrono::high_resolution_clock::time_point> registeredTimes = std::vector<std::chrono::high_resolution_clock::time_point>(); +public: + ExecTime(const std::shared_ptr<Operator> op) : Hook(op) {} + ~ExecTime() = default; + + void call() override final { + registeredTimes.push_back(std::chrono::high_resolution_clock::now()); + } + + static std::shared_ptr<ExecTime> create(const std::shared_ptr<Operator> op) + { + return std::make_shared<ExecTime>(op); + } + + std::vector<std::chrono::high_resolution_clock::time_point> getTimes() { + return registeredTimes; + } + + std::chrono::high_resolution_clock::time_point getTime(size_t idx) { + return registeredTimes[idx]; + } + +}; + +namespace { + static Registrar<Hook> registrarHook_ExecTime({"execution_time"}, Aidge::ExecTime::create); +} +} + +#endif /* execTime_H_ */ \ No newline at end of file diff --git a/include/aidge/hook/Hook.hpp b/include/aidge/hook/Hook.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5e00db5d68f11aadd4f3b6eb8174ba61b33e4a49 --- /dev/null +++ b/include/aidge/hook/Hook.hpp @@ -0,0 +1,41 @@ +/** + * \file Hook.hpp + * \brief Hook structure + * \version file 1.0.0 + * \date Creation 27 June 2023 + * \date 27 June 2023 + * \par ChangeLog + * \par + * v1.0.0, 27 June 2023<br> + * - Initial version. + * \author mn271187, ik243221 + * \copyright + * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All + * rights reserved. + */ + +#ifndef Hook_H_ +#define Hook_H_ + +#include "aidge/utils/Attributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include <memory> + +namespace Aidge { + +class Operator; +class Hook : public Registrable<Hook, std::tuple<std::string>, std::shared_ptr<Hook>(const std::shared_ptr<Operator>)> { +//class Hook : public Registrable<Hook, std::tuple<std::string>, std::shared_ptr<Hook>(const std::shared_ptr<Operator>)>{ +protected: + const std::shared_ptr<Operator> mOperator; + +public: + Hook(std::shared_ptr<Operator> op) : mOperator(op) {} + virtual ~Hook() = default; + + virtual void call() = 0; + +}; +} + +#endif /* Hook_H_ */ diff --git a/include/aidge/hook/OutputRange.hpp b/include/aidge/hook/OutputRange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a2da2a997d594c0ef78fb7c31f33b32c3495c4eb --- /dev/null +++ b/include/aidge/hook/OutputRange.hpp @@ -0,0 +1,74 @@ +/** + * \file execTime.hpp + * \brief execTime structure + * \version file 1.0.0 + * \date Creation 27 June 2023 + * \date 27 June 2023 + * \par ChangeLog + * \par + * v1.0.0, 27 June 2023<br> + * - Initial version. + * \author ik243221 + * \copyright + * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All + * rights reserved. + */ + +#ifndef AIDGE_CORE_HOOK_OUTPUTRANGE_H_ +#define AIDGE_CORE_HOOK_OUTPUTRANGE_H_ + +#include "aidge/operator/Operator.hpp" +#include "aidge/hook/hook.hpp" +#include <memory> +#include <chrono> +#include <vector> +#include <cmath> +namespace Aidge { + +class OutputRange : public Hook { +private: + std::vector<float> registeredOutputs = std::vector<float>(); +public: + OutputRange(const std::shared_ptr<Operator> op) : Hook(op) {} + ~OutputRange() = default; + + void call() override final { + //std::cout << "call() outputRange hook " << std::endl; + //this assumes there is only 1 output possible + std::shared_ptr<Tensor> tensor = mOperator->getOutput(0); + //tensor->print(); + //std::cout << "call() outputRange hook : tensor printed" << std::endl; + float max_value = 0.; + float * casted_tensor = static_cast<float *>(tensor->getImpl()->rawPtr()); + //find the absolute max value in the tensor, save it to registered outputs + for(std::size_t i = 0; i < tensor->size(); ++i) { + //std::cout << "call() outputRange hook : casted_tensor[i] = " << casted_tensor[i] << std::endl; + if(std::abs(casted_tensor[i]) > max_value){ + max_value = std::abs(casted_tensor[i]); + } + } + //std::cout << "call() outputRange hook : max_value = " << max_value << std::endl; + registeredOutputs.push_back(max_value); + } + + static std::shared_ptr<OutputRange> create(const std::shared_ptr<Operator> op) + { + return std::make_shared<OutputRange>(op); + } + + std::vector<float> getOutputs() { + return registeredOutputs; + } + + float getOutput(size_t idx) { + return registeredOutputs[idx]; + } + +}; + +namespace { + static Registrar<Hook> registrarHook_OutputRange({"output_range"}, Aidge::OutputRange::create); +} +} + +#endif /* outputRange_H_ */ \ No newline at end of file diff --git a/include/aidge/nodeTester/ConditionalData.hpp b/include/aidge/nodeTester/ConditionalData.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12df32a728571678a3885f9981e526e1d73db785 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalData.hpp @@ -0,0 +1,98 @@ + +#ifndef AIDGE_CORE_CONDITIONAL_DATA_H_ +#define AIDGE_CORE_CONDITIONAL_DATA_H_ + +#include <vector> +#include <string> +#include <stdexcept> //error +#include <memory> +#include <map> +namespace Aidge{ + + + +///////////////////////// +// The data type in AST Intepretation +//////////////////////// + +class BaseConditionalValue { +public: + virtual ~BaseConditionalValue() {} +}; + +template <typename T> +class ConditionalValue : public BaseConditionalValue { +public: + ConditionalValue(const T& data) : value(data) {} + T value; +}; + + +struct ConditionalData { + /** + * @brief generic type to propagate all the different values in the AST interpretation + */ + //void* value; + std::unique_ptr<BaseConditionalValue> value; + const std::type_info* type =nullptr; + + ///////////////////////////////// + // + //////////////////////////////// + /** + * @brief set a value + */ + template <typename T> + void setValue(const T& newValue) { + //make sure that the old value is free + deleteValue(); + value = std::make_unique<ConditionalValue<T>>(newValue); + type = &typeid(T); + } + + /** + * @brief get the actual value + * @details recaste the value to the templaited type and checks that the conversion type is compatible with type + * @tparam the type of the return value + * @return the value + */ + template <typename T> + T getValue() const { + if (type && *type == typeid(T)) { + //const Value<T>* typedValue = dynamic_cast<const Value<T>*>(static_cast<const BaseValue*>(value)); + const ConditionalValue<T>* typedValue = dynamic_cast<const ConditionalValue<T>*>(value.get()); + if (typedValue) { + return typedValue->value; + } + } + throw std::runtime_error(std::string("DATA ERROR ") + type->name() + " != " + typeid(T).name()); + } + /////////////////////////////////// + // + /////////////////////////////////// + std::string getType() const { + return type ? type->name() : "nullptr"; + } + + + template <typename T> + bool isTypeEqualTo() const { + return (type && *type == typeid(T)); + } + + void deleteValue() { + if (type) { + value.reset(); + type = nullptr; + } + } + + ~ConditionalData() { // TODO best can we have a list of type supported ? + deleteValue(); + } +}; + +} + + +#endif //AIDGE_CORE_CONDITIONAL_DATA_H_ diff --git a/include/aidge/nodeTester/ConditionalInterpreter.hpp b/include/aidge/nodeTester/ConditionalInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..165fac1c2ae98bf76b73c039de9fc975e9845cc9 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalInterpreter.hpp @@ -0,0 +1,339 @@ + + +#ifndef AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ +#define AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ + +#include "aidge/nodeTester/ConditionalParser.hpp" +#include "aidge/nodeTester/ConditionalData.hpp" + +#include <memory> // for shared_ptr +#include <unordered_map> +#include <functional> +#include "aidge/graph/Node.hpp" +#include <sstream> + + +namespace Aidge{ + + + +////////////////////////////// +// +///////////////////////////// +/** + * @brief class used to register any lambda function without context, + * it encapsulates the source lambda in a lambda which takes as argument ConditionalData* which are any type. + * @see ConditionalData + */ +class ConditionalRegisterFunction { + ////////////////////////// + //Safe recaste + ////////////////////////// + + /** + * @brief recast the ConditionalData* to the argument type of the lambda + * @tparam T type of the lambda argument + * @see ConditionalData + */ + template <typename T> + T safeCastInput(ConditionalData* data) { + //cnvertion and type cheking + if (data->isTypeEqualTo<T>()){ + return data->getValue<T>(); + }else{ + throw std::invalid_argument( "incompatible input type " + data->getType() +" "+ typeid(T).name() ); + } + + } + + + /** + * @brief recaste the output of the lambda to a ConditionalData* + * @tparam T type of the lambda return + * @see ConditionalData + */ + template <typename T> + ConditionalData* safeCastOutput(T data) { + + ConditionalData* out = new ConditionalData; + out->setValue<T>(data); + + return out; + } + + + + + ////////////////////// + // get all the type of the function + ////////////////////// + + /** + * @brief Retrieves information about a function's return type and argument types. + * @tparam T The function type. + */ + template <typename T> + struct function_traits; + + + /** + * @brief Specialization of function_traits for function pointers. + * @tparam R The return type of the function. + * @tparam Args The argument types of the function. + */ + template <typename R, typename... Args> + struct function_traits<R (*)(Args...)> { + using return_type = R; + static constexpr std::size_t arity = sizeof...(Args); + + template <std::size_t N> + struct argument { + static_assert(N < arity, "Index out of range."); + using type = typename std::tuple_element<N, std::tuple<Args...>>::type; + }; + }; + + /** + * @brief Specialization of function_traits for std::function types. + * @tparam R The return type of the function. + * @tparam Args The argument types of the function. + */ + template <typename R, typename... Args> + struct function_traits<std::function<R(Args...)>> { + using return_type = R; + static constexpr std::size_t arity = sizeof...(Args); + + template <std::size_t N> + struct argument { + static_assert(N < arity, "Index out of range."); + using type = typename std::tuple_element<N, std::tuple<Args...>>::type; + }; + }; + + ///////////////////// + //change the function to ConditionalData*(std::vector<ConditionalData*>) + ///////////////////// + + /** + * @brief Converts a function to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam F The type of the function to convert. + * @tparam ParamsIdx The indices of the function parameters. + * @param f The function to convert. + * @return The pointer to the converted function. + */ + template <class F, std::size_t... ParamsIdx> + auto funcPointer(F f, std::index_sequence<ParamsIdx...>) { + //wrapp the lambda in a new one that as ConditionalData as inputs and output + return [this,f](std::vector<ConditionalData*> &args) { + if (args.size() != sizeof...(ParamsIdx)){ + std::ostringstream errorMessage; + errorMessage << "bad Number of argument: get " << args.size() << " need " << sizeof...(ParamsIdx) << "\n"; + throw std::runtime_error(errorMessage.str()); + } + //assert(args.size() == sizeof...(ParamsIdx));//the size of the vector valide + + using FuncTraits = function_traits<decltype(f)>; + using outType = typename FuncTraits::return_type; + + outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[ParamsIdx])...); + //typename + return safeCastOutput<outType>(result); + }; + } + + /** + * @brief Converts a function pointer to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam R The return type of the function. + * @tparam Params The parameter types of the function. + * @param f The function pointer to convert. + * @return The pointer to the converted function. + */ + template <class R,class... Params> + auto funcPointer(R (*f)(Params...)) { + return funcPointer(f, std::index_sequence_for<Params...>{}); + } + + /** + * @brief Converts a std::function to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam R The return type of the function. + * @tparam Params The parameter types of the function. + * @param f The function pointer to convert. + * @return The pointer to the converted function. + */ + template <class R,class... Params> + auto funcPointer(std::function<R(Params...)> f) { + return funcPointer(f, std::index_sequence_for<Params...>{}); + } + + + /////////////////// + // interface + /////////////////// + + public: + + /** + * @brief Default constructor + */ + ConditionalRegisterFunction(){} + + + /** + * @brief Inserts a function into the map with the provided key. + * @tparam T The function type. + * @param key The key to associate with the function. + * @param f The function to insert. + */ + template <class T> + void insert(const std::string key,T f){ + mWlambda.insert({ key, funcPointer(f)}); + } + + + /** + * @brief Runs the function associated with the given key, using the provided vector of input data. + * @param key The key of the function to run. + * @param datas The vector of input data. + * @return A pointer to the output ConditionalData object. + */ + ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas); + + private: + /// @brief map of name and the converted function. + std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda; +}; + +/////////////////// +//AST tree node +// //////////////// +/** + * @brief this class interprets AST to generate a test on a graph node. For each AST node, + * it generates an interpretation and registers lambda functions that can be used in the test expression. + * there are two lambda control mechanisms: + * - A cpp mechanism which allows any lambda to be inserted into the constructor that use templaite + * - A user mechanism limited to lambda bool(NodePtr) + * @see ConditionalParser use to get the AST + */ +class ConditionalInterpreter +{ + private: + + /** + * @brief the AST generate by the Parser + * @see ConditionalParser + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> mTree; + /** + * @brief the registery for the lambda fuction + * @see ConditionalRegisterFunction + */ + ConditionalRegisterFunction mLambdaRegiter; + + + std::vector<ConditionalData*> mResolution ; + + void clearRes(){ + + for (std::size_t i = 0; i < mResolution.size(); ++i) { + delete mResolution[i]; + } + mResolution.clear(); + } + + public: + /** + * @brief Constructor + * @param ConditionalExpressions The expression of the test to be performed on the nodes + */ + + ConditionalInterpreter(const std::string ConditionalExpressions); + + ~ConditionalInterpreter(){clearRes();} + + /** + * @brief Test a node depending of the ConditionalExpressions + * @details the AST is visit using \ref visit() whith the $ init whit the nodeOp + * @return bool the match node has the initialized expresion + * @see visit() This function uses the visit() function to perform the evaluation. + */ + bool test( const NodePtr nodeOp); + + /** + * @brief Interface for inserting custom lambda bool(NodePtr) functions in AST interpretation, + * it will be available in the ConditionalExpressions expretion as : key($) + * @param key The key that will be used to call the function in the expression + * @param f The pointer to function + */ + void insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f); + + + ///// + + private: + /** + * @brief Recursive AST traversal function, using the for interpreting AST nodes function, + * using \ref ASTnodeInterpreterF fuctions + * @param NodeOp The node currently being tested + * @param nodes The AST given by the parsing process + */ + std::vector<ConditionalData*> visit(const ASTNodeCh& nodes, const NodePtr NodeOp ); + + /** + * @defgroup ASTnodeInterpreterF Functions for interpreting AST nodes + * @brief For each node type in the AST, function defines the processing to be performed + * they return a std::vector<ConditionalData*> which corresponds to the value(s) obtained + */ + + /** + * @ingroup ASTnodeInterpreterF + * @brief Function that does something. + */ + void fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a int and to ConditionalData* + */ + void fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a float and to ConditionalData* + */ + void fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a str and to ConditionalData* + */ + void fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the == operation between two previously converted ConditionalData* + */ + void fEq(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the != operation between two previously converted ConditionalData* + */ + void fNeq(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the && operation between two previously converted ConditionalData* in bool + */ + void fAnd(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the || operation between two previously converted ConditionalData* in bool + */ + void fOr(void); + + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the ! operation + */ + void fNot(void); +}; + + +} + +#endif //AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ diff --git a/include/aidge/nodeTester/ConditionalLexer.hpp b/include/aidge/nodeTester/ConditionalLexer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcfb9ebe783ac719076ce675e6fc3d78caf5be07 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalLexer.hpp @@ -0,0 +1,88 @@ +/** + * @file + * @brief + * @version file 1.0.0 + * @author vl241552 + * @copyright + * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All + * rights reserved. + */ + + + +#ifndef AIDGE_CORE_CONDITIONAL_LEXER_H_ +#define AIDGE_CORE_CONDITIONAL_LEXER_H_ + +#include <string> +#include <regex> +#include <memory> // for shared_ptr + + +#include <stdexcept> //error +#include <sstream> + +#include "aidge/nodeTester/ConditionalTypes.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" + + +namespace Aidge{ + + + +class ConditionalLexer +{ + +public: +ConditionalLexer( const std::string ConditionalExpressions ); + +/** + * @brief Get the next token on the ConditionalExpressions + * @return ParsingToken<ConditionalTokenTypes> + */ +std::shared_ptr<ParsingToken<ConditionalTokenTypes>> getNextToken(void); +/** + * @brief Restart at the start of the ConditionalExpressions + * + */ +void rstPosition(void); + +/** + * @brief Test if the string is completely read + * @return bool + */ +bool isEnd(void); + + +/** + * @brief Get the representation of the class + * @return string + */ +const std::string rep(){ + return mConditionalExpressions; +} + +private: + +/** + * @brief Constructs an error message to display the character not understood by the lexer + * @return error mesage + */ +std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); + +/** + * @brief The expression of the test to be performed on the nodes + */ +const std::string mConditionalExpressions; +/** + * @brief The lexer's current position in mConditionalExpressions + */ +std::size_t mPosition; + +}; + +///////////////////////////////////// + + +} + +#endif //AIDGE_CORE_CONDITIONAL_LEXER_H_ diff --git a/include/aidge/nodeTester/ConditionalParser.hpp b/include/aidge/nodeTester/ConditionalParser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a99f5374182f57c0adca3b4d44691ff4e37de44d --- /dev/null +++ b/include/aidge/nodeTester/ConditionalParser.hpp @@ -0,0 +1,109 @@ + + + +#ifndef AIDGE_CORE_CONDITIONAL_PARSER_H_ +#define AIDGE_CORE_CONDITIONAL_PARSER_H_ + + +#include <memory> // for shared_ptr +#include <map> +#include <vector> + +#include "aidge/nodeTester/ConditionalLexer.hpp" +#include "aidge/nodeTester/ConditionalTypes.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" +#include "aidge/utilsParsing/AstNode.hpp" + +namespace Aidge{ + +const std::map<ConditionalTokenTypes, std::size_t> ConditionalPrec{ + {ConditionalTokenTypes::AND,2}, + {ConditionalTokenTypes::OR,1} +}; + + + + +using ASTNodeCh = std::vector<std::shared_ptr<AstNode<ConditionalTokenTypes>>>; + +/** + * @brief this class uses the lexer to create an AST according to a set of gramer rules + */ +class ConditionalParser{ + + public: + /** + * @brief AST graph creation function + * @param ConditionalExpressions String representing the logical fuction to be performed + */ + ConditionalParser(const std::string ConditionalExpressions); + + virtual ~ConditionalParser() = default; + /** + * @brief AST graph creation function + * @return The AST tree + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> parse(void); + + + private: + /** + * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken + */ + void rstParser(void); + + ////////////////// + + /** + * @defgroup ParsingFunctions Function for creating AST + * @brief Functions for recursive construction of the AST representing grammar rules + */ + + /** + * @ingroup ParsingFunctions + * @brief Token reading and verification function + * + */ + void ackToken(ConditionalTokenTypes tokenType); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for values : (KEY|INTEGER|FOAT|STRING|LAMBDA lambda) + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstVal(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for comparison : val (EQ|NEQ) val | LPAREN expr RPAREN + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstCmpr(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for arguments of a lambda : LAMBDA val (ARGSEP val)* RPAREN + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstLambda(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for a expresion : cmpr ((AND | OR) cmpr)* + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstExpr(std::size_t precLimit = 0); + + + /** + * @brief The actual token in the parce + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> mCurrentToken; + /** + * @brief The lexem use + */ + ConditionalLexer mLexer; + +}; + + +} + +#endif //AIDGE_CORE_CONDITIONAL_PARSER_H_ diff --git a/include/aidge/nodeTester/ConditionalTypes.hpp b/include/aidge/nodeTester/ConditionalTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6cb2edfd78e6b43c4f2dbc89c49cdaa9ea79f7d2 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalTypes.hpp @@ -0,0 +1,36 @@ + + +#ifndef AIDGE_CORE_CONDITIONAL_TYPES_H_ +#define AIDGE_CORE_CONDITIONAL_TYPES_H_ +namespace Aidge{ + /** + * @brief enum for all types of token use in the parsing + * 7-5 type + * 4-0 id + */ + enum class ConditionalTokenTypes + { + STOP, + + NOT, /**< ! */ + AND, /**< && */ + OR, /**< || */ + + EQ, /**< == */ + NEQ, /**< != */ + + KEY, /**< [A-Za-z][A-Za-z0-9_]* */ + INTEGER, /**< [0-9]+ */ + FLOAT, /**< [0-9]+\.[0-9]* */ + STRING , /**< \'.*\' */ + BOOL, /**< true|false */ + NODE, /**< \$ */ + LAMBDA , /**< [A-Za-z][A-Za-z0-9_]*\( */ + + ARGSEP, /**< , */ + LPAREN, /**< \( */ + RPAREN, /**< \) */ + + }; +} +#endif // AIDGE_CORE_CONDITIONAL_TYPES_H_ diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index c96b2c571f412124ccdfb83dde685e111448a222..65c7e8ce0e47bd470e2a1499a682ed2f2c8c2dbc 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -32,14 +32,13 @@ class Add_Op : public Operator, public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, NUM> mInputs; - const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(shared_from_this()); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char* Type = "Add"; constexpr Add_Op() - : Operator(Type), - mOutput(std::make_shared<Tensor>()) + : Operator(Type) { assert(NUM > 0 && "Add should have at least one input"); for (std::size_t i = 0; i<NUM; ++i) { @@ -48,6 +47,31 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Add_Op(const Add_Op<NUM>& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + assert(NUM > 0 && "Add should have at least one input"); + for (std::size_t i = 0; i<NUM; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Add_Op<NUM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Add_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Add_Op>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : @@ -57,14 +81,14 @@ public: // return *in; // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { const auto expectedDims = mInputs[0]->dims(); std::size_t nonEmptyInputTensor = 1; @@ -116,7 +140,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Add_Op<NUM>>::create(name)(*this); mOutput->setBackend(name); @@ -126,7 +150,7 @@ public: } } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -138,10 +162,16 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return NUM; } inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input_0", "data_input_n"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::size_t NUM> -inline std::shared_ptr<Node> Add(const char* name = nullptr) { +inline std::shared_ptr<Node> Add(const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Add_Op<NUM>>(), name); } } diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 7bf8740877e635cc2e59418bee1c444c7f3884e8..dfcd0d5b3b4d892f201485e85710d42cd5b71dba 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -21,20 +21,19 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" namespace Aidge { -enum class AvgPoolingParam { StrideDims, KernelDims, PaddingDims }; +enum class AvgPoolingAttr { StrideDims, KernelDims }; template <DimIdx_t DIM> class AvgPooling_Op : public Operator, public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>, - public Parameterizable<AvgPoolingParam, + public StaticAttributes<AvgPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { private: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -45,25 +44,43 @@ public: AvgPooling_Op() = delete; - using Parameterizable_ = Parameterizable<AvgPoolingParam, - std::array<DimSize_t, DIM>, + using Attributes_ = StaticAttributes<AvgPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1)> >; - template <AvgPoolingParam e> - using param = typename Parameterizable_::template param<e>; + std::array<DimSize_t, DIM>>; + template <AvgPoolingAttr e> + using attr = typename Attributes_::template attr<e>; constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), - Parameterizable_(param<AvgPoolingParam::StrideDims>(stride_dims), - param<AvgPoolingParam::KernelDims>(kernel_dims), - param<AvgPoolingParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims), + attr<AvgPoolingAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + AvgPooling_Op(const AvgPooling_Op<DIM>& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::AvgPooling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<AvgPooling_Op<DIM>>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); @@ -71,17 +88,15 @@ public: mInput = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInput->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; - for (std::size_t dim = 0; dim < this->template get<AvgPoolingParam::KernelDims>().size() ; ++dim) { + for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) { outputDims[dim+2] = 1 + static_cast<DimSize_t>( std::floor(static_cast<float>(mInput->dims()[dim+2] - - this->template get<AvgPoolingParam::KernelDims>()[dim] + - this->template get<AvgPoolingParam::PaddingDims>()[dim] + - this->template get<AvgPoolingParam::PaddingDims>()[dim+DIM]) / - static_cast<float>(this->template get<AvgPoolingParam::StrideDims>()[dim]))); + this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) / + static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim]))); } outputDims[1] = mInput->dims()[1]; outputDims[0] = mInput->dims()[0]; @@ -124,7 +139,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -132,7 +147,7 @@ public: mInput->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -142,34 +157,37 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, - const char *name = nullptr, - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { - // FIXME: properly handle default w&b initialization in every cases + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); - auto avgPool = std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims), name); - return avgPool; + return std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims), name); } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <DimSize_t DIM> inline std::shared_ptr<Node> AvgPooling( DimSize_t const (&kernel_dims)[DIM], - const char *name = nullptr, - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); - return AvgPooling(to_array(kernel_dims), name, stride_dims, padding_dims); + return AvgPooling(to_array(kernel_dims), name, stride_dims); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::AvgPoolingParam>::data[] = {"StrideDims", - "KernelDims", "PaddingDims"}; +const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims", + "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */ diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 07af5fa8416cf726e209cd9e690af345b321fb0e..da7360c8ba3816cdfe1d2d00f80b08808a80f961 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -21,17 +21,17 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" namespace Aidge { -enum class BatchNormParam { Epsilon, Momentum }; +enum class BatchNormAttr { Epsilon, Momentum }; template <DimIdx_t DIM> class BatchNorm_Op : public Operator, public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>, - public Parameterizable<BatchNormParam, float, float> { + public StaticAttributes<BatchNormAttr, float, float> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 5> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), @@ -44,18 +44,40 @@ public: BatchNorm_Op() = delete; - using Parameterizable_ = Parameterizable<BatchNormParam, float, float>; - template <BatchNormParam e> - using param = typename Parameterizable_::template param<e>; + using Attributes_ = StaticAttributes<BatchNormAttr, float, float>; + template <BatchNormAttr e> + using attr = typename Attributes_::template attr<e>; constexpr BatchNorm_Op(float epsilon, float momentum) : Operator(Type), - Parameterizable_(param<BatchNormParam::Epsilon>(epsilon), - param<BatchNormParam::Momentum>(momentum)), + Attributes_(attr<BatchNormAttr::Epsilon>(epsilon), + attr<BatchNormAttr::Momentum>(momentum)), mOutput(std::make_shared<Tensor>()) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + BatchNorm_Op(const BatchNorm_Op<DIM>& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<BatchNorm_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::BatchNorm_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<BatchNorm_Op<DIM>>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : @@ -65,18 +87,17 @@ public: // return *in; // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 5 && "operators supports only 5 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) { if(mInputs[i]->size() != mInputs[0]->dims()[1]) { - assert(!mInputs[0]->hasImpl() && "Incompatible size with already implemented learnable parameter"); mInputs[i]->resize(std::array<DimSize_t, 1>({mInputs[0]->dims()[1]})); } } @@ -115,7 +136,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -126,7 +147,7 @@ public: mInputs[4]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -139,12 +160,18 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 5; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "scale", "shift", "mean", "variance"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <DimSize_t DIM> inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, const float momentum = 0.1F, - const char *name = nullptr) { + const std::string& name = "") { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); @@ -157,7 +184,7 @@ inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, namespace { template <> -const char *const EnumStrings<Aidge::BatchNormParam>::data[] = { "Epsilon", "Momentum" }; +const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "Epsilon", "Momentum" }; } -#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_ \ No newline at end of file +#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_ diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index d6efba2cec6908ad58b9feea5e53807c7227cc88..b1e3e34b0eff681632d90cb8314ebd8c96722eec 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -21,18 +21,18 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" namespace Aidge { -enum class ConvParam { StrideDims, DilationDims, InChannels, OutChannels, KernelDims, PaddingDims }; +enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims }; template <DimIdx_t DIM> class Conv_Op : public Operator, public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, - public Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >> { + public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, + DimSize_t, std::array<DimSize_t, DIM>> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), @@ -44,28 +44,47 @@ public: Conv_Op() = delete; - using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, - DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>; - template <ConvParam e> - using param = typename Parameterizable_::template param<e>; + using Attributes_ = StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, + DimSize_t, DimSize_t, std::array<DimSize_t, DIM>>; + template <ConvAttr e> + using attr = typename Attributes_::template attr<e>; constexpr Conv_Op(DimSize_t in_channels, DimSize_t out_channels, const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), - Parameterizable_(param<ConvParam::StrideDims>(stride_dims), - param<ConvParam::DilationDims>(dilation_dims), - param<ConvParam::InChannels>(in_channels), - param<ConvParam::OutChannels>(out_channels), - param<ConvParam::KernelDims>(kernel_dims), - param<ConvParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + Attributes_(attr<ConvAttr::StrideDims>(stride_dims), + attr<ConvAttr::DilationDims>(dilation_dims), + attr<ConvAttr::InChannels>(in_channels), + attr<ConvAttr::OutChannels>(out_channels), + attr<ConvAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Conv_Op(const Conv_Op<DIM>& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Conv_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Conv_Op<DIM>>(*this); + } + // Data operator[](const char* inputName) override final { // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : // (strcmp(inputName, "weight") ? mInputs[1] : @@ -79,30 +98,28 @@ public: // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; - for (std::size_t dim = 0; dim < this->template get<ConvParam::KernelDims>().size() ; ++dim) { - const DimSize_t kernelExtent = this->template get<ConvParam::DilationDims>()[dim] * - (this->template get<ConvParam::KernelDims>()[dim] - 1) + + for (std::size_t dim = 0; dim < this->template getAttr<ConvAttr::KernelDims>().size() ; ++dim) { + const DimSize_t kernelExtent = this->template getAttr<ConvAttr::DilationDims>()[dim] * + (this->template getAttr<ConvAttr::KernelDims>()[dim] - 1) + 1; outputDims[dim+2] = 1 + static_cast<DimSize_t>( - floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + - this->template get<ConvParam::PaddingDims>()[dim] + - this->template get<ConvParam::PaddingDims>()[dim+DIM]) / - static_cast<float>(this->template get<ConvParam::StrideDims>()[dim]))); + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent) / + static_cast<float>(this->template getAttr<ConvAttr::StrideDims>()[dim]))); } - outputDims[1] = this->template get<ConvParam::OutChannels>(); + outputDims[1] = this->template getAttr<ConvAttr::OutChannels>(); outputDims[0] = mInputs[0]->dims()[0]; mOutput->resize(outputDims); } @@ -139,7 +156,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -148,7 +165,7 @@ public: mInputs[2]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -160,43 +177,53 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> Conv(DimSize_t in_channels, DimSize_t out_channels, const std::array<DimSize_t, DIM> &kernel_dims, - const char *name = nullptr, + const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); - auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, padding_dims, dilation_dims), name); + auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), name); // addProducer(conv, 1, append(append(kernel_dims, in_channels), out_channels), "w"); addProducer(conv, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); - addProducer(conv, 2, {out_channels}, "b"); + addProducer(conv, 2, std::array<DimSize_t, 1>({out_channels}), "b"); return conv; } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <DimSize_t DIM> inline std::shared_ptr<Node> Conv( DimSize_t in_channels, DimSize_t out_channels, DimSize_t const (&kernel_dims)[DIM], - const char *name = nullptr, + const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); - return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); + return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, dilation_dims); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::ConvParam>::data[] = {"StrideDims", "DilationDims", "InChannels", "OutChannels", - "KernelDims", "PaddingDims"}; +const char *const EnumStrings<Aidge::ConvAttr>::data[] = { + "StrideDims", + "DilationDims", + "InChannels", + "OutChannels", + "KernelDims" +}; } #endif /* AIDGE_CORE_OPERATOR_CONV_H_ */ diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index a3b7fbf3b21a5b3fd9e532e0cc19cebd46e5d022..4caec2032a3c61529d452ae855f00c1da411af10 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -21,22 +21,21 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" namespace Aidge { -enum class ConvDepthWiseParam { StrideDims, DilationDims, Channels, KernelDims, PaddingDims }; +enum class ConvDepthWiseAttr { StrideDims, DilationDims, Channels, KernelDims }; template <DimIdx_t DIM> class ConvDepthWise_Op : public Operator, public Registrable<ConvDepthWise_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ConvDepthWise_Op<DIM> &)>, - public Parameterizable<ConvDepthWiseParam, + public StaticAttributes<ConvDepthWiseAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), @@ -48,53 +47,69 @@ class ConvDepthWise_Op : public Operator, ConvDepthWise_Op() = delete; - using Parameterizable_ = Parameterizable<ConvDepthWiseParam, + using Attributes_ = StaticAttributes<ConvDepthWiseAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >>; - template <ConvDepthWiseParam e> - using param = typename Parameterizable_::template param<e>; + std::array<DimSize_t, DIM>>; + template <ConvDepthWiseAttr e> + using attr = typename Attributes_::template attr<e>; constexpr ConvDepthWise_Op(const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), - Parameterizable_(param<ConvDepthWiseParam::StrideDims>(stride_dims), - param<ConvDepthWiseParam::DilationDims>(dilation_dims), - param<ConvDepthWiseParam::Channels>(0), - param<ConvDepthWiseParam::KernelDims>(kernel_dims), - param<ConvDepthWiseParam::PaddingDims>(padding_dims)), - mOutput(std::make_shared<Tensor>()) { + Attributes_(attr<ConvDepthWiseAttr::StrideDims>(stride_dims), + attr<ConvDepthWiseAttr::DilationDims>(dilation_dims), + attr<ConvDepthWiseAttr::Channels>(0), + attr<ConvDepthWiseAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ConvDepthWise_Op(const ConvDepthWise_Op<DIM>& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<ConvDepthWise_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ConvDepthWise_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ConvDepthWise_Op<DIM>>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; - for (std::size_t dim = 0; dim < this->template get<ConvDepthWiseParam::KernelDims>().size() ; ++dim) { - const DimSize_t kernelExtent = this->template get<ConvDepthWiseParam::DilationDims>()[dim] * - (this->template get<ConvDepthWiseParam::KernelDims>()[dim] - 1) + + for (std::size_t dim = 0; dim < this->template getAttr<ConvDepthWiseAttr::KernelDims>().size() ; ++dim) { + const DimSize_t kernelExtent = this->template getAttr<ConvDepthWiseAttr::DilationDims>()[dim] * + (this->template getAttr<ConvDepthWiseAttr::KernelDims>()[dim] - 1) + 1; outputDims[dim+2] = 1 + static_cast<DimSize_t>( - floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + - this->template get<ConvDepthWiseParam::PaddingDims>()[dim] + - this->template get<ConvDepthWiseParam::PaddingDims>()[dim+DIM]) / - static_cast<float>(this->template get<ConvDepthWiseParam::StrideDims>()[dim]))); + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent) / + static_cast<float>(this->template getAttr<ConvDepthWiseAttr::StrideDims>()[dim]))); } - this->template get<ConvDepthWiseParam::Channels>() = mInputs[0]->dims()[1]; - // std::array<DimSize_t, DIM+2> weightDims = append(mInputs[0]->dims()[1],append(1, this->template get<ConvDepthWiseParam::KernelDims>())); + this->template getAttr<ConvDepthWiseAttr::Channels>() = mInputs[0]->dims()[1]; + // std::array<DimSize_t, DIM+2> weightDims = append(mInputs[0]->dims()[1],append(1, this->template getAttr<ConvDepthWiseAttr::KernelDims>())); // if (mInputs[1]->empty()) { // mInputs[1]->resize(weightDims); // } @@ -140,7 +155,7 @@ class ConvDepthWise_Op : public Operator, - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -149,7 +164,7 @@ class ConvDepthWise_Op : public Operator, mInputs[2]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -161,38 +176,43 @@ class ConvDepthWise_Op : public Operator, inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> ConvDepthWise(const std::array<DimSize_t, DIM> &kernel_dims, - const char *name = nullptr, + const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); - auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims, dilation_dims), name); + auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), name); addProducer(convDW, 1, std::array<DimSize_t,0>({}), "w"); addProducer(convDW, 2, std::array<DimSize_t,0>({}), "b"); return convDW; } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <DimSize_t DIM> inline std::shared_ptr<Node> ConvDepthWise( DimSize_t const (&kernel_dims)[DIM], - const char *name = nullptr, + const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); - return ConvDepthWise(to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); + return ConvDepthWise(to_array(kernel_dims), name, stride_dims, dilation_dims); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::ConvDepthWiseParam>::data[] = {"StrideDims", "DilationDims", "Channels", - "KernelDims", "PaddingDims"}; +const char *const EnumStrings<Aidge::ConvDepthWiseAttr>::data[] = {"StrideDims", "DilationDims", "Channels", + "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H_ */ diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4213f979cf9d675f523a228095edc5606f9412ee --- /dev/null +++ b/include/aidge/operator/Div.hpp @@ -0,0 +1,146 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_DIV_H_ +#define AIDGE_CORE_OPERATOR_DIV_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Div_Op : public Operator, + public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Div"; + + Div_Op() + : Operator(Type) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Div_Op(const Div_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Div_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Div_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Div_Op>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInputs[0]->empty()) + mOutput->resize(mInputs[0]->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < 2 && "wrong inputIdx for Add operator."); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx < 2) && "Div Operator has 2 inputs"); + (void) inputIdx; // avoid unused warning + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Div Operator has only 1 output"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Div_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[0]->setBackend(name); + mInputs[1]->setBackend(name); + } + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 2; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 2; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Div(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Div_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_DIV_H_ */ diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 6e4c54a030c108c29c08a8f5dfdc24d084ccc91c..b949527c51b9330077dd3bd8f8b4bf1f1b9d719c 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -23,17 +23,17 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" namespace Aidge { -enum class FCParam { OutChannels, NoBias }; +enum class FCAttr { OutChannels, NoBias }; class FC_Op : public Operator, public Registrable<FC_Op, std::string, std::unique_ptr<OperatorImpl>(const FC_Op &)>, - public Parameterizable<FCParam, DimSize_t, bool> { + public StaticAttributes<FCAttr, DimSize_t, bool> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), std::make_shared<Tensor>()}; @@ -44,24 +44,45 @@ public: FC_Op() = delete; - using Parameterizable_ = Parameterizable<FCParam, DimSize_t, bool>; - template <FCParam e> using param = typename Parameterizable_::template param<e>; + using Attributes_ = StaticAttributes<FCAttr, DimSize_t, bool>; + template <FCAttr e> using attr = typename Attributes_::template attr<e>; FC_Op(DimSize_t out_channels, bool noBias) : Operator(Type), - Parameterizable_( - param<FCParam::OutChannels>(out_channels), - param<FCParam::NoBias>(noBias)), - mOutput(std::make_shared<Tensor>()) + Attributes_( + attr<FCAttr::OutChannels>(out_channels), + attr<FCAttr::NoBias>(noBias)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + FC_Op(const FC_Op& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<FC_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::FC_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<FC_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); if (inputIdx == 2) { - assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template get<FCParam::NoBias>()) == false ? static_cast<std::size_t>(this->template get<FCParam::OutChannels>()) : 0)); + assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template getAttr<FCAttr::NoBias>()) == false ? static_cast<std::size_t>(this->template getAttr<FCAttr::OutChannels>()) : 0)); assert(std::dynamic_pointer_cast<Tensor>(data)->nbDims() == 1); } mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -72,9 +93,9 @@ public: void computeOutputDims() override final { if (!mInputs[0]->empty()) { // <in_features**, out_channels> - std::array<DimSize_t, 2> weightDims = {this->template get<FCParam::OutChannels>(), static_cast<DimSize_t>(mInputs[0]->sizeM1())}; + std::array<DimSize_t, 2> weightDims = {this->template getAttr<FCAttr::OutChannels>(), static_cast<DimSize_t>(mInputs[0]->sizeM1())}; // <out_channels, batch> - std::array<DimSize_t, 2> outputDims = {mInputs[0]->dims()[0], this->template get<FCParam::OutChannels>()}; + std::array<DimSize_t, 2> outputDims = {mInputs[0]->dims()[0], this->template getAttr<FCAttr::OutChannels>()}; mInputs[1]->resize(weightDims); mOutput->resize(outputDims); @@ -114,7 +135,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<FC_Op>::create(name)(*this); mOutput->setBackend(name); @@ -124,7 +145,7 @@ public: mInputs[2]->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -137,21 +158,27 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; -inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const char* name = nullptr) { +inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const std::string& name = "") { // FIXME: properly handle default w&b initialization in every cases auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(out_channels, noBias), name); - addProducer(fc, 1, {out_channels, 1}, "w"); - addProducer(fc, 2, {(noBias ? 0 : out_channels)}, "b"); // already sets bias dims + addProducer(fc, 1, std::array<DimSize_t, 2>({out_channels, 1}), "w"); + addProducer(fc, 2, (noBias ? std::array<DimSize_t, 1>({0}) : std::array<DimSize_t, 1>({out_channels})), "b"); // already sets bias dims return fc; } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::FCParam>::data[] = {"OutChannels", +const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels", "NoBias"}; } -#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index a3e1f02912fb3abdc8adeb09971ee090e875c1fb..55ccbf1516fa79663d57e1e44bc4017bc5c8b843 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -16,24 +16,29 @@ #include <vector> #include <string> #include <cassert> +#include <cstring> #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" -#include "aidge/utils/CParameter.hpp" +#include "aidge/utils/DynamicAttributes.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" + namespace Aidge { class GenericOperator_Op : public Operator, - public Registrable<GenericOperator_Op, std::string, std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)> { + public Registrable<GenericOperator_Op, std::string, std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)>, + public DynamicAttributes { private: - CParameter mParams; + using ComputeDimsFunc = std::function<std::vector<std::vector<size_t>>(const std::vector<std::vector<size_t>>&)>; + IOIndex_t mNbDataIn; IOIndex_t mNbIn; IOIndex_t mNbOut; std::vector<std::shared_ptr<Tensor>> mInputs; std::vector<std::shared_ptr<Tensor>> mOutputs; + ComputeDimsFunc mComputeOutputDims; public: GenericOperator_Op(const char *type, IOIndex_t nbDataIn, IOIndex_t nbIn, IOIndex_t nbOut) @@ -50,52 +55,76 @@ class GenericOperator_Op } /** - * @brief Get the Parameter object identified by its name. - * @tparam T expected parameter type. - * @param key Parameter name. - * @details assert if T is not the actual parameter type, if the parameter - * does not exist or internal parameter position is invalid. - * @todo Returning a T const& ? But dangerous => may get an address within - * param buffer that will get invalid after the CParam death. - * @note at() throws if the parameter does not exist, using find to test - * for parameter existance - * @return template<class T> The parameter. + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. */ - template <class T> - T getParameter(std::string const &key) const { - return mParams.Get<T>(key); + GenericOperator_Op(const GenericOperator_Op& op) + : Operator(op.type().c_str()), mNbDataIn(op.mNbDataIn), mNbIn(op.mNbIn), mNbOut(op.mNbOut) + { + // cpy-ctor + mInputs = std::vector<std::shared_ptr<Tensor>>(mNbIn); + for (std::size_t i = 0; i < mNbIn; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + mOutputs = std::vector<std::shared_ptr<Tensor>>(mNbOut); + for (std::size_t i = 0; i < mNbOut; ++i) { + mOutputs[i] = std::make_shared<Tensor>(*op.mOutputs[i]); + } } - ///\brief Add a parameter value, identified by its name - ///\tparam T expected parameter type - ///\param i_ParamName Parameter name - ///\param i_Value Parameter value - ///\todo Pass i_Value by ref if large or not trivial - ///\bug If parameter already exists, its value is changed but written in the - /// internal buffer in a new location (previous value is still in memory at - /// its previous location) - template <class T> - void addParameter(std::string const &key, T const &value) { - mParams.Add<T>(key, value); + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::GenericOperator_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<GenericOperator_Op>(*this); } + // Helper functions that can be used with setComputeOutputDims(): + static const ComputeDimsFunc Identity; - std::string getParameterType(std::string const &key) { return mParams.getParamType(key); } - - std::vector<std::string> getParametersName() { return mParams.getParametersName(); } + void setComputeOutputDims(ComputeDimsFunc func) { + mComputeOutputDims = func; + } // Override Virtual Opertor methods - void associateInput(const IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { - printf("Info: using associateInput() on a GenericOperator.\n"); + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < mNbIn && "operators supports only x inputs"); + + if (strcmp(data->type(), Tensor::Type) == 0) { + // TODO: associate input only if of type Tensor, otherwise do nothing + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } } void computeOutputDims() override final { - assert(false && "Cannot compute output dim of a GenericOperator"); + if (mComputeOutputDims) { + std::vector<std::vector<size_t>> inputsDims(mNbIn, std::vector<size_t>()); + for (std::size_t i = 0; i < mNbIn; ++i) { + if (mInputs[i]) { + inputsDims[i] = mInputs[i]->dims(); + } + } + + const auto& outputsDims = mComputeOutputDims(inputsDims); + assert(outputsDims.size() == mNbOut && "The provided ComputeDimsFunc function returns the wrong number of outputs"); + for (std::size_t i = 0; i < mNbOut; ++i) { + mOutputs[i]->resize(outputsDims[i]); + } + } + else { + assert(false && "Cannot compute output dim of a GenericOperator"); + } } bool outputDimsForwarded() const override final { - assert(false && "GenericOperator cannot forward dims"); - return false; + if (mComputeOutputDims) { + return !(mOutputs[0]->empty()); + } + else { + assert(false && "GenericOperator cannot forward dims"); + return false; + } } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { @@ -137,11 +166,22 @@ class GenericOperator_Op ~GenericOperator_Op() = default; - void setBackend(const std::string & /*name*/) { printf("setBackend: not available yet.\n"); } - void setDatatype(const DataType & /*datatype*/) { printf("setDatatype: not available yet.\n"); } - void forward() override final { printf("forward: not available yet.\n"); } - void backward() override final { printf("backward: not available yet.\n"); } - + void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); } + void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); } + void forward() override final { + if(mImpl){ + mImpl->forward(); + }else{ + printf("forward: No implementation is linked.\n"); + } + } + void backward() override final { + if(mImpl){ + mImpl->backward(); + }else{ + printf("backward: No implementation is linked.\n"); + } + } inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }; inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; }; inline IOIndex_t nbOutputs() const noexcept override final { return mNbOut; }; @@ -158,7 +198,7 @@ class GenericOperator_Op * @return std::shared_ptr<Node> Node associated with the Generic Operator. */ inline std::shared_ptr<Node> GenericOperator(const char *type, IOIndex_t nbDataIn, IOIndex_t nbIn, IOIndex_t nbOut, - const char *name = nullptr) { + const std::string& name = "") { return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbDataIn, nbIn, nbOut), name); } } // namespace Aidge diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 64587d51de784082da455eb64aa5bbe175773b5d..bcdcbc7cabd8eda46a7c0c4930f317e562fb46a0 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -15,7 +15,7 @@ #include <vector> #include <memory> -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/backend/OperatorImpl.hpp" @@ -25,13 +25,13 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class LeakyReLUParam { +enum class LeakyReLUAttr { NegativeSlope }; class LeakyReLU_Op : public Operator, public Registrable<LeakyReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const LeakyReLU_Op&)>, - public Parameterizable<LeakyReLUParam, float> { + public StaticAttributes<LeakyReLUAttr, float> { public: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -42,17 +42,39 @@ public: LeakyReLU_Op() = delete; - using Parameterizable_ = Parameterizable<LeakyReLUParam, float>; - template <LeakyReLUParam e> using param = typename Parameterizable_::template param<e>; + using Attributes_ = StaticAttributes<LeakyReLUAttr, float>; + template <LeakyReLUAttr e> using attr = typename Attributes_::template attr<e>; LeakyReLU_Op(float negativeSlope) : Operator(Type), - Parameterizable_( - param<LeakyReLUParam::NegativeSlope>(negativeSlope)) + Attributes_( + attr<LeakyReLUAttr::NegativeSlope>(negativeSlope)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + LeakyReLU_Op(const LeakyReLU_Op& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<LeakyReLU_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::LeakyReLU_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<LeakyReLU_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning @@ -98,14 +120,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<LeakyReLU_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -115,17 +137,22 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; -inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const char* name = nullptr) { - // FIXME: properly handle default w&b initialization in every cases +inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<LeakyReLU_Op>(negativeSlope), name); } } namespace { template <> -const char* const EnumStrings<Aidge::LeakyReLUParam>::data[] +const char* const EnumStrings<Aidge::LeakyReLUAttr>::data[] = {"NegativeSlope"}; } diff --git a/include/aidge/operator/Matmul.hpp b/include/aidge/operator/MatMul.hpp similarity index 61% rename from include/aidge/operator/Matmul.hpp rename to include/aidge/operator/MatMul.hpp index b44e8a9b9540e287ff35af1c9642c8202fd096d0..eed1ec04535aa5896aa3d01a27d8023d37a42183 100644 --- a/include/aidge/operator/Matmul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -23,38 +23,59 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" namespace Aidge { -enum class MatmulParam { OutChannels }; +enum class MatMulAttr { OutChannels }; -class Matmul_Op : public Operator, - public Registrable<Matmul_Op, +class MatMul_Op : public Operator, + public Registrable<MatMul_Op, std::string, - std::unique_ptr<OperatorImpl>(const Matmul_Op &)>, - public Parameterizable<MatmulParam, DimSize_t> { + std::unique_ptr<OperatorImpl>(const MatMul_Op &)>, + public StaticAttributes<MatMulAttr, DimSize_t> { public: std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()}; const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: - static constexpr const char* Type = "Matmul"; + static constexpr const char* Type = "MatMul"; - Matmul_Op() = delete; + MatMul_Op() = delete; - using Parameterizable_ = Parameterizable<MatmulParam, DimSize_t>; - template <MatmulParam e> using param = typename Parameterizable_::template param<e>; + using Attributes_ = StaticAttributes<MatMulAttr, DimSize_t>; + template <MatMulAttr e> using attr = typename Attributes_::template attr<e>; - Matmul_Op(DimSize_t out_channels) + MatMul_Op(DimSize_t out_channels) : Operator(Type), - Parameterizable_( - param<MatmulParam::OutChannels>(out_channels)), - mOutput(std::make_shared<Tensor>()) + Attributes_( + attr<MatMulAttr::OutChannels>(out_channels)) { setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + MatMul_Op(const MatMul_Op& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<MatMul_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::MatMul_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<MatMul_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 2 && "operators supports only 2 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); @@ -64,9 +85,9 @@ public: void computeOutputDims() override final { if (!mInputs[0]->empty()) { // <in_features**, out_channels> - std::array<DimSize_t, 2> weightDims = {static_cast<DimSize_t>(mInputs[0]->size()), this->template get<MatmulParam::OutChannels>()}; + std::array<DimSize_t, 2> weightDims = {this->template getAttr<MatMulAttr::OutChannels>(), static_cast<DimSize_t>(mInputs[0]->sizeM1())}; // <out_channels, batch> - std::array<DimSize_t, 1> outputDims = {this->template get<MatmulParam::OutChannels>()}; + std::array<DimSize_t, 2> outputDims = {mInputs[0]->dims()[0], this->template getAttr<MatMulAttr::OutChannels>()}; mInputs[1]->resize(weightDims); mOutput->resize(outputDims); @@ -106,8 +127,8 @@ public: } - void setBackend(const std::string& name) { - mImpl = Registrar<Matmul_Op>::create(name)(*this); + void setBackend(const std::string& name) override { + mImpl = Registrar<MatMul_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround @@ -115,7 +136,7 @@ public: mInputs[1]->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -127,19 +148,25 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 2; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; -inline std::shared_ptr<Node> Matmul(DimSize_t out_channels, const char* name = nullptr) { - // FIXME: properly handle default w&b initialization in every cases - auto matmul = std::make_shared<Node>(std::make_shared<Matmul_Op>(out_channels), name); - addProducer(matmul, 1, {1, out_channels}, "w"); +inline std::shared_ptr<Node> MatMul(DimSize_t out_channels, const std::string& name = "") { + // FIXME: properly handle default w initialization in every cases + auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(out_channels), name); + addProducer(matmul, 1, std::array<DimSize_t, 2>({out_channels, 1}), "w"); return matmul; } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::MatmulParam>::data[] = {"OutChannels"}; +const char *const EnumStrings<Aidge::MatMulAttr>::data[] = {"OutChannels"}; } #endif /* AIDGE_CORE_OPERATOR__MATMUL_H_ */ diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bcf47f13cc34132f668ea1ffcb2c91ed6f06f44d --- /dev/null +++ b/include/aidge/operator/MaxPooling.hpp @@ -0,0 +1,206 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_MAXPOOLING_H_ +#define AIDGE_CORE_OPERATOR_MAXPOOLING_H_ + +#include <array> +#include <numeric> +#include <vector> +#include <cmath> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class MaxPoolingAttr { StrideDims, KernelDims, CeilMode }; + +template <DimIdx_t DIM> +class MaxPooling_Op : public Operator, + public Registrable<MaxPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const MaxPooling_Op<DIM> &)>, + public StaticAttributes<MaxPoolingAttr, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, DIM>, + bool> { +private: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char *Type = "MaxPooling"; + + MaxPooling_Op() = delete; + + using Attributes_ = StaticAttributes<MaxPoolingAttr, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, DIM>, + bool>; + template <MaxPoolingAttr e> + using attr = typename Attributes_::template attr<e>; + + constexpr MaxPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + bool ceil_mode = false) + : Operator(Type), + Attributes_(attr<MaxPoolingAttr::StrideDims>(stride_dims), + attr<MaxPoolingAttr::KernelDims>(kernel_dims), + attr<MaxPoolingAttr::CeilMode>(ceil_mode)), + mOutput(std::make_shared<Tensor>()) { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + MaxPooling_Op(const MaxPooling_Op<DIM>& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<MaxPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::MaxPooling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<MaxPooling_Op<DIM>>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 1 && "operators supports only 3 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) { + std::array<DimSize_t, DIM + 2> outputDims = {}; + + std::function<float(float)> roundingFunction; + if (this->template getAttr<MaxPoolingAttr::CeilMode>()) { + roundingFunction = [](float x) { return std::ceil(x); }; + } else { + roundingFunction = [](float x) { return std::floor(x); }; + } + + for (std::size_t dim = 0; dim < this->template getAttr<MaxPoolingAttr::KernelDims>().size() ; ++dim) { + outputDims[dim+2] = 1 + static_cast<DimSize_t>( + roundingFunction(static_cast<float>(mInput->dims()[dim+2] - + this->template getAttr<MaxPoolingAttr::KernelDims>()[dim]) / + static_cast<float>(this->template getAttr<MaxPoolingAttr::StrideDims>()[dim]))); + } + outputDims[1] = mInput->dims()[1]; + outputDims[0] = mInput->dims()[0]; + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return *(mInput.get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "MaxPooling Operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "MaxPooling Operators has only 1 outputs"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string &name) override { + mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + + void setDatatype(const DataType &datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> MaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + bool ceil_mode=false) { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by MaxPooling, not supported"); + return std::make_shared<Node>(std::make_shared<MaxPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, ceil_mode), name); +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> MaxPooling( + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + bool ceil_mode = false) { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by MaxPooling, not supported"); + return MaxPooling(to_array(kernel_dims), name, stride_dims, ceil_mode); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::MaxPoolingAttr>::data[] = {"StrideDims", "KernelDims", "CeilMode"}; +} + +#endif /* AIDGE_CORE_OPERATOR_MAXPOOLING_H_ */ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 35a59b56cbf5c10a78116f81de96a8baddc03ff0..72058dfcba6e811a01a22e261208741879638cad 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -13,16 +13,155 @@ #define AIDGE_CORE_OPERATOR_METAOPERATOR_H_ #include "aidge/operator/Operator.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/scheduler/Scheduler.hpp" namespace Aidge { -class MetaOperator : public Operator { +class MetaOperator_Op : public Operator, + public Registrable<MetaOperator_Op, std::array<std::string, 2>, std::unique_ptr<OperatorImpl>(const MetaOperator_Op &)> { public: - MetaOperator() - : Operator("MetaOp") + std::vector<std::shared_ptr<Tensor>> mInputs; + std::vector<std::shared_ptr<Tensor>> mOutputs; // These are shared with micro-graph outputs tensors + + // Micro-graph handling: + std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph + std::shared_ptr<SequentialScheduler> mScheduler; + // Need to store an ordored list of input/output operators for the micro-graph, + // because input/output nodes in a GraphView are unordered. + // TODO: refactor GraphView to handle ordered input/output? + std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mInputOps; + std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mOutputOps; + + public: + MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, + std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), + std::vector<NodePtr> outputNodes = std::vector<NodePtr>()); + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + MetaOperator_Op(const MetaOperator_Op& op) + : Operator(op.type().c_str()), + mGraph(op.mGraph->clone()) { + // cpy-ctor } - ~MetaOperator() = default; + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::MatMul_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<MetaOperator_Op>(*this); + } + + const std::shared_ptr<GraphView>& getMicroGraph() const { + return mGraph; + } + + const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const { + return mScheduler; + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + const auto& inputOp = mInputOps[inputIdx]; + inputOp.first->associateInput(inputOp.second, data); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + // Forward dims of micro-graph + mGraph->forwardDims(); + + // Associate outputs to micro-graph outputs for custom implementation + for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { + const auto& outputOp = mOutputOps[outputIdx]; + mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + } + } + + bool outputDimsForwarded() const override final { return !(mOutputs[0]->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return *(mInputs[inputIdx].get()); + } + + inline Tensor& output(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return *(mOutputs[outputIdx].get()); + } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return mInputs[inputIdx]; + } + + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return mOutputs[outputIdx]; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return std::static_pointer_cast<Data>(mOutputs[outputIdx]); + } + + void setBackend(const std::string &name) override { + if (Registrar<MetaOperator_Op>::exists({name, type()})) { + // A custom implementation exists for this meta operator + mImpl = Registrar<MetaOperator_Op>::create({name, type()})(*this); + } + + // The micro-graph should always be set to the right backend, since it + // shares input/output tensors. + // Input/output tensors backend are updated here. + mGraph->setBackend(name); + } + + void setDatatype(const DataType &datatype) override { + // The micro-graph should always be set to the right data type, since it + // shares input/output tensors. + // Input/output tensors data type are updated here. + mGraph->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return mGraph->inputs().size(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return mGraph->dataInputs().size(); } + inline IOIndex_t nbOutputs() const noexcept override final { return mGraph->outputs().size(); } + + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; + NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; + + void updateConsummerProducer() override; + void forward() override; + void backward() override { + assert(false && "not implemented"); + } + }; + +inline std::shared_ptr<Node> MetaOperator(const char *type, + const std::shared_ptr<GraphView>& graph, + const std::string& name = "", + std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), + std::vector<NodePtr> outputNodes = std::vector<NodePtr>()) +{ + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph, inputNodes, outputNodes), name); } +} // namespace Aidge #endif /* MetaOperator_H_ */ diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..73feb134837787ae8d0d280dd723182c9d21438b --- /dev/null +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -0,0 +1,142 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ +#define AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/Pad.hpp" + +namespace Aidge { +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + // Construct micro-graph + auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); + auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); + // Need to specify the ordered list of input operators + const std::vector<NodePtr> orderedInputNodes = {pad, conv}; + + auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name, orderedInputNodes); + addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); + addProducer(metaOp, 2, {out_channels}, "b"); + return metaOp; +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedConv( + DimSize_t in_channels, + DimSize_t out_channels, + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + return PaddedConv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise(const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + // Construct micro-graph + auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); + auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); + // Need to specify the ordered list of input operators + const std::vector<NodePtr> orderedInputNodes = {pad, conv}; + + auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name, orderedInputNodes); + addProducer(metaOp, 1, std::array<DimSize_t,0>({}), "w"); + addProducer(metaOp, 2, std::array<DimSize_t,0>({}), "b"); + return metaOp; +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise( + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + return PaddedConvDepthWise(to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) +{ + auto graph = Sequential({ + Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), + AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims) + }); + + return MetaOperator("PaddedAvgPooling", graph, name); +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedAvgPooling( + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) +{ + return PaddedAvgPooling(to_array(kernel_dims), name, stride_dims, padding_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + bool ceil_mode = false) +{ + auto graph = Sequential({ + Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), + MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims, ceil_mode) + }); + + return MetaOperator("PaddedMaxPooling", graph, name); +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedMaxPooling( + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + bool ceil_mode= false) +{ + return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode); +} +} // namespace Aidge + +#endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */ diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4ea79fe52622b22f8ea8fbd9191d50d45e26acac --- /dev/null +++ b/include/aidge/operator/Mul.hpp @@ -0,0 +1,146 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_MUL_H_ +#define AIDGE_CORE_OPERATOR_MUL_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Mul_Op : public Operator, + public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Mul"; + + Mul_Op() + : Operator(Type) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Mul_Op(const Mul_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Mul_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Mul_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Mul_Op>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInputs[0]->empty()) + mOutput->resize(mInputs[0]->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < 2 && "wrong inputIdx for Add operator."); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx < 2) && "Mul Operator has 2 inputs"); + (void) inputIdx; // avoid unused warning + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Mul Operator has only 1 output"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Mul_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[0]->setBackend(name); + mInputs[1]->setBackend(name); + } + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 2; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 2; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Mul(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Mul_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_MUL_H_ */ diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 30e1ce2a7f664485077282405ec60ddf49513cb5..903b6362adf3db0c867dc419086e0cb6ddaa65c7 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -20,12 +20,14 @@ #include "aidge/data/Data.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/Types.h" +#include "aidge/hook/Hook.hpp" namespace Aidge { class Operator : public std::enable_shared_from_this<Operator> { protected: - std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator + std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator + std::map<std::string, std::shared_ptr<Hook>> mHooks; private: std::string mType; @@ -33,8 +35,18 @@ private: public: Operator() = delete; Operator(const char* type) : mType(type) {} + virtual std::shared_ptr<Operator> clone() const = 0; virtual ~Operator(); + Operator(const Operator& op): + std::enable_shared_from_this<Operator>() + { + mType = op.mType; + mImpl = nullptr; + // Implementation is never cloned. It is up to the non-abstract Operator copy-constructor to create a new implementation matching the copied Operator implementation. + // See https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/merge_requests/8#note_1214050 for the discussion. + // Hooks are not copied. + } public: @@ -48,6 +60,15 @@ public: virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0; virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0; + std::shared_ptr<Hook> getHook(std::string hookName) { + return mHooks[hookName]; + } + void addHook(std::string hookName) { + mHooks.insert(std::pair<std::string, std::shared_ptr<Hook>>(hookName,Registrar<Hook>::create({hookName})(shared_from_this()))); + } + + void runHooks() const; + /////////////////////////////////////////////////////// // IMPLEMENTATION /////////////////////////////////////////////////////// @@ -55,12 +76,20 @@ public: virtual void setBackend(const std::string& name) = 0; virtual void setDatatype(const DataType& datatype) = 0; + /** + * @brief Set the a new OperatorImpl to the Operator + * + */ + void setImpl(std::shared_ptr<OperatorImpl> impl){ + mImpl = impl; + } + /** * @brief Minimum amount of data from a specific input for one computation pass. * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; /** * @brief Amount of data from a specific input actually used in one computation pass. @@ -68,7 +97,7 @@ public: * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; + virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Amount of data ready to be used on a specific output. @@ -76,7 +105,9 @@ public: * @param outputIdx Index of the output analysed. * @return NbElts_t */ - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + + virtual void updateConsummerProducer(); virtual void forward(); @@ -93,6 +124,12 @@ public: virtual IOIndex_t nbInputs() const noexcept = 0; virtual IOIndex_t nbDataInputs() const noexcept = 0; virtual IOIndex_t nbOutputs() const noexcept = 0; + static const std::vector<std::string> getInputsName(){ + return {}; + } + static const std::vector<std::string> getOutputsName(){ + return {}; + } }; } // namespace Aidge diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbebb16e1e24501b0ea371fb45211047f6e2b5e7 --- /dev/null +++ b/include/aidge/operator/Pad.hpp @@ -0,0 +1,201 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_PAD_H_ +#define AIDGE_CORE_OPERATOR_PAD_H_ + +#include <array> +#include <numeric> +#include <vector> +#include <cmath> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class PadAttr { BeginEndBorders, BorderType, BorderValue }; +enum class PadBorderType { Constant, Edge, Reflect, Wrap }; + +template <DimIdx_t DIM> +class Pad_Op : public Operator, + public Registrable<Pad_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Pad_Op<DIM> &)>, + public StaticAttributes<PadAttr, + std::array<DimSize_t, 2*DIM>, + PadBorderType, + double> { +private: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char *Type = "Pad"; + + Pad_Op() = delete; + + using Attributes_ = StaticAttributes<PadAttr, + std::array<DimSize_t, 2*DIM>, + PadBorderType, + double>; + template <PadAttr e> + using attr = typename Attributes_::template attr<e>; + + constexpr Pad_Op(const std::array<DimSize_t, 2*DIM> &beginEndTuples, + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) + : Operator(Type), + Attributes_(attr<PadAttr::BeginEndBorders>(beginEndTuples), + attr<PadAttr::BorderType>(borderType), + attr<PadAttr::BorderValue>(borderValue)) { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Pad_Op(const Pad_Op& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Pad_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Pad_Op<DIM>>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 1 && "operators supports only 3 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) { + std::array<DimSize_t, DIM + 2> outputDims = {}; + + for (std::size_t dim = 0; dim < DIM; ++dim) { + outputDims[dim+2] = this->template getAttr<PadAttr::BeginEndBorders>()[2*dim] + + mInput->dims()[dim+2] + + this->template getAttr<PadAttr::BeginEndBorders>()[2*dim+1]; + } + outputDims[1] = mInput->dims()[1]; + outputDims[0] = mInput->dims()[0]; + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return *(mInput.get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "Pad Operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "Pad Operators has only 1 outputs"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string &name) override { + mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + + void setDatatype(const DataType &datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, 2*DIM> &beginEndTuples, + const std::string& name = "", + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) +{ + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); + return std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, borderType, borderValue), name); +} + +// helper with C-style array instead of std::array for beginEndTuples to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> Pad( + DimSize_t const (&beginEndTuples)[2*DIM], + const std::string& name = "", + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) +{ + return Pad<DIM>(to_array(beginEndTuples), name, borderType, borderValue); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::PadAttr>::data[] = {"BeginEndBorders", "BorderType", "BorderValue"}; + +template <> +const char *const EnumStrings<Aidge::PadBorderType>::data[] = {"Constant", "Edge", "Reflect", "Wrap"}; +} + +#endif /* AIDGE_CORE_OPERATOR_PAD_H_ */ diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp new file mode 100644 index 0000000000000000000000000000000000000000..732cf36b4ef7e7640648c542191acd02d0875a4f --- /dev/null +++ b/include/aidge/operator/Pow.hpp @@ -0,0 +1,146 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_POW_H_ +#define AIDGE_CORE_OPERATOR_POW_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Pow_Op : public Operator, + public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Pow"; + + Pow_Op() + : Operator(Type) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Pow_Op(const Pow_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Pow_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Pow_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Pow_Op>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInputs[0]->empty()) + mOutput->resize(mInputs[0]->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < 2 && "wrong inputIdx for Add operator."); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx < 2) && "Pow Operator has 2 inputs"); + (void) inputIdx; // avoid unused warning + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Pow Operator has only 1 output"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Pow_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[0]->setBackend(name); + mInputs[1]->setBackend(name); + } + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 2; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 2; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Pow(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Pow_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_POW_H_ */ diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 1f77400ce8a8ef727ea9e0a7d12477c6519ea2df..d747b340618cc7e321f2cfc2ed9169798e5d77e9 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -19,7 +19,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" namespace Aidge { @@ -29,15 +29,14 @@ class Producer_Op public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( const Producer_Op &)> { private: - std::shared_ptr<Tensor> mOutput; + std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char* Type = "Producer"; template <std::size_t DIM> Producer_Op(const std::array<DimSize_t, DIM>& dims) - : Operator(Type), - mOutput(std::make_shared<Tensor>()) + : Operator(Type) { //ctor setDatatype(DataType::Float32); @@ -51,10 +50,41 @@ public: setDatatype(tensor->dataType()); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Producer_Op(const Producer_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Producer_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Producer_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Producer_Op>(*this); + } + void associateInput(const IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { assert(false && "Producer operator takes no input"); } + /** + * @brief Set the Output Tensor of the Producer operator. + * This method will create a copy of the Tensor. + * + * @param newOutput Tensor containing the values to copy + */ + void setOutputTensor(const Tensor& newOutput) { + *mOutput = newOutput; + } + void computeOutputDims() override final {} bool outputDimsForwarded() const override final {return true;} @@ -91,17 +121,23 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutput->dims(); } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Producer_Op>::create(name)(*this); mOutput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); } inline IOIndex_t nbInputs() const noexcept override final { return 0; }; inline IOIndex_t nbDataInputs() const noexcept override final { return 0; }; inline IOIndex_t nbOutputs() const noexcept override final { return 1; }; + static const std::vector<std::string> getInputsName(){ + return {}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } public: void forward() override final { @@ -113,34 +149,36 @@ public: }; template <std::array<DimSize_t, 1>::size_type DIM> -inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const char *name = nullptr) { +inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "") { static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported"); return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name); } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <std::size_t DIM> -inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const char *name = nullptr) { +inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "") { return Producer(to_array(dims), name); } -inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const char *name = nullptr) { +inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor), name); } template <std::array<DimSize_t, 1>::size_type DIM> -void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, const std::array<DimSize_t, DIM>& dims, const char* extension) { +void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, const std::array<DimSize_t, DIM>& dims, const std::string& extension) { assert(inputIdx != gk_IODefaultIndex); static_assert(DIM<=MaxDim,"Too many tensor dimensions required by addProducer, not supported"); - const char* prodName = otherNode->name().empty() ? nullptr : (otherNode->name() + std::string("_") + std::string(extension)).c_str(); + const std::string prodName = (otherNode->name().empty()) ? "" : (otherNode->name() + std::string("_") + extension); auto prod = Producer(dims, prodName); prod->addChild(otherNode, 0, inputIdx); otherNode->getOperator()->associateInput(inputIdx, prod->getOperator()->getRawOutput(0)); } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <std::size_t DIM> -void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, DimSize_t const (&dims)[DIM], const char* extension) { +void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, DimSize_t const (&dims)[DIM], const std::string& extension) { addProducer(otherNode, inputIdx, to_array(dims), extension); } } // namespace Aidge -#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 3ea90462cf2b083a1a61ae39be06471093ec9f9f..52f13f1c5ce1d0b7a0d4ccaa4d7fe9927bcc3e53 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -42,6 +42,27 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ReLU_Op(const ReLU_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<ReLU_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ReLU_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ReLU_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning @@ -87,14 +108,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<ReLU_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -104,10 +125,15 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; -inline std::shared_ptr<Node> ReLU(const char* name = nullptr) { - // FIXME: properly handle default w&b initialization in every cases +inline std::shared_ptr<Node> ReLU(const std::string& name = "") { return std::make_shared<Node>(std::make_shared<ReLU_Op>(), name); } } diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..353666fb3950d034a7dbe8ec1d3ebdb312679f95 --- /dev/null +++ b/include/aidge/operator/Scaling.hpp @@ -0,0 +1,168 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_Scaling_H__ +#define __AIDGE_CORE_OPERATOR_Scaling_H__ + +#include <vector> +#include <memory> + + + +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class ScalingAttr { + scalingFactor +}; + +class Scaling_Op : public Operator, + public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, + public StaticAttributes<ScalingAttr, float> { +public: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Scaling"; + + Scaling_Op() = delete; + + using Attributes_ = StaticAttributes<ScalingAttr, float>; + template <ScalingAttr e> using attr = typename Attributes_::template attr<e>; + + Scaling_Op(float scalingFactor) + : Operator(Type), + Attributes_( + attr<ScalingAttr::scalingFactor>(scalingFactor)) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Scaling_Op(const Scaling_Op& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Scaling_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Scaling_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Scaling_Op>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + (void) inputIdx; //avoid unused warning + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) + mOutput->resize(mInput->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "Scaling Operator has only 1 input"); + (void) inputIdx; // avoid unused warning + return *(mInput.get()); + } + inline Tensor& output(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Scaling Operator has only 1 output"); + (void) outputIdx; // avoid unused warning + return *(mOutput.get()); + } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "Scaling Operator has only 1 input"); + (void) inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Scaling Operator has only 1 output"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning; + return mOutput; + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Scaling_Op>::create(name)(*this); + mOutput->setBackend(name); + // FIXME: temporary workaround + mInput->setBackend(name); + } + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); +} +} + +namespace { +template <> +const char* const EnumStrings<Aidge::ScalingAttr>::data[] + = {"scalingFactor"}; +} + +#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */ diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 93eb262f703ca7eb385641c77df7ae7e79c00b96..ba6132a5ee00325d0f7de57db117a169d42352e9 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -42,6 +42,27 @@ public: setDatatype(DataType::Float32); } + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Softmax_Op(const Softmax_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Softmax_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Softmax_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Softmax_Op>(*this); + } + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx == 0 && "operator supports only 1 input"); (void) inputIdx; // avoid unused warning @@ -87,14 +108,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Softmax_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -104,10 +125,15 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; -inline std::shared_ptr<Node> Softmax(const char* name = nullptr) { - // FIXME: properly handle default w&b initialization in every cases +inline std::shared_ptr<Node> Softmax(const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Softmax_Op>(), name); } } diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90b2ae6a8ae1311aef14e4eba4d3563a28a3d18e --- /dev/null +++ b/include/aidge/operator/Sqrt.hpp @@ -0,0 +1,141 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_SQRT_H_ +#define AIDGE_CORE_OPERATOR_SQRT_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Sqrt_Op : public Operator, + public Registrable<Sqrt_Op, std::string, std::unique_ptr<OperatorImpl>(const Sqrt_Op&)> { +public: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Sqrt"; + + Sqrt_Op() + : Operator(Type) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Sqrt_Op(const Sqrt_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Sqrt_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Sqrt_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Sqrt_Op>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) + mOutput->resize(mInput->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t /*inputIdx*/) const override final { return *(mInput.get()); } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "Sqrt Operator has only 1 input"); + (void) inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Sqrt Operator has only 1 output"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Sqrt_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Sqrt(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Sqrt_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_SQRT_H_ */ diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp new file mode 100644 index 0000000000000000000000000000000000000000..451cba08f58e7a580576531ce2a97c92fb9be3ae --- /dev/null +++ b/include/aidge/operator/Sub.hpp @@ -0,0 +1,146 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_SUB_H_ +#define AIDGE_CORE_OPERATOR_SUB_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Sub_Op : public Operator, + public Registrable<Sub_Op, std::string, std::unique_ptr<OperatorImpl>(const Sub_Op&)> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Sub"; + + Sub_Op() + : Operator(Type) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Sub_Op(const Sub_Op& op) + : Operator(Type), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Sub_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Sub_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Sub_Op>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInputs[0]->empty()) + mOutput->resize(mInputs[0]->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < 2 && "wrong inputIdx for Add operator."); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx < 2) && "Sub Operator has 2 inputs"); + (void) inputIdx; // avoid unused warning + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Sub Operator has only 1 output"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 2 && "operator supports only 2 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Sub_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[0]->setBackend(name); + mInputs[1]->setBackend(name); + } + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 2; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 2; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Sub(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Sub_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_SUB_H_ */ diff --git a/include/aidge/recipies/LabelGraph.hpp b/include/aidge/recipies/LabelGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9dd77e5e9f397260cf936cf77b15616c17ea33b8 --- /dev/null +++ b/include/aidge/recipies/LabelGraph.hpp @@ -0,0 +1,35 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_RECIPIES_LABELGRAPH_H_ +#define AIDGE_RECIPIES_LABELGRAPH_H_ + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge { +NodePtr nodeLabel(NodePtr node); + +/** + * @brief Generate the graph for the pixel-wise labels corresponding to a data graph, taking into account the scaling changes (padding, stride, pooling...). + * @details Right now, the behavior is to replace the following operators: + * - Conv: MaxPooling + * - ConvDepthWie: MaxPooling + * - AvgPooling: MaxPooling + * - MaxPooling: MaxPooling + * - all others: identity (removed) + * @param graph Data graph + * @param return Computing graph for the labels derived from the data graph + */ +std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph); +} // namespace Aidge + +#endif /* AIDGE_RECIPIES_LABELGRAPH_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 81b3f31662933fe4f59a17cdb0ee42441fb791bc..1896894ee8690cedaef696394da0829604e36211 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -43,6 +43,8 @@ public: }; ~SequentialScheduler() = default; + void generateScheduling(bool verbose = false); + /** * @brief Run the provided Computational Graph with a batch of data */ @@ -54,6 +56,15 @@ public: */ void saveSchedulingDiagram(const std::string& fileName) const; + /** + * @brief Return a vector of Node ordered by the order they are called by the scheduler + * + * @return std::vector<std::shared_ptr<Node>> + */ + std::vector<std::shared_ptr<Node>> getStaticScheduling(){ + return mStaticSchedule; + } + private: /** * @brief Set of layers receiving an input from currently processing layers @@ -63,9 +74,22 @@ private: */ std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; + /** + * @brief Shared ptr to the scheduled graph view + * + */ std::shared_ptr<GraphView> mGraphView; + /** + * @brief List of SchedulingElement (i.e: Nodes with their computation time) + * + */ std::vector<SchedulingElement> mScheduling; + /** + * @brief List of nodes ordered by their + * + */ + std::vector<std::shared_ptr<Node>> mStaticSchedule; }; } // namespace Aidge -#endif /* AIDGE_SCHEDULER_H_ */ \ No newline at end of file +#endif /* AIDGE_SCHEDULER_H_ */ diff --git a/include/aidge/utils/Attributes.hpp b/include/aidge/utils/Attributes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d3444000191022b575adaf1430319479daa5d4fc --- /dev/null +++ b/include/aidge/utils/Attributes.hpp @@ -0,0 +1,77 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_UTILS_ATTRIBUTES_H_ +#define AIDGE_CORE_UTILS_ATTRIBUTES_H_ + +#ifdef PYBIND +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#endif +#include <vector> +#include <string> +#include <set> + +#ifdef PYBIND +namespace py = pybind11; +#endif + +namespace { +// This is the type that will hold all the strings. Each enumerate type will +// declare its own specialization. +template <typename T> struct EnumStrings { + static const char* const data[]; +}; +} + +namespace Aidge { +template<class T, std::size_t N> +constexpr std::size_t size(T (&)[N]) { return N; } + +/* This abstract class allows to avoid binding Attributes. +* Otherwise we would need to bind every template possible of Attributes. +* Every operators can access the methods of this class by inheriting from +* Attributes in the binding code. +*/ +class Attributes { +public: + /** + * @brief Check if the attribute exists. + * @param name Name of the attribute to check. + * @return bool True if the attribute exists, false otherwise. + */ + virtual bool hasAttr(const std::string& name) const = 0; + + /** + * @brief Get the (implementation defined) name of the type of an attribute, returned by std::type_info::name. + * @param name Name of the attribute. + * @return std::string Name of the type as returned by std::type_info::name. + */ + virtual std::string getAttrType(const std::string& name) const = 0; + + /** + * @brief Get the attribute's name list. + * @return std::set<std::string> Vector of names of the attributes. + */ + virtual std::set<std::string> getAttrsName() const = 0; + +#ifdef PYBIND + /* Bindable get function, does not recquire any templating. + * This is thanks to py::object which allow the function to + * be agnostic from its return type. + */ + virtual py::object getAttrPy(const std::string& name) const = 0; +#endif + virtual ~Attributes() {} +}; +} + +#endif /* AIDGE_CORE_UTILS_ATTRIBUTES_H_ */ diff --git a/include/aidge/utils/CParameter.hpp b/include/aidge/utils/CParameter.hpp deleted file mode 100644 index 0f4c74ab8bccb7bc134e035a5f12d31d51663e5d..0000000000000000000000000000000000000000 --- a/include/aidge/utils/CParameter.hpp +++ /dev/null @@ -1,115 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_CPARAMETER_H_ -#define AIDGE_CPARAMETER_H_ - -#include <assert.h> -#include <map> -#include <vector> - -namespace Aidge { - -///\todo store also a fix-sized code that indicates the type -///\todo managing complex types or excluding non-trivial, non-aggregate types -class CParameter -{ -private: - template <typename T> - struct is_vector : std::false_type {}; - - template <typename T, typename Alloc> - struct is_vector<std::vector<T, Alloc>> : std::true_type {}; - -public: - // not copyable, not movable - CParameter(CParameter const &) = delete; - CParameter(CParameter &&) = delete; - CParameter &operator=(CParameter const &) = delete; - CParameter &operator=(CParameter &&) = delete; - CParameter() : m_Params({}){}; - ~CParameter() = default; - - /** - * \brief Returning a parameter identified by its name - * \tparam T expected parameter type - * \param i_ParamName Parameter name - * \details assert if T is not the actual parameter type, if the parameter does not - * exist or interna parameter position is invalid. - * \todo Returning a T const& ? But dangerous => the client may get an address within - * param buffer that will get invalid after the CParam death. - * \note at() throws if the parameter does not exist, using find to test for parameter existance - */ - template<class T> T Get(std::string const i_ParamName) const - { - assert(m_Params.find(i_ParamName) != m_Params.end()); - assert(m_Types.find(i_ParamName) != m_Types.end()); - assert(m_Params.at(i_ParamName) <= m_OffSet); - assert(typeid(T).name() == m_Types.at(i_ParamName)); - return *reinterpret_cast<T *>(m_BeginBuffer + m_Params.at(i_ParamName)); - } - - ///\brief Add a parameter value, identified by its name - ///\tparam T expected parameter type - ///\param i_ParamName Parameter name - ///\param i_Value Parameter value - ///\todo Pass i_Value by ref if large or not trivial - ///\bug If parameter already exists, its value is changed but written in the - /// internal buffer in a new location (previous value is still in memory at its previous location) - template<class T> void Add(std::string const &i_ParamName, T const &i_Value) - { - m_Buffer.resize(m_Buffer.size() + (sizeof(T) / sizeof(uint8_t))); - m_BeginBuffer = m_Buffer.data(); // Update buffer ptr in case of memory reordering - *reinterpret_cast<T *>(m_BeginBuffer + m_OffSet) - = i_Value; // Black-magic used to add anytype into the vector - m_Params[i_ParamName] = m_OffSet; // Copy pointer offset - m_OffSet += sizeof(T); // Increment offset - - m_Types[i_ParamName] = typeid(i_Value).name(); - } - - - std::string getParamType(std::string const &i_ParamName){ - return m_Types[i_ParamName]; - } - - std::vector<std::string> getParametersName(){ - std::vector<std::string> parametersName; - for(auto const& it: m_Params) - parametersName.push_back(it.first); - return parametersName; - } - -private: - std::map<std::string, std::size_t> m_Params; // { Param name : offset } - - ///\brief Map to check type error - /* Note : i tried this : `std::map<std::string, std::type_info const *> mTypes;` - but looks like the type_ingo object was destroyed. - I am not a hugde fan of storing a string and making string comparison. - Maybe we can use a custom enum type (or is there a standard solution ?) - */ - std::map<std::string, std::string> m_Types; - - ///\brief All parameters values concatenated in raw binary form. - std::vector<uint8_t> m_Buffer = {}; - - ///\brief Starting address of the buffer - uint8_t *m_BeginBuffer = m_Buffer.data(); - - ///\brief Offset, in number of uint8_t, of the next parameter to write - std::size_t m_OffSet = 0; - -}; - -} - -#endif /* AIDGE_CPARAMETER_H_ */ diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2af8f47e9420f266cc6eca21f167944c761db7ea --- /dev/null +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -0,0 +1,221 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ +#define AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ + +#include <map> +#include <vector> +#include <type_traits> +#include <typeinfo> +#include <cassert> +#include <string> + +#include "aidge/utils/future_std/any.hpp" +#include "aidge/utils/Attributes.hpp" + +#ifdef PYBIND +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <pybind11/embed.h> + +namespace py = pybind11; +#endif + + +namespace Aidge { + +///\todo store also a fix-sized code that indicates the type +///\todo managing complex types or excluding non-trivial, non-aggregate types +class DynamicAttributes : public Attributes { +public: + /** + * \brief Returning an Attribute identified by its name + * \tparam T expected Attribute type + * \param name Attribute name + * \details assert if T is not the actual Attribute type or if the Attribute does not + * exist + * \note at() throws if the Attribute does not exist, using find to test for Attribute existance + */ + template<class T> T& getAttr(const std::string& name) + { +#ifdef PYBIND + // If attribute does not exist in C++, it might have been created or modified in Python + auto it = mAttrs.find(name); + if (it == mAttrs.end()) { + auto itPy = mAttrsPy.find(name); + if (itPy != mAttrsPy.end()) { + // Insert the attribute back in C++ + mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>()))); + } + } +#endif + + return future_std::any_cast<T&>(mAttrs.at(name)); + } + + template<class T> const T& getAttr(const std::string& name) const + { +#ifdef PYBIND + // If attribute does not exist in C++, it might have been created or modified in Python + auto it = mAttrs.find(name); + if (it == mAttrs.end()) { + auto itPy = mAttrsPy.find(name); + if (itPy != mAttrsPy.end()) { + // Insert the attribute back in C++ + mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>()))); + } + } +#endif + + return future_std::any_cast<const T&>(mAttrs.at(name)); + } + + ///\brief Add a new Attribute, identified by its name. If it already exists, asserts. + ///\tparam T expected Attribute type + ///\param name Attribute name + ///\param value Attribute value + template<class T> void addAttr(const std::string& name, const T& value) + { + const auto& res = mAttrs.emplace(std::make_pair(name, future_std::any(value))); + assert(res.second && "attribute already exists"); + +#ifdef PYBIND + // We cannot handle Python object if the Python interpreter is not running + if (Py_IsInitialized()) { + // Keep a copy of the attribute in py::object that is updated everytime + mAttrsPy.emplace(std::make_pair(name, py::cast(value))); + } +#endif + } + + ///\brief Set an Attribute value, identified by its name. If it already exists, its value (and type, if different) is changed. + ///\tparam T expected Attribute type + ///\param name Attribute name + ///\param value Attribute value + template<class T> void setAttr(const std::string& name, const T& value) + { + auto res = mAttrs.emplace(std::make_pair(name, future_std::any(value))); + if (!res.second) + res.first->second = future_std::any(value); + +#ifdef PYBIND + // We cannot handle Python object if the Python interpreter is not running + if (Py_IsInitialized()) { + // Keep a copy of the attribute in py::object that is updated everytime + auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value))); + if (!resPy.second) + resPy.first->second = std::move(py::cast(value)); + } +#endif + } + + void delAttr(const std::string& name) { + mAttrs.erase(name); +#ifdef PYBIND + mAttrsPy.erase(name); +#endif + } + +#ifdef PYBIND + void addAttrPy(const std::string& name, py::object&& value) + { + auto it = mAttrs.find(name); + assert(it == mAttrs.end() && "attribute already exists"); + + const auto& res = mAttrsPy.emplace(std::make_pair(name, value)); + assert(res.second && "attribute already exists"); + } + + void setAttrPy(const std::string& name, py::object&& value) + { + auto resPy = mAttrsPy.emplace(std::make_pair(name, value)); + if (!resPy.second) + resPy.first->second = std::move(value); + + // Force getAttr() to take attribute value from mAttrsPy and update mAttrs + mAttrs.erase(name); + } +#endif + + ////////////////////////////////////// + /// Generic Attributes API + ////////////////////////////////////// + bool hasAttr(const std::string& name) const override final { +#ifdef PYBIND + // Attributes might have been created in Python, the second condition is necessary. + return (mAttrs.find(name) != mAttrs.end() || mAttrsPy.find(name) != mAttrsPy.end()); +#else + return (mAttrs.find(name) != mAttrs.end()); +#endif + } + + std::string getAttrType(const std::string& name) const override final { + // In order to remain consistent between C++ and Python, with or without PyBind, the name of the type is: + // - C-style for C++ created attributes + // - Python-style for Python created attributes +#ifdef PYBIND + // If attribute does not exist in C++, it might have been created in Python + auto it = mAttrs.find(name); + if (it == mAttrs.end()) { + auto itPy = mAttrsPy.find(name); + if (itPy != mAttrsPy.end()) { + return std::string(Py_TYPE(itPy->second.ptr())->tp_name); + } + } +#endif + + return mAttrs.at(name).type().name(); + } + + std::set<std::string> getAttrsName() const override final { + std::set<std::string> attrsName; + for(auto const& it: mAttrs) + attrsName.insert(it.first); +#ifdef PYBIND + // Attributes might have been created in Python + for(auto const& it: mAttrsPy) + attrsName.insert(it.first); +#endif + return attrsName; + } + +#ifdef PYBIND + /** + * @detail See https://github.com/pybind/pybind11/issues/1590 as to why a + * generic type caster for std::any is not feasable. + * The strategy here is to keep a copy of each attribute in py::object that is updated everytime. + */ + py::object getAttrPy(const std::string& name) const override final { + return mAttrsPy.at(name); + }; +#endif + +private: +#ifdef PYBIND + // Stores C++ attributes (copy) and Python-only attributes + // Code should be compiled with -fvisibility=hidden + // See https://pybind11.readthedocs.io/en/stable/faq.html: + // “‘SomeClass’ declared with greater visibility than the type of its + // field ‘SomeClass::member’ [-Wattributes]†+ // This map will only be populated if Python interpreter is running + std::map<std::string, py::object> mAttrsPy; + // Stores C++ attributes only + // mutable because it may be updated in getAttr() from Python + mutable std::map<std::string, future_std::any> mAttrs; +#else + std::map<std::string, future_std::any> mAttrs; +#endif +}; + +} + +#endif /* AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ */ diff --git a/include/aidge/utils/ErrorHandling.hpp b/include/aidge/utils/ErrorHandling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8fbeff30abecfec0077786b21825b6a6f36677c6 --- /dev/null +++ b/include/aidge/utils/ErrorHandling.hpp @@ -0,0 +1,59 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + +#ifndef AIDGE_ERRORHANDLING_H_ +#define AIDGE_ERRORHANDLING_H_ + +#include <cstdio> +#include <memory> + +#define AIDGE_STRINGIZE_DETAIL(x) #x +#define AIDGE_STRINGIZE(x) AIDGE_STRINGIZE_DETAIL(x) + +#ifdef NO_EXCEPTION +#define AIDGE_THROW_OR_ABORT(ex, ...) \ +do { std::printf(__VA_ARGS__); std::abort(); } while (false) +#else +#include <stdexcept> +#include <memory> +#define AIDGE_THROW_OR_ABORT(ex, ...) \ +do { \ + int n = 128; \ + std::unique_ptr<char[]> formatted; \ + formatted.reset(new char[n]); \ + const int len = std::snprintf(formatted.get(), n, __VA_ARGS__); \ + if (len >= n) { \ + formatted.reset(new char[len + 1]); \ + std::snprintf(formatted.get(), len + 1, __VA_ARGS__); \ + }; \ + throw ex(formatted.get()); \ +} while (false) +#endif + +/** + * Macro for specified API assertions. + * Used to check logic directly related to user's inputs. + * If it asserts, it means an user error. +*/ +#define AIDGE_ASSERT(stm, ...) \ +if (!(stm)) { printf("Assertion failed: " AIDGE_STRINGIZE(stm) " in " __FILE__ ":%d", __LINE__); \ + AIDGE_THROW_OR_ABORT(std::runtime_error, __VA_ARGS__); } + +/** + * Macro for internal assertions. + * Used to check internal logic not directly related to API user's inputs. + * If it asserts, it means a bug. +*/ +#define AIDGE_INTERNAL_ASSERT(stm) \ +assert((stm) && "Internal assertion failed: " #stm " in " __FILE__ ":" AIDGE_STRINGIZE(__LINE__)) + +#endif //AIDGE_ERRORHANDLING_H_ diff --git a/include/aidge/utils/Parameter.hpp b/include/aidge/utils/Parameter.hpp deleted file mode 100644 index b0c6e35950187f17d991cfe5b2c9bd2b09f1e70f..0000000000000000000000000000000000000000 --- a/include/aidge/utils/Parameter.hpp +++ /dev/null @@ -1,197 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_CORE_UTILS_PARAMETER_H_ -#define AIDGE_CORE_UTILS_PARAMETER_H_ - -#ifdef PYBIND -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> -#include <string> // Add this inclue to print error -#endif -#include <tuple> -#include <cassert> -#include <cstddef> - -#ifdef PYBIND -namespace py = pybind11; -#endif - -namespace { -// This is the type that will hold all the strings. Each enumerate type will -// declare its own specialization. -template <typename T> struct EnumStrings { - static const char* const data[]; -}; -} - -namespace Aidge { -template<class T, std::size_t N> -constexpr std::size_t size(T (&)[N]) { return N; } - -#ifdef PYBIND -/* This abstract class allows to avoid binding Parametrizable. -* Otherwise we would need to bind every template possible of Parametrizable. -* Every operators can access the methods of this class by inheriting from -* PyAbstractParametrizable in the binding code. -*/ -class PyAbstractParametrizable{ - public: - /* Bindable get function, does not recquire any templating. - * This is thanks to py::object which allow the function to - * be agnostic from its return type. - */ - virtual py::object getPy(const char* /*name*/) = 0; -}; -#endif - -template <class PARAM_ENUM, class ...T> -class Parameterizable -#ifdef PYBIND - : public PyAbstractParametrizable -#endif - { -public: - using Parameters = std::tuple<T...>; - - // Helper class to pass to the constructor - template <PARAM_ENUM paramEnum> - class param { - public: - constexpr param(const typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& v) : value(v) {} - const typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type value; - }; - -/* - // Direct tuple initialization - Parameterizable(T... params) : mParams({params...}) { - - } -*/ - - // Constructor for parameters initialization. - // Compile-time garantee that every parameter is initialized. - template <PARAM_ENUM ...paramEnum> // non-type parameter pack - constexpr Parameterizable(const param<paramEnum>&&... params) { - // Check number of params consistency - static_assert(sizeof...(params) == std::tuple_size<std::tuple<T...>>::value, "wrong number of parameters in constructor"); - // static_assert(size(EnumStrings<PARAM_ENUM>::data) == std::tuple_size<std::tuple<T...>>::value, "wrong number of parameters in enum string"); - - // Check no duplicates - constexpr std::array<PARAM_ENUM, std::tuple_size<std::tuple<T...>>::value> pe = { paramEnum... }; - static_assert(!hasDuplicates(pe), "duplicate parameter"); // requires C++14 - - // Init params with constructor arguments - const std::array<PARAM_ENUM, std::tuple_size<std::tuple<T...>>::value> p = { ((void)(get<paramEnum>() = params.value), paramEnum) ... }; - (void)p; // avoid unused warning - } - - // Compile-time access with enum - template <PARAM_ENUM paramEnum> - constexpr typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& get() { - return std::get<static_cast<std::size_t>(paramEnum)>(mParams); - } - - template <PARAM_ENUM paramEnum> - constexpr const typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& get() const { - return std::get<static_cast<std::size_t>(paramEnum)>(mParams); - } - - // Runtime access with enum - template <typename R> - constexpr R& get(PARAM_ENUM paramEnum) { - return get<R>(static_cast<std::size_t>(paramEnum)); - } - - template <typename R> - constexpr const R& get(PARAM_ENUM paramEnum) const { - return get<R>(static_cast<std::size_t>(paramEnum)); - } - - // Runtime existance check with name - constexpr bool isParam(const char* name) const { - for (std::size_t i = 0; i < size(EnumStrings<PARAM_ENUM>::data); ++i) { - if (strcmp(EnumStrings<PARAM_ENUM>::data[i], name) == 0) { - return true; - } - } - - return false; - } - - // Runtime access with name - template <typename R> - constexpr R& get(const char* name) { - for (std::size_t i = 0; i < size(EnumStrings<PARAM_ENUM>::data); ++i) { - if (strcmp(EnumStrings<PARAM_ENUM>::data[i], name) == 0) { - return get<R>(i); - } - } - - assert(false && "parameter not found"); - } - - template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value-1> - constexpr typename std::enable_if<(SIZE > 0), R&>::type get(std::size_t i) { - if (i == SIZE) { - if (std::is_same<R, typename std::tuple_element<SIZE,std::tuple<T...>>::type>::value) { - return reinterpret_cast<R&>(std::get<SIZE>(mParams)); - } - else { - assert(false && "wrong parameter type"); - } - } - else { - return get<R, SIZE-1>(i); - } - } - - template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value-1> - constexpr typename std::enable_if<(SIZE <= 0), R&>::type get(std::size_t i) { - assert(false && "parameter not found"); - } - - constexpr const std::tuple<T...>& getParams() const { - return mParams; - } - - #ifdef PYBIND - py::object getPy(const char* name){ - for (std::size_t i = 0; i < size(EnumStrings<PARAM_ENUM>::data); ++i) { - if (strcmp(EnumStrings<PARAM_ENUM>::data[i], name) == 0) { - // https://github.com/pybind/pybind11/blob/f3e0602802c7840992c97f4960515777cad6a5c7/include/pybind11/pytypes.h#L1119-L1138 - // Normal accessor would not work has we convert the tuple to a py::object which can be anything - return py::detail::accessor_policies::tuple_item::get(py::cast(mParams), static_cast<py::size_t>(i)); - } - } - throw py::value_error("Parameter : " + std::string(name) + " does not exist." ); - }; - #endif - -private: - template <typename V, std::size_t N> - static constexpr bool hasDuplicates(const std::array<V, N>& array) { - for (std::size_t i = 1; i < N; i++) { - for (std::size_t j = 0; j < i; j++) { - if (array[i] == array[j]) { - return true; - } - } - } - - return false; - } - - std::tuple<T...> mParams; -}; -} - -#endif /* AIDGE_CORE_UTILS_PARAMETER_H_ */ diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index 4cbf8fd284bef314dbe28b19ebdae05172467bad..894e56fae2e9c2f6bcf11e4e76a433f5c8058080 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -17,11 +17,54 @@ namespace Aidge{ +// FUSE MATMUL + ADD -> FC + +/** + * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. + * + * @param nodes Strict set of Node to merge. + */ void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); +/** + * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ +void fuseMulAdd(std::shared_ptr<GraphView> graphView); + + +// REMOVE FLATTEN + FC -> FC + +/** + * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. + * + * @param nodes Strict set of Node to merge. + */ void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +/** + * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ +void removeFlatten(std::shared_ptr<GraphView> graphView); + +// FUSE BN + FC || CONV -> FC || CONV +/** + * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. + * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ + * + * @param nodes Strict set of Node to merge. + */ +void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes); +/** + * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. + * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ +void fuseBatchNorm(std::shared_ptr<GraphView> graphView); } - -#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index 98749c1349bad644dee2c1a8549559939791f71c..ece74509d466800c870d73d1e0bbe1d639f8bf54 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -34,7 +34,8 @@ public: static std::map<Key, std::function<Func>>& registry() { #ifdef PYBIND - if (std::getenv("AIDGE_CORE_WITH_PYBIND")){ + #define _CRT_SECURE_NO_WARNINGS + if (Py_IsInitialized()){ std::string name = std::string("registrar_")+typeid(Registrable<DerivedClass, Key, Func>).name(); static auto shared_data = reinterpret_cast<std::map<Key, std::function<Func>> *>(py::get_shared_data(name)); if (!shared_data) @@ -57,6 +58,11 @@ struct Registrar { //assert(newInsert && "registrar already exists"); } + static bool exists(const typename C::registrar_key& key) { + const auto it = C::registry().find(key); + return (it != C::registry().end()); + } + static auto create(const typename C::registrar_key& key){ const auto it = C::registry().find(key); assert(it != C::registry().end() && "invalid registrar key"); @@ -72,4 +78,4 @@ struct Registrar { }; } -#endif //AIDGE_CORE_UTILS_REGISTRAR_H_ \ No newline at end of file +#endif //AIDGE_CORE_UTILS_REGISTRAR_H_ diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b67f69ae7afc2c22f3b424812ec994b10974b668 --- /dev/null +++ b/include/aidge/utils/StaticAttributes.hpp @@ -0,0 +1,204 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_UTILS_STATICATTRIBUTES_H_ +#define AIDGE_CORE_UTILS_STATICATTRIBUTES_H_ + +#include <tuple> +#include <cassert> +#include <cstddef> +#include <typeinfo> + +#include "aidge/utils/Attributes.hpp" +#include "aidge/utils/ErrorHandling.hpp" + +namespace Aidge { +/** + * @brief This class is designed to handle static attributes (i.e. known at compile-time) + * with named accessors, with minimal overhead (the name strings are not stored in each object + * instance and it remains possible to access attribute without overhead at compile-time). +*/ +template <class ATTRS_ENUM, class ...T> +class StaticAttributes : public Attributes { +public: + using Attrs = std::tuple<T...>; + + // Helper class to pass to the constructor + template <ATTRS_ENUM attrsEnum> + class attr { + public: + constexpr attr(const typename std::tuple_element<static_cast<std::size_t>(attrsEnum),std::tuple<T...>>::type& v) : value(v) {} + const typename std::tuple_element<static_cast<std::size_t>(attrsEnum),std::tuple<T...>>::type value; + }; + +/* + // Direct tuple initialization + StaticAttributes(T... attrs) : mAttrs({attrs...}) { + + } +*/ + + // Constructor for Attributes initialization. + // Compile-time garantee that every attribute is initialized. + template <ATTRS_ENUM ...attrsEnum> // non-type attribute pack + constexpr StaticAttributes(const attr<attrsEnum>&&... attrs) { + // Check number of attrs consistency + static_assert(sizeof...(attrs) == std::tuple_size<std::tuple<T...>>::value, "wrong number of attributes in constructor"); + // static_assert(size(EnumStrings<ATTRS_ENUM>::data) == std::tuple_size<std::tuple<T...>>::value, "wrong number of attributes in enum string"); + + // Check no duplicates + constexpr std::array<ATTRS_ENUM, std::tuple_size<std::tuple<T...>>::value> pe = { attrsEnum... }; + static_assert(!hasDuplicates(pe), "duplicate attribute"); // requires C++14 + + // Init attrs with constructor arguments + const std::array<ATTRS_ENUM, std::tuple_size<std::tuple<T...>>::value> p = { ((void)(getAttr<attrsEnum>() = attrs.value), attrsEnum) ... }; + (void)p; // avoid unused warning + } + + // Compile-time access with enum + template <ATTRS_ENUM attrsEnum> + constexpr typename std::tuple_element<static_cast<std::size_t>(attrsEnum),std::tuple<T...>>::type& getAttr() { + return std::get<static_cast<std::size_t>(attrsEnum)>(mAttrs); + } + + template <ATTRS_ENUM attrsEnum> + constexpr const typename std::tuple_element<static_cast<std::size_t>(attrsEnum),std::tuple<T...>>::type& getAttr() const { + return std::get<static_cast<std::size_t>(attrsEnum)>(mAttrs); + } + + // Runtime access with enum + template <typename R> + constexpr R& getAttr(ATTRS_ENUM attrsEnum) { + return getAttr<R>(static_cast<std::size_t>(attrsEnum)); + } + + template <typename R> + constexpr const R& getAttr(ATTRS_ENUM attrsEnum) const { + return getAttr<R>(static_cast<std::size_t>(attrsEnum)); + } + + // Runtime access with name + template <typename R> + R& getAttr(const char* name) { + for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { + if (strcmp(EnumStrings<ATTRS_ENUM>::data[i], name) == 0) { + return getAttr<R>(i); + } + } + + AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"%s\" not found", name); + } + + template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> + typename std::enable_if<(SIZE > 0), R&>::type getAttr(std::size_t i) { + if (i == SIZE-1) { + if (std::is_same<R, typename std::tuple_element<SIZE-1,std::tuple<T...>>::type>::value) { + return reinterpret_cast<R&>(std::get<SIZE-1>(mAttrs)); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "wrong type for attribute with index %lu", i); + } + } + else { + return getAttr<R, SIZE-1>(i); + } + } + + template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> + [[noreturn]] typename std::enable_if<(SIZE == 0), R&>::type getAttr(std::size_t /*i*/) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); + } + + template <std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> + constexpr typename std::enable_if<(SIZE > 0), const std::type_info&>::type getAttrType(std::size_t i) const { + if (i == SIZE-1) { + return typeid(typename std::tuple_element<SIZE-1,std::tuple<T...>>::type); + } + else { + return getAttrType<SIZE-1>(i); + } + } + + template <std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> + [[noreturn]] typename std::enable_if<(SIZE == 0), const std::type_info&>::type getAttrType(std::size_t /*i*/) const { + AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); + } + + constexpr const std::tuple<T...>& getStaticAttributes() const { + return mAttrs; + } + + ////////////////////////////////////// + /// Generic Attributes API + ////////////////////////////////////// + // Runtime existance check with name + bool hasAttr(const std::string& name) const override final { + for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { + if (name == EnumStrings<ATTRS_ENUM>::data[i]) { + return true; + } + } + + return false; + } + + // Runtime type access with name + std::string getAttrType(const std::string& name) const override final { + for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { + if (name == EnumStrings<ATTRS_ENUM>::data[i]) { + return getAttrType(i).name(); + } + } + + AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"%s\" not found", name.c_str()); + } + + std::set<std::string> getAttrsName() const override final { + std::set<std::string> attrsName; + for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { + attrsName.insert(EnumStrings<ATTRS_ENUM>::data[i]); + } + return attrsName; + } + + #ifdef PYBIND + py::object getAttrPy(const std::string& name) const override { + for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { + if (name == EnumStrings<ATTRS_ENUM>::data[i]) { + // https://github.com/pybind/pybind11/blob/f3e0602802c7840992c97f4960515777cad6a5c7/include/pybind11/pytypes.h#L1119-L1138 + // Normal accessor would not work has we convert the tuple to a py::object which can be anything + return py::detail::accessor_policies::tuple_item::get(py::cast(mAttrs), static_cast<py::size_t>(i)); + } + } + + AIDGE_THROW_OR_ABORT(py::value_error, "attribute \"%s\" not found", name.c_str()); + }; + #endif + +private: + template <typename V, std::size_t N> + static constexpr bool hasDuplicates(const std::array<V, N>& array) { + for (std::size_t i = 1; i < N; i++) { + for (std::size_t j = 0; j < i; j++) { + if (array[i] == array[j]) { + return true; + } + } + } + + return false; + } + + std::tuple<T...> mAttrs; +}; +} + +#endif /* AIDGE_CORE_UTILS_STATICATTRIBUTES_H_ */ diff --git a/include/aidge/utils/TensorUtils.hpp b/include/aidge/utils/TensorUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6387619546c66922e48cf95a8a56487d4b0d0641 --- /dev/null +++ b/include/aidge/utils/TensorUtils.hpp @@ -0,0 +1,52 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_ +#define AIDGE_CORE_UTILS_TENSOR_UTILS_H_ +#include <cmath> // std::abs +#include "aidge/data/Tensor.hpp" + +/** + * @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is: + * + * |t1-t2| <= absolute + relative * |t2| + * + * If a tensor value is different from the other tensor return False + * If the tensor does not have the same size, return False + * If the datatype is not the same between each tensor return False + * If the templated type does not correspond to the datatype of each tensor, raise an assertion error + * + * @tparam T should correspond to the type of the tensor, define the type of the absolute and relative error + * @param t1 first :cpp:class:`Aidge::Tensor` to test + * @param t2 second :cpp:class:`Aidge::Tensor` to test + * @param relative relative difference allowed (should be betwen 0 and 1) + * @param absolute absolute error allowed (shoulmd be positive) + * @return true if both tensor are approximately equal and have the datatype, shape. Else return false + */ +template <typename T> +bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute){ + assert(t1.dataType() == t2.dataType()); + assert(t1.dataType() == NativeType<T>::type); + assert(relative >= 0); + assert(absolute >= 0 && absolute<=1); + + if (t1.size() != t2.size()){ + return false; + } + for(size_t i; i < t1.size(); ++i){ + if (static_cast<float>(std::abs(t1.get<T>(i) - t2.get<T>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T>(i)))))){ + return false; + } + } + return true; +} + +#endif /* AIDGE_CORE_UTILS_TENSOR_UTILS_H_s */ diff --git a/include/aidge/utils/future_std/any.hpp b/include/aidge/utils/future_std/any.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8d9bfe28d0497dc12c59aaed68a23d3a9563815e --- /dev/null +++ b/include/aidge/utils/future_std/any.hpp @@ -0,0 +1,552 @@ +/** + * Origin: https://github.com/claudiofantacci/any + * + * Implementation of N4562 std::experimental::any (merged into C++17 as std::any) + * for C++11 compilers. + * + * See also: + * + http://en.cppreference.com/w/cpp/any + * + http://en.cppreference.com/w/cpp/experimental/any + * + http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4562.html#any + * + https://cplusplus.github.io/LWG/lwg-active.html#2509 + * + * Copyright (c) 2016 Denilson das Mercês Amorim + * Copyright (c) 2018 Claudio Fantacci + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://www.boost.org/LICENSE_1_0.txt) + */ + +#ifndef AIDGE_CORE_UTILS_FUTURE_STD_ANY_H_ +#define AIDGE_CORE_UTILS_FUTURE_STD_ANY_H_ + +#include <stdexcept> +#include <typeinfo> +#include <type_traits> +#include <utility> + + +namespace future_std +{ + +class bad_any_cast : public std::bad_cast +{ +public: + const char* what() const noexcept override + { + return "bad any_cast"; + } +}; + + +class any final +{ +public: + /** + * Constructs an object of type any with an empty state. + */ + any() : + vtable(nullptr) + { } + + + /** + * Constructs an object of type any with an equivalent state as other. + */ + any(const any& rhs) : + vtable(rhs.vtable) + { + if(rhs.has_value()) + { + rhs.vtable->copy(rhs.storage, this->storage); + } + } + + + /** + * Constructs an object of type any with a state equivalent to the original state of other. + * rhs is left in a valid but otherwise unspecified state. + */ + any(any&& rhs) noexcept : + vtable(rhs.vtable) + { + if(rhs.has_value()) + { + rhs.vtable->move(rhs.storage, this->storage); + rhs.vtable = nullptr; + } + } + + + /** + * Same effect as this->clear(). + */ + ~any() + { + this->reset(); + } + + + /** + * Constructs an object of type any that contains an object of type T direct-initialized with std::forward<ValueType>(value). + * T shall satisfy the CopyConstructible requirements, otherwise the program is ill-formed. + * This is because an `any` may be copy constructed into another `any` at any time, so a copy should always be allowed. + */ + template<typename ValueType, typename = typename std::enable_if<!std::is_same<typename std::decay<ValueType>::type, any>::value>::type> + any(ValueType&& value) + { + static_assert(std::is_copy_constructible<typename std::decay<ValueType>::type>::value, + "T shall satisfy the CopyConstructible requirements."); + this->construct(std::forward<ValueType>(value)); + } + + + /** + * Has the same effect as any(rhs).swap(*this). No effects if an exception is thrown. + */ + any& operator=(const any& rhs) + { + any(rhs).swap(*this); + return *this; + } + + + /** + * Has the same effect as any(std::move(rhs)).swap(*this). + * The state of *this is equivalent to the original state of rhs and rhs is left in a valid + * but otherwise unspecified state. + */ + any& operator=(any&& rhs) noexcept + { + any(std::move(rhs)).swap(*this); + return *this; + } + + + /** + * Has the same effect as any(std::forward<ValueType>(value)).swap(*this). No effect if a exception is thrown. + * T shall satisfy the CopyConstructible requirements, otherwise the program is ill-formed. + * This is because an `any` may be copy constructed into another `any` at any time, so a copy should always be allowed. + */ + template<typename ValueType, typename = typename std::enable_if<!std::is_same<typename std::decay<ValueType>::type, any>::value>::type> + any& operator=(ValueType&& value) + { + static_assert(std::is_copy_constructible<typename std::decay<ValueType>::type>::value, "T shall satisfy the CopyConstructible requirements."); + any(std::forward<ValueType>(value)).swap(*this); + return *this; + } + + + /** + * If not empty, destroys the contained object. + */ + void reset() noexcept + { + if(has_value()) + { + this->vtable->destroy(storage); + this->vtable = nullptr; + } + } + + + /** + * Returns true if *this has no contained object, otherwise false. + */ + bool has_value() const noexcept + { + return this->vtable != nullptr; + } + + + /** + * If *this has a contained object of type T, typeid(T); otherwise typeid(void). + */ + const std::type_info& type() const noexcept + { + return has_value()? this->vtable->type() : typeid(void); + } + + + /** + * Exchange the states of *this and rhs. + */ + void swap(any& other) noexcept + { + if(this->vtable != other.vtable) + { + any tmp(std::move(other)); + + other.vtable = this->vtable; + if(this->vtable != nullptr) + this->vtable->move(this->storage, other.storage); + + this->vtable = tmp.vtable; + if(tmp.vtable != nullptr) + { + tmp.vtable->move(tmp.storage, this->storage); + tmp.vtable = nullptr; + } + } + else + { + if(this->vtable != nullptr) + this->vtable->swap(this->storage, other.storage); + } + } + + +private: + union storage_union + { + using stack_storage_t = typename std::aligned_storage<2 * sizeof(void*), std::alignment_of<void*>::value>::type; + + void* dynamic; + + stack_storage_t stack; + }; + + + /** + * Base VTable specification. + * + * Note: The caller is responsible for doing .vtable = nullptr after destructful operations + * such as destroy() and/or move(). + */ + struct vtable_type + { + /** + * The type of the object this vtable is for. + */ + const std::type_info& (*type)() noexcept; + + + /** + * Destroys the object in the union. + * The state of the union after this call is unspecified, caller must ensure not to use src anymore. + */ + void(*destroy)(storage_union&) noexcept; + + + /** + * Copies the **inner** content of the src union into the yet unitialized dest union. + * As such, both inner objects will have the same state, but on separate memory locations. + */ + void(*copy)(const storage_union& src, storage_union& dest); + + + /** + * Moves the storage from src to the yet unitialized dest union. + * The state of src after this call is unspecified, caller must ensure not to use src anymore. + */ + void(*move)(storage_union& src, storage_union& dest) noexcept; + + + /** + * Exchanges the storage between lhs and rhs. + */ + void(*swap)(storage_union& lhs, storage_union& rhs) noexcept; + }; + + + /** + * VTable for dynamically allocated storage. + */ + template<typename T> + struct vtable_dynamic + { + static const std::type_info& type() noexcept + { + return typeid(T); + } + + + static void destroy(storage_union& storage) noexcept + { + delete reinterpret_cast<T*>(storage.dynamic); + } + + + static void copy(const storage_union& src, storage_union& dest) + { + dest.dynamic = new T(*reinterpret_cast<const T*>(src.dynamic)); + } + + + static void move(storage_union& src, storage_union& dest) noexcept + { + dest.dynamic = src.dynamic; + src.dynamic = nullptr; + } + + + static void swap(storage_union& lhs, storage_union& rhs) noexcept + { + std::swap(lhs.dynamic, rhs.dynamic); + } + }; + + + /** + * VTable for stack allocated storage. + */ + template<typename T> + struct vtable_stack + { + static const std::type_info& type() noexcept + { + return typeid(T); + } + + + static void destroy(storage_union& storage) noexcept + { + reinterpret_cast<T*>(&storage.stack)->~T(); + } + + + static void copy(const storage_union& src, storage_union& dest) + { + new (&dest.stack) T(reinterpret_cast<const T&>(src.stack)); + } + + + static void move(storage_union& src, storage_union& dest) noexcept + { + /** + * One of the conditions for using vtable_stack is a nothrow move constructor, + * so this move constructor will never throw a exception. + */ + new (&dest.stack) T(std::move(reinterpret_cast<T&>(src.stack))); + destroy(src); + } + + + static void swap(storage_union& lhs, storage_union& rhs) noexcept + { + storage_union tmp_storage; + move(rhs, tmp_storage); + move(lhs, rhs); + move(tmp_storage, lhs); + } + }; + + + /** + * Whether the type T must be dynamically allocated or can be stored on the stack. + */ + template<typename T> + struct requires_allocation : + std::integral_constant<bool, !(std::is_nothrow_move_constructible<T>::value // N4562 6.3/3 [any.class] + && sizeof(T) <= sizeof(storage_union::stack) + && std::alignment_of<T>::value <= std::alignment_of<storage_union::stack_storage_t>::value)> + { }; + + + /** + * Returns the pointer to the vtable of the type T. + */ + template<typename T> + static vtable_type* vtable_for_type() + { + using VTableType = typename std::conditional<requires_allocation<T>::value, vtable_dynamic<T>, vtable_stack<T>>::type; + static vtable_type table = { VTableType::type, VTableType::destroy, VTableType::copy, VTableType::move, VTableType::swap }; + return &table; + } + + +protected: + template<typename T> + friend const T* any_cast(const any* operand) noexcept; + + + template<typename T> + friend T* any_cast(any* operand) noexcept; + + + /** + * Same effect as is_same(this->type(), t); + */ + bool is_typed(const std::type_info& t) const + { + return is_same(this->type(), t); + } + + + /** + * Checks if two type infos are the same. + * If ANY_IMPL_FAST_TYPE_INFO_COMPARE is defined, checks only the address of the + * type infos, otherwise does an actual comparision. Checking addresses is + * only a valid approach when there's no interaction with outside sources + * (other shared libraries and such). + */ + static bool is_same(const std::type_info& a, const std::type_info& b) + { +#ifdef ANY_IMPL_FAST_TYPE_INFO_COMPARE + return &a == &b; +#else + return a == b; +#endif + } + + + /** + * Casts (with no type_info checks) the storage pointer as const T*. + */ + template<typename T> + const T* cast() const noexcept + { + return requires_allocation<typename std::decay<T>::type>::value ? reinterpret_cast<const T*>(storage.dynamic) : reinterpret_cast<const T*>(&storage.stack); + } + + + /** + * Casts (with no type_info checks) the storage pointer as T*. + */ + template<typename T> + T* cast() noexcept + { + return requires_allocation<typename std::decay<T>::type>::value ? reinterpret_cast<T*>(storage.dynamic) : reinterpret_cast<T*>(&storage.stack); + } + + +private: + storage_union storage; // On offset(0) so no padding for align + + vtable_type* vtable; + + + template<typename ValueType, typename T> + typename std::enable_if<requires_allocation<T>::value>::type do_construct(ValueType&& value) + { + storage.dynamic = new T(std::forward<ValueType>(value)); + } + + + template<typename ValueType, typename T> + typename std::enable_if<!requires_allocation<T>::value>::type do_construct(ValueType&& value) + { + new (&storage.stack) T(std::forward<ValueType>(value)); + } + + + /** + * Chooses between stack and dynamic allocation for the type decay_t<ValueType>, + * assigns the correct vtable, and constructs the object on our storage. + */ + template<typename ValueType> + void construct(ValueType&& value) + { + using T = typename std::decay<ValueType>::type; + + this->vtable = vtable_for_type<T>(); + + do_construct<ValueType,T>(std::forward<ValueType>(value)); + } +}; + + +namespace detail +{ + template<typename ValueType> + inline ValueType any_cast_move_if_true(typename std::remove_reference<ValueType>::type* p, std::true_type) + { + return std::move(*p); + } + + + template<typename ValueType> + inline ValueType any_cast_move_if_true(typename std::remove_reference<ValueType>::type* p, std::false_type) + { + return *p; + } +} + + +/** + * Performs *any_cast<add_const_t<remove_reference_t<ValueType>>>(&operand), or throws bad_any_cast on failure. + */ +template<typename ValueType> +inline ValueType any_cast(const any& operand) +{ + auto p = any_cast<typename std::add_const<typename std::remove_reference<ValueType>::type>::type>(&operand); + if(p == nullptr) throw bad_any_cast(); + return *p; +} + + +/** + * Performs *any_cast<remove_reference_t<ValueType>>(&operand), or throws bad_any_cast on failure. + */ +template<typename ValueType> +inline ValueType any_cast(any& operand) +{ + auto p = any_cast<typename std::remove_reference<ValueType>::type>(&operand); + if(p == nullptr) throw bad_any_cast(); + return *p; +} + + +/** + * If ANY_IMPL_ANYCAST_MOVEABLE is not defined, does as N4562 specifies: + * Performs *any_cast<remove_reference_t<ValueType>>(&operand), or throws bad_any_cast on failure. + * + * If ANY_IMPL_ANYCAST_MOVEABLE is defined, does as LWG Defect 2509 specifies [1]: + * If ValueType is MoveConstructible and isn't a lvalue reference, performs + * std::move(*any_cast<remove_reference_t<ValueType>>(&operand)), otherwise + * *any_cast<remove_reference_t<ValueType>>(&operand). + * Throws bad_any_cast on failure. + * + * [1] https://cplusplus.github.io/LWG/lwg-active.html#2509 + */ +template<typename ValueType> +inline ValueType any_cast(any&& operand) +{ +#ifdef ANY_IMPL_ANY_CAST_MOVEABLE + using can_move = std::integral_constant<bool, std::is_move_constructible<ValueType>::value && !std::is_lvalue_reference<ValueType>::value>; +#else + using can_move = std::false_type; +#endif + + auto p = any_cast<typename std::remove_reference<ValueType>::type>(&operand); + if(p == nullptr) throw bad_any_cast(); + return detail::any_cast_move_if_true<ValueType>(p, can_move()); +} + + +/** + * If operand != nullptr && operand->type() == typeid(ValueType), a pointer to the object + * contained by operand, otherwise nullptr. + */ +template<typename T> +inline const T* any_cast(const any* operand) noexcept +{ + if(operand == nullptr || !operand->is_typed(typeid(T))) + return nullptr; + else + return operand->cast<T>(); +} + + +/** + * If operand != nullptr && operand->type() == typeid(ValueType), a pointer to the object + * contained by operand, otherwise nullptr. + */ +template<typename T> +inline T* any_cast(any* operand) noexcept +{ + if(operand == nullptr || !operand->is_typed(typeid(T))) + return nullptr; + else + return operand->cast<T>(); +} + + +inline void swap(any& lhs, any& rhs) noexcept +{ + lhs.swap(rhs); +} + +} + +#endif /* AIDGE_CORE_UTILS_FUTURE_STD_ANY_H_ */ diff --git a/include/aidge/utils/future_std/expected.hpp b/include/aidge/utils/future_std/expected.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c271d0e8d8066c0bcd0358f28f8bcd711a8b6ba0 --- /dev/null +++ b/include/aidge/utils/future_std/expected.hpp @@ -0,0 +1,3487 @@ +// Origin: https://github.com/martinmoene/expected-lite +// +// This version targets C++11 and later. +// +// Copyright (C) 2016-2020 Martin Moene. +// +// Distributed under the Boost Software License, Version 1.0. +// (See copy at http://www.boost.org/LICENSE_1_0.txt) +// +// expected lite is based on: +// A proposal to add a utility class to represent expected monad +// by Vicente J. Botet Escriba and Pierre Talbot. http:://wg21.link/p0323 + +#ifndef AIDGE_CORE_UTILS_FUTURE_STD_EXPECTED_H_ +#define AIDGE_CORE_UTILS_FUTURE_STD_EXPECTED_H_ + +#define expected_lite_MAJOR 0 +#define expected_lite_MINOR 6 +#define expected_lite_PATCH 3 + +#define expected_lite_VERSION expected_STRINGIFY(expected_lite_MAJOR) "." expected_STRINGIFY(expected_lite_MINOR) "." expected_STRINGIFY(expected_lite_PATCH) + +#define expected_STRINGIFY( x ) expected_STRINGIFY_( x ) +#define expected_STRINGIFY_( x ) #x + +// expected-lite configuration: + +#define nsel_EXPECTED_DEFAULT 0 +#define nsel_EXPECTED_FUTURE_STD 1 +#define nsel_EXPECTED_STD 2 + +// tweak header support: + +#ifdef __has_include +# if __has_include(<future_std/expected.tweak.hpp>) +# include <future_std/expected.tweak.hpp> +# endif +#define expected_HAVE_TWEAK_HEADER 1 +#else +#define expected_HAVE_TWEAK_HEADER 0 +//# pragma message("expected.hpp: Note: Tweak header not supported.") +#endif + +// expected selection and configuration: + +#if !defined( nsel_CONFIG_SELECT_EXPECTED ) +# define nsel_CONFIG_SELECT_EXPECTED ( nsel_HAVE_STD_EXPECTED ? nsel_EXPECTED_STD : nsel_EXPECTED_FUTURE_STD ) +#endif + +// Proposal revisions: +// +// DXXXXR0: -- +// N4015 : -2 (2014-05-26) +// N4109 : -1 (2014-06-29) +// P0323R0: 0 (2016-05-28) +// P0323R1: 1 (2016-10-12) +// -------: +// P0323R2: 2 (2017-06-15) +// P0323R3: 3 (2017-10-15) +// P0323R4: 4 (2017-11-26) +// P0323R5: 5 (2018-02-08) +// P0323R6: 6 (2018-04-02) +// P0323R7: 7 (2018-06-22) * +// +// expected-lite uses 2 and higher + +#ifndef nsel_P0323R +# define nsel_P0323R 7 +#endif + +// Monadic operations proposal revisions: +// +// P2505R0: 0 (2021-12-12) +// P2505R1: 1 (2022-02-10) +// P2505R2: 2 (2022-04-15) +// P2505R3: 3 (2022-06-05) +// P2505R4: 4 (2022-06-15) +// P2505R5: 5 (2022-09-20) * +// +// expected-lite uses 5 + +#ifndef nsel_P2505R +# define nsel_P2505R 5 +#endif + +// Control presence of C++ exception handling (try and auto discover): + +#ifndef nsel_CONFIG_NO_EXCEPTIONS +# if defined(_MSC_VER) +# include <cstddef> // for _HAS_EXCEPTIONS +# endif +# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (_HAS_EXCEPTIONS) +# define nsel_CONFIG_NO_EXCEPTIONS 0 +# else +# define nsel_CONFIG_NO_EXCEPTIONS 1 +# endif +#endif + +// at default use SEH with MSVC for no C++ exceptions + +#ifndef nsel_CONFIG_NO_EXCEPTIONS_SEH +# define nsel_CONFIG_NO_EXCEPTIONS_SEH ( nsel_CONFIG_NO_EXCEPTIONS && _MSC_VER ) +#endif + +// C++ language version detection (C++23 is speculative): +// Note: VC14.0/1900 (VS2015) lacks too much from C++14. + +#ifndef nsel_CPLUSPLUS +# if defined(_MSVC_LANG ) && !defined(__clang__) +# define nsel_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG ) +# else +# define nsel_CPLUSPLUS __cplusplus +# endif +#endif + +#define nsel_CPP98_OR_GREATER ( nsel_CPLUSPLUS >= 199711L ) +#define nsel_CPP11_OR_GREATER ( nsel_CPLUSPLUS >= 201103L ) +#define nsel_CPP14_OR_GREATER ( nsel_CPLUSPLUS >= 201402L ) +#define nsel_CPP17_OR_GREATER ( nsel_CPLUSPLUS >= 201703L ) +#define nsel_CPP20_OR_GREATER ( nsel_CPLUSPLUS >= 202002L ) +#define nsel_CPP23_OR_GREATER ( nsel_CPLUSPLUS >= 202300L ) + +// Use C++23 std::expected if available and requested: + +#if nsel_CPP23_OR_GREATER && defined(__has_include ) +# if __has_include( <expected> ) +# define nsel_HAVE_STD_EXPECTED 1 +# else +# define nsel_HAVE_STD_EXPECTED 0 +# endif +#else +# define nsel_HAVE_STD_EXPECTED 0 +#endif + +#define nsel_USES_STD_EXPECTED ( (nsel_CONFIG_SELECT_EXPECTED == nsel_EXPECTED_STD) || ((nsel_CONFIG_SELECT_EXPECTED == nsel_EXPECTED_DEFAULT) && nsel_HAVE_STD_EXPECTED) ) + +// +// in_place: code duplicated in any-lite, expected-lite, expected-lite, value-ptr-lite, variant-lite: +// + +#ifndef future_std_lite_HAVE_IN_PLACE_TYPES +#define future_std_lite_HAVE_IN_PLACE_TYPES 1 + +// C++17 std::in_place in <utility>: + +#if nsel_CPP17_OR_GREATER + +#include <utility> + +namespace future_std { + +using std::in_place; +using std::in_place_type; +using std::in_place_index; +using std::in_place_t; +using std::in_place_type_t; +using std::in_place_index_t; + +#define future_std_lite_in_place_t( T) std::in_place_t +#define future_std_lite_in_place_type_t( T) std::in_place_type_t<T> +#define future_std_lite_in_place_index_t(K) std::in_place_index_t<K> + +#define future_std_lite_in_place( T) std::in_place_t{} +#define future_std_lite_in_place_type( T) std::in_place_type_t<T>{} +#define future_std_lite_in_place_index(K) std::in_place_index_t<K>{} + +} // namespace future_std + +#else // nsel_CPP17_OR_GREATER + +#include <cstddef> + +namespace future_std { +namespace detail { + +template< class T > +struct in_place_type_tag {}; + +template< std::size_t K > +struct in_place_index_tag {}; + +} // namespace detail + +struct in_place_t {}; + +template< class T > +inline in_place_t in_place( detail::in_place_type_tag<T> = detail::in_place_type_tag<T>() ) +{ + return in_place_t(); +} + +template< std::size_t K > +inline in_place_t in_place( detail::in_place_index_tag<K> = detail::in_place_index_tag<K>() ) +{ + return in_place_t(); +} + +template< class T > +inline in_place_t in_place_type( detail::in_place_type_tag<T> = detail::in_place_type_tag<T>() ) +{ + return in_place_t(); +} + +template< std::size_t K > +inline in_place_t in_place_index( detail::in_place_index_tag<K> = detail::in_place_index_tag<K>() ) +{ + return in_place_t(); +} + +// mimic templated typedef: + +#define future_std_lite_in_place_t( T) future_std::in_place_t(&)( future_std::detail::in_place_type_tag<T> ) +#define future_std_lite_in_place_type_t( T) future_std::in_place_t(&)( future_std::detail::in_place_type_tag<T> ) +#define future_std_lite_in_place_index_t(K) future_std::in_place_t(&)( future_std::detail::in_place_index_tag<K> ) + +#define future_std_lite_in_place( T) future_std::in_place_type<T> +#define future_std_lite_in_place_type( T) future_std::in_place_type<T> +#define future_std_lite_in_place_index(K) future_std::in_place_index<K> + +} // namespace future_std + +#endif // nsel_CPP17_OR_GREATER +#endif // future_std_lite_HAVE_IN_PLACE_TYPES + +// +// Using std::expected: +// + +#if nsel_USES_STD_EXPECTED + +#include <expected> + +namespace future_std { + + using std::expected; +// ... +} + +#else // nsel_USES_STD_EXPECTED + +#include <cassert> +#include <exception> +#include <functional> +#include <initializer_list> +#include <memory> +#include <new> +#include <system_error> +#include <type_traits> +#include <utility> + +// additional includes: + +#if nsel_CONFIG_NO_EXCEPTIONS +# if nsel_CONFIG_NO_EXCEPTIONS_SEH +# include <windows.h> // for ExceptionCodes +# else +// already included: <cassert> +# endif +#else +# include <stdexcept> +#endif + +// C++ feature usage: + +#if nsel_CPP11_OR_GREATER +# define nsel_constexpr constexpr +#else +# define nsel_constexpr /*constexpr*/ +#endif + +#if nsel_CPP14_OR_GREATER +# define nsel_constexpr14 constexpr +#else +# define nsel_constexpr14 /*constexpr*/ +#endif + +#if nsel_CPP17_OR_GREATER +# define nsel_inline17 inline +#else +# define nsel_inline17 /*inline*/ +#endif + +// Compiler versions: +// +// MSVC++ 6.0 _MSC_VER == 1200 nsel_COMPILER_MSVC_VERSION == 60 (Visual Studio 6.0) +// MSVC++ 7.0 _MSC_VER == 1300 nsel_COMPILER_MSVC_VERSION == 70 (Visual Studio .NET 2002) +// MSVC++ 7.1 _MSC_VER == 1310 nsel_COMPILER_MSVC_VERSION == 71 (Visual Studio .NET 2003) +// MSVC++ 8.0 _MSC_VER == 1400 nsel_COMPILER_MSVC_VERSION == 80 (Visual Studio 2005) +// MSVC++ 9.0 _MSC_VER == 1500 nsel_COMPILER_MSVC_VERSION == 90 (Visual Studio 2008) +// MSVC++ 10.0 _MSC_VER == 1600 nsel_COMPILER_MSVC_VERSION == 100 (Visual Studio 2010) +// MSVC++ 11.0 _MSC_VER == 1700 nsel_COMPILER_MSVC_VERSION == 110 (Visual Studio 2012) +// MSVC++ 12.0 _MSC_VER == 1800 nsel_COMPILER_MSVC_VERSION == 120 (Visual Studio 2013) +// MSVC++ 14.0 _MSC_VER == 1900 nsel_COMPILER_MSVC_VERSION == 140 (Visual Studio 2015) +// MSVC++ 14.1 _MSC_VER >= 1910 nsel_COMPILER_MSVC_VERSION == 141 (Visual Studio 2017) +// MSVC++ 14.2 _MSC_VER >= 1920 nsel_COMPILER_MSVC_VERSION == 142 (Visual Studio 2019) + +#if defined(_MSC_VER) && !defined(__clang__) +# define nsel_COMPILER_MSVC_VER (_MSC_VER ) +# define nsel_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900)) ) +#else +# define nsel_COMPILER_MSVC_VER 0 +# define nsel_COMPILER_MSVC_VERSION 0 +#endif + +#define nsel_COMPILER_VERSION( major, minor, patch ) ( 10 * ( 10 * (major) + (minor) ) + (patch) ) + +#if defined(__clang__) +# define nsel_COMPILER_CLANG_VERSION nsel_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) +#else +# define nsel_COMPILER_CLANG_VERSION 0 +#endif + +#if defined(__GNUC__) && !defined(__clang__) +# define nsel_COMPILER_GNUC_VERSION nsel_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#else +# define nsel_COMPILER_GNUC_VERSION 0 +#endif + +// half-open range [lo..hi): +//#define nsel_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) ) + +// Method enabling + +#define nsel_REQUIRES_0(...) \ + template< bool B = (__VA_ARGS__), typename std::enable_if<B, int>::type = 0 > + +#define nsel_REQUIRES_T(...) \ + , typename std::enable_if< (__VA_ARGS__), int >::type = 0 + +#define nsel_REQUIRES_R(R, ...) \ + typename std::enable_if< (__VA_ARGS__), R>::type + +#define nsel_REQUIRES_A(...) \ + , typename std::enable_if< (__VA_ARGS__), void*>::type = nullptr + +// Presence of language and library features: + +#ifdef _HAS_CPP0X +# define nsel_HAS_CPP0X _HAS_CPP0X +#else +# define nsel_HAS_CPP0X 0 +#endif + +//#define nsel_CPP11_140 (nsel_CPP11_OR_GREATER || nsel_COMPILER_MSVC_VER >= 1900) + +// Clang, GNUC, MSVC warning suppression macros: + +#ifdef __clang__ +# pragma clang diagnostic push +#elif defined __GNUC__ +# pragma GCC diagnostic push +#endif // __clang__ + +#if nsel_COMPILER_MSVC_VERSION >= 140 +# pragma warning( push ) +# define nsel_DISABLE_MSVC_WARNINGS(codes) __pragma( warning(disable: codes) ) +#else +# define nsel_DISABLE_MSVC_WARNINGS(codes) +#endif + +#ifdef __clang__ +# define nsel_RESTORE_WARNINGS() _Pragma("clang diagnostic pop") +#elif defined __GNUC__ +# define nsel_RESTORE_WARNINGS() _Pragma("GCC diagnostic pop") +#elif nsel_COMPILER_MSVC_VERSION >= 140 +# define nsel_RESTORE_WARNINGS() __pragma( warning( pop ) ) +#else +# define nsel_RESTORE_WARNINGS() +#endif + +// Suppress the following MSVC (GSL) warnings: +// - C26409: Avoid calling new and delete explicitly, use std::make_unique<T> instead (r.11) + +nsel_DISABLE_MSVC_WARNINGS( 26409 ) + +// +// expected: +// + +namespace future_std { namespace expected_lite { + +// type traits C++17: + +namespace std17 { + +#if nsel_CPP17_OR_GREATER + +using std::conjunction; +using std::is_swappable; +using std::is_nothrow_swappable; + +#else // nsel_CPP17_OR_GREATER + +namespace detail { + +using std::swap; + +struct is_swappable +{ + template< typename T, typename = decltype( swap( std::declval<T&>(), std::declval<T&>() ) ) > + static std::true_type test( int /* unused */); + + template< typename > + static std::false_type test(...); +}; + +struct is_nothrow_swappable +{ + // wrap noexcept(expr) in separate function as work-around for VC140 (VS2015): + + template< typename T > + static constexpr bool satisfies() + { + return noexcept( swap( std::declval<T&>(), std::declval<T&>() ) ); + } + + template< typename T > + static auto test( int ) -> std::integral_constant<bool, satisfies<T>()>{} + + template< typename > + static auto test(...) -> std::false_type; +}; +} // namespace detail + +// is [nothrow] swappable: + +template< typename T > +struct is_swappable : decltype( detail::is_swappable::test<T>(0) ){}; + +template< typename T > +struct is_nothrow_swappable : decltype( detail::is_nothrow_swappable::test<T>(0) ){}; + +// conjunction: + +template< typename... > struct conjunction : std::true_type{}; +template< typename B1 > struct conjunction<B1> : B1{}; + +template< typename B1, typename... Bn > +struct conjunction<B1, Bn...> : std::conditional<bool(B1::value), conjunction<Bn...>, B1>::type{}; + +#endif // nsel_CPP17_OR_GREATER + +} // namespace std17 + +// type traits C++20: + +namespace std20 { + +#if defined(__cpp_lib_remove_cvref) + +using std::remove_cvref; + +#else + +template< typename T > +struct remove_cvref +{ + typedef typename std::remove_cv< typename std::remove_reference<T>::type >::type type; +}; + +#endif + +} // namespace std20 + +// forward declaration: + +template< typename T, typename E > +class expected; + +namespace detail { + +#if nsel_P2505R >= 3 +template< typename T > +struct is_expected : std::false_type {}; + +template< typename T, typename E > +struct is_expected< expected< T, E > > : std::true_type {}; +#endif // nsel_P2505R >= 3 + +/// discriminated union to hold value or 'error'. + +template< typename T, typename E > +class storage_t_noncopy_nonmove_impl +{ + template< typename, typename > friend class future_std::expected_lite::expected; + +public: + using value_type = T; + using error_type = E; + + // no-op construction + storage_t_noncopy_nonmove_impl() {} + ~storage_t_noncopy_nonmove_impl() {} + + explicit storage_t_noncopy_nonmove_impl( bool has_value ) + : m_has_value( has_value ) + {} + + void construct_value() + { + new( &m_value ) value_type(); + } + + // void construct_value( value_type const & e ) + // { + // new( &m_value ) value_type( e ); + // } + + // void construct_value( value_type && e ) + // { + // new( &m_value ) value_type( std::move( e ) ); + // } + + template< class... Args > + void emplace_value( Args&&... args ) + { + new( &m_value ) value_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_value( std::initializer_list<U> il, Args&&... args ) + { + new( &m_value ) value_type( il, std::forward<Args>(args)... ); + } + + void destruct_value() + { + m_value.~value_type(); + } + + // void construct_error( error_type const & e ) + // { + // // new( &m_error ) error_type( e ); + // } + + // void construct_error( error_type && e ) + // { + // // new( &m_error ) error_type( std::move( e ) ); + // } + + template< class... Args > + void emplace_error( Args&&... args ) + { + new( &m_error ) error_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_error( std::initializer_list<U> il, Args&&... args ) + { + new( &m_error ) error_type( il, std::forward<Args>(args)... ); + } + + void destruct_error() + { + m_error.~error_type(); + } + + constexpr value_type const & value() const & + { + return m_value; + } + + value_type & value() & + { + return m_value; + } + + constexpr value_type const && value() const && + { + return std::move( m_value ); + } + + nsel_constexpr14 value_type && value() && + { + return std::move( m_value ); + } + + value_type const * value_ptr() const + { + return &m_value; + } + + value_type * value_ptr() + { + return &m_value; + } + + error_type const & error() const & + { + return m_error; + } + + error_type & error() & + { + return m_error; + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + bool has_value() const + { + return m_has_value; + } + + void set_has_value( bool v ) + { + m_has_value = v; + } + +private: + union + { + value_type m_value; + error_type m_error; + }; + + bool m_has_value = false; +}; + +template< typename T, typename E > +class storage_t_impl +{ + template< typename, typename > friend class future_std::expected_lite::expected; + +public: + using value_type = T; + using error_type = E; + + // no-op construction + storage_t_impl() {} + ~storage_t_impl() {} + + explicit storage_t_impl( bool has_value ) + : m_has_value( has_value ) + {} + + void construct_value() + { + new( &m_value ) value_type(); + } + + void construct_value( value_type const & e ) + { + new( &m_value ) value_type( e ); + } + + void construct_value( value_type && e ) + { + new( &m_value ) value_type( std::move( e ) ); + } + + template< class... Args > + void emplace_value( Args&&... args ) + { + new( &m_value ) value_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_value( std::initializer_list<U> il, Args&&... args ) + { + new( &m_value ) value_type( il, std::forward<Args>(args)... ); + } + + void destruct_value() + { + m_value.~value_type(); + } + + void construct_error( error_type const & e ) + { + new( &m_error ) error_type( e ); + } + + void construct_error( error_type && e ) + { + new( &m_error ) error_type( std::move( e ) ); + } + + template< class... Args > + void emplace_error( Args&&... args ) + { + new( &m_error ) error_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_error( std::initializer_list<U> il, Args&&... args ) + { + new( &m_error ) error_type( il, std::forward<Args>(args)... ); + } + + void destruct_error() + { + m_error.~error_type(); + } + + constexpr value_type const & value() const & + { + return m_value; + } + + value_type & value() & + { + return m_value; + } + + constexpr value_type const && value() const && + { + return std::move( m_value ); + } + + nsel_constexpr14 value_type && value() && + { + return std::move( m_value ); + } + + value_type const * value_ptr() const + { + return &m_value; + } + + value_type * value_ptr() + { + return &m_value; + } + + error_type const & error() const & + { + return m_error; + } + + error_type & error() & + { + return m_error; + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + bool has_value() const + { + return m_has_value; + } + + void set_has_value( bool v ) + { + m_has_value = v; + } + +private: + union + { + value_type m_value; + error_type m_error; + }; + + bool m_has_value = false; +}; + +/// discriminated union to hold only 'error'. + +template< typename E > +struct storage_t_impl<void, E> +{ + template< typename, typename > friend class future_std::expected_lite::expected; + +public: + using value_type = void; + using error_type = E; + + // no-op construction + storage_t_impl() {} + ~storage_t_impl() {} + + explicit storage_t_impl( bool has_value ) + : m_has_value( has_value ) + {} + + void construct_error( error_type const & e ) + { + new( &m_error ) error_type( e ); + } + + void construct_error( error_type && e ) + { + new( &m_error ) error_type( std::move( e ) ); + } + + template< class... Args > + void emplace_error( Args&&... args ) + { + new( &m_error ) error_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_error( std::initializer_list<U> il, Args&&... args ) + { + new( &m_error ) error_type( il, std::forward<Args>(args)... ); + } + + void destruct_error() + { + m_error.~error_type(); + } + + error_type const & error() const & + { + return m_error; + } + + error_type & error() & + { + return m_error; + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + bool has_value() const + { + return m_has_value; + } + + void set_has_value( bool v ) + { + m_has_value = v; + } + +private: + union + { + char m_dummy; + error_type m_error; + }; + + bool m_has_value = false; +}; + +template< typename T, typename E, bool isConstructable, bool isMoveable > +class storage_t +{ +public: +}; + +template< typename T, typename E > +class storage_t<T, E, false, false> : public storage_t_noncopy_nonmove_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_noncopy_nonmove_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) = delete; + storage_t( storage_t && other ) = delete; + +}; + +template< typename T, typename E > +class storage_t<T, E, true, true> : public storage_t_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<T, E>( other.has_value() ) + { + if ( this->has_value() ) this->construct_value( other.value() ); + else this->construct_error( other.error() ); + } + + storage_t(storage_t && other ) + : storage_t_impl<T, E>( other.has_value() ) + { + if ( this->has_value() ) this->construct_value( std::move( other.value() ) ); + else this->construct_error( std::move( other.error() ) ); + } +}; + +template< typename E > +class storage_t<void, E, true, true> : public storage_t_impl<void, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<void, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<void, E>( other.has_value() ) + { + if ( this->has_value() ) ; + else this->construct_error( other.error() ); + } + + storage_t(storage_t && other ) + : storage_t_impl<void, E>( other.has_value() ) + { + if ( this->has_value() ) ; + else this->construct_error( std::move( other.error() ) ); + } +}; + +template< typename T, typename E > +class storage_t<T, E, true, false> : public storage_t_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<T, E>(other.has_value()) + { + if ( this->has_value() ) this->construct_value( other.value() ); + else this->construct_error( other.error() ); + } + + storage_t( storage_t && other ) = delete; +}; + +template< typename E > +class storage_t<void, E, true, false> : public storage_t_impl<void, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<void, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<void, E>(other.has_value()) + { + if ( this->has_value() ) ; + else this->construct_error( other.error() ); + } + + storage_t( storage_t && other ) = delete; +}; + +template< typename T, typename E > +class storage_t<T, E, false, true> : public storage_t_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) = delete; + + storage_t( storage_t && other ) + : storage_t_impl<T, E>( other.has_value() ) + { + if ( this->has_value() ) this->construct_value( std::move( other.value() ) ); + else this->construct_error( std::move( other.error() ) ); + } +}; + +template< typename E > +class storage_t<void, E, false, true> : public storage_t_impl<void, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<void, E>( has_value ) + {} + + storage_t( storage_t const & other ) = delete; + + storage_t( storage_t && other ) + : storage_t_impl<void, E>( other.has_value() ) + { + if ( this->has_value() ) ; + else this->construct_error( std::move( other.error() ) ); + } +}; + +#if nsel_P2505R >= 3 +// C++11 invoke implementation +template< typename > +struct is_reference_wrapper : std::false_type {}; +template< typename T > +struct is_reference_wrapper< std::reference_wrapper< T > > : std::true_type {}; + +template< typename FnT, typename ClassT, typename ObjectT, typename... Args + nsel_REQUIRES_T( + std::is_function<FnT>::value + && ( std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + || std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value ) + ) +> +nsel_constexpr auto invoke_member_function_impl( FnT ClassT::* memfnptr, ObjectT && obj, Args && ... args ) + noexcept( noexcept( (std::forward< ObjectT >( obj ).*memfnptr)( std::forward< Args >( args )... ) ) ) + -> decltype( (std::forward< ObjectT >( obj ).*memfnptr)( std::forward< Args >( args )...) ) +{ + return (std::forward< ObjectT >( obj ).*memfnptr)( std::forward< Args >( args )... ); +} + +template< typename FnT, typename ClassT, typename ObjectT, typename... Args + nsel_REQUIRES_T( + std::is_function<FnT>::value + && is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_function_impl( FnT ClassT::* memfnptr, ObjectT && obj, Args && ... args ) + noexcept( noexcept( (obj.get().*memfnptr)( std::forward< Args >( args ) ... ) ) ) + -> decltype( (obj.get().*memfnptr)( std::forward< Args >( args ) ... ) ) +{ + return (obj.get().*memfnptr)( std::forward< Args >( args ) ... ); +} + +template< typename FnT, typename ClassT, typename ObjectT, typename... Args + nsel_REQUIRES_T( + std::is_function<FnT>::value + && !std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_function_impl( FnT ClassT::* memfnptr, ObjectT && obj, Args && ... args ) + noexcept( noexcept( ((*std::forward< ObjectT >( obj )).*memfnptr)( std::forward< Args >( args ) ... ) ) ) + -> decltype( ((*std::forward< ObjectT >( obj )).*memfnptr)( std::forward< Args >( args ) ... ) ) +{ + return ((*std::forward<ObjectT>(obj)).*memfnptr)( std::forward< Args >( args ) ... ); +} + +template< typename MemberT, typename ClassT, typename ObjectT + nsel_REQUIRES_T( + std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + || std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_object_impl( MemberT ClassT::* memobjptr, ObjectT && obj ) + noexcept( noexcept( std::forward< ObjectT >( obj ).*memobjptr ) ) + -> decltype( std::forward< ObjectT >( obj ).*memobjptr ) +{ + return std::forward< ObjectT >( obj ).*memobjptr; +} + +template< typename MemberT, typename ClassT, typename ObjectT + nsel_REQUIRES_T( + is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_object_impl( MemberT ClassT::* memobjptr, ObjectT && obj ) + noexcept( noexcept( obj.get().*memobjptr ) ) + -> decltype( obj.get().*memobjptr ) +{ + return obj.get().*memobjptr; +} + +template< typename MemberT, typename ClassT, typename ObjectT + nsel_REQUIRES_T( + !std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_object_impl( MemberT ClassT::* memobjptr, ObjectT && obj ) + noexcept( noexcept( (*std::forward< ObjectT >( obj )).*memobjptr ) ) + -> decltype( (*std::forward< ObjectT >( obj )).*memobjptr ) +{ + return (*std::forward< ObjectT >( obj )).*memobjptr; +} + +template< typename F, typename... Args + nsel_REQUIRES_T( + std::is_member_function_pointer< typename std20::remove_cvref< F >::type >::value + ) +> +nsel_constexpr auto invoke( F && f, Args && ... args ) + noexcept( noexcept( invoke_member_function_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) ) + -> decltype( invoke_member_function_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) +{ + return invoke_member_function_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ); +} + +template< typename F, typename... Args + nsel_REQUIRES_T( + std::is_member_object_pointer< typename std20::remove_cvref< F >::type >::value + ) +> +nsel_constexpr auto invoke( F && f, Args && ... args ) + noexcept( noexcept( invoke_member_object_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) ) + -> decltype( invoke_member_object_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) +{ + return invoke_member_object_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ); +} + +template< typename F, typename... Args + nsel_REQUIRES_T( + !std::is_member_function_pointer< typename std20::remove_cvref< F >::type >::value + && !std::is_member_object_pointer< typename std20::remove_cvref< F >::type >::value + ) +> +nsel_constexpr auto invoke( F && f, Args && ... args ) + noexcept( noexcept( std::forward< F >( f )( std::forward< Args >( args ) ... ) ) ) + -> decltype( std::forward< F >( f )( std::forward< Args >( args ) ... ) ) +{ + return std::forward< F >( f )( std::forward< Args >( args ) ... ); +} + +template< typename F, typename ... Args > +using invoke_result_nocvref_t = typename std20::remove_cvref< decltype( invoke( std::declval< F >(), std::declval< Args >()... ) ) >::type; + +#if nsel_P2505R >= 5 +template< typename F, typename ... Args > +using transform_invoke_result_t = typename std::remove_cv< decltype( invoke( std::declval< F >(), std::declval< Args >()... ) ) >::type; +#else +template< typename F, typename ... Args > +using transform_invoke_result_t = invoke_result_nocvref_t +#endif // nsel_P2505R >= 5 + +template< typename T > +struct valid_expected_value_type : std::integral_constant< bool, std::is_destructible< T >::value && !std::is_reference< T >::value && !std::is_array< T >::value > {}; + +#endif // nsel_P2505R >= 3 +} // namespace detail + +/// x.x.5 Unexpected object type; unexpected_type; C++17 and later can also use aliased type unexpected. + +#if nsel_P0323R <= 2 +template< typename E = std::exception_ptr > +class unexpected_type +#else +template< typename E > +class unexpected_type +#endif // nsel_P0323R +{ +public: + using error_type = E; + + // x.x.5.2.1 Constructors + +// unexpected_type() = delete; + + constexpr unexpected_type( unexpected_type const & ) = default; + constexpr unexpected_type( unexpected_type && ) = default; + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, Args&&...>::value + ) + > + constexpr explicit unexpected_type( future_std_lite_in_place_t(E), Args &&... args ) + : m_error( std::forward<Args>( args )...) + {} + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, std::initializer_list<U>, Args&&...>::value + ) + > + constexpr explicit unexpected_type( future_std_lite_in_place_t(E), std::initializer_list<U> il, Args &&... args ) + : m_error( il, std::forward<Args>( args )...) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible<E,E2>::value + && !std::is_same< typename std20::remove_cvref<E2>::type, future_std_lite_in_place_t(E2) >::value + && !std::is_same< typename std20::remove_cvref<E2>::type, unexpected_type >::value + ) + > + constexpr explicit unexpected_type( E2 && error ) + : m_error( std::forward<E2>( error ) ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && !std::is_convertible< E2 const &, E>::value /*=> explicit */ + ) + > + constexpr explicit unexpected_type( unexpected_type<E2> const & error ) + : m_error( E{ error.value() } ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && std::is_convertible< E2 const &, E>::value /*=> explicit */ + ) + > + constexpr /*non-explicit*/ unexpected_type( unexpected_type<E2> const & error ) + : m_error( error.value() ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && !std::is_convertible< E2 const &, E>::value /*=> explicit */ + ) + > + constexpr explicit unexpected_type( unexpected_type<E2> && error ) + : m_error( E{ std::move( error.value() ) } ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && std::is_convertible< E2 const &, E>::value /*=> non-explicit */ + ) + > + constexpr /*non-explicit*/ unexpected_type( unexpected_type<E2> && error ) + : m_error( std::move( error.value() ) ) + {} + + // x.x.5.2.2 Assignment + + nsel_constexpr14 unexpected_type& operator=( unexpected_type const & ) = default; + nsel_constexpr14 unexpected_type& operator=( unexpected_type && ) = default; + + template< typename E2 = E > + nsel_constexpr14 unexpected_type & operator=( unexpected_type<E2> const & other ) + { + unexpected_type{ other.value() }.swap( *this ); + return *this; + } + + template< typename E2 = E > + nsel_constexpr14 unexpected_type & operator=( unexpected_type<E2> && other ) + { + unexpected_type{ std::move( other.value() ) }.swap( *this ); + return *this; + } + + // x.x.5.2.3 Observers + + nsel_constexpr14 E & value() & noexcept + { + return m_error; + } + + constexpr E const & value() const & noexcept + { + return m_error; + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + nsel_constexpr14 E && value() && noexcept + { + return std::move( m_error ); + } + + constexpr E const && value() const && noexcept + { + return std::move( m_error ); + } + +#endif + + // x.x.5.2.4 Swap + + template< typename U=E > + nsel_REQUIRES_R( void, + std17::is_swappable<U>::value + ) + swap( unexpected_type & other ) noexcept ( + std17::is_nothrow_swappable<U>::value + ) + { + using std::swap; + swap( m_error, other.m_error ); + } + + // TODO: ??? unexpected_type: in-class friend operator==, != + +private: + error_type m_error; +}; + +#if nsel_CPP17_OR_GREATER + +/// template deduction guide: + +template< typename E > +unexpected_type( E ) -> unexpected_type< E >; + +#endif + +/// class unexpected_type, std::exception_ptr specialization (P0323R2) + +#if !nsel_CONFIG_NO_EXCEPTIONS +#if nsel_P0323R <= 2 + +// TODO: Should expected be specialized for particular E types such as exception_ptr and how? +// See p0323r7 2.1. Ergonomics, http://wg21.link/p0323 +template<> +class unexpected_type< std::exception_ptr > +{ +public: + using error_type = std::exception_ptr; + + unexpected_type() = delete; + + ~unexpected_type(){} + + explicit unexpected_type( std::exception_ptr const & error ) + : m_error( error ) + {} + + explicit unexpected_type(std::exception_ptr && error ) + : m_error( std::move( error ) ) + {} + + template< typename E > + explicit unexpected_type( E error ) + : m_error( std::make_exception_ptr( error ) ) + {} + + std::exception_ptr const & value() const + { + return m_error; + } + + std::exception_ptr & value() + { + return m_error; + } + +private: + std::exception_ptr m_error; +}; + +#endif // nsel_P0323R +#endif // !nsel_CONFIG_NO_EXCEPTIONS + +/// x.x.4, Unexpected equality operators + +template< typename E1, typename E2 > +constexpr bool operator==( unexpected_type<E1> const & x, unexpected_type<E2> const & y ) +{ + return x.value() == y.value(); +} + +template< typename E1, typename E2 > +constexpr bool operator!=( unexpected_type<E1> const & x, unexpected_type<E2> const & y ) +{ + return ! ( x == y ); +} + +#if nsel_P0323R <= 2 + +template< typename E > +constexpr bool operator<( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return x.value() < y.value(); +} + +template< typename E > +constexpr bool operator>( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return ( y < x ); +} + +template< typename E > +constexpr bool operator<=( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return ! ( y < x ); +} + +template< typename E > +constexpr bool operator>=( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return ! ( x < y ); +} + +#endif // nsel_P0323R + +/// x.x.5 Specialized algorithms + +template< typename E + nsel_REQUIRES_T( + std17::is_swappable<E>::value + ) +> +void swap( unexpected_type<E> & x, unexpected_type<E> & y) noexcept ( noexcept ( x.swap(y) ) ) +{ + x.swap( y ); +} + +#if nsel_P0323R <= 2 + +// unexpected: relational operators for std::exception_ptr: + +inline constexpr bool operator<( unexpected_type<std::exception_ptr> const & /*x*/, unexpected_type<std::exception_ptr> const & /*y*/ ) +{ + return false; +} + +inline constexpr bool operator>( unexpected_type<std::exception_ptr> const & /*x*/, unexpected_type<std::exception_ptr> const & /*y*/ ) +{ + return false; +} + +inline constexpr bool operator<=( unexpected_type<std::exception_ptr> const & x, unexpected_type<std::exception_ptr> const & y ) +{ + return ( x == y ); +} + +inline constexpr bool operator>=( unexpected_type<std::exception_ptr> const & x, unexpected_type<std::exception_ptr> const & y ) +{ + return ( x == y ); +} + +#endif // nsel_P0323R + +// unexpected: traits + +#if nsel_P0323R <= 3 + +template< typename E> +struct is_unexpected : std::false_type {}; + +template< typename E> +struct is_unexpected< unexpected_type<E> > : std::true_type {}; + +#endif // nsel_P0323R + +// unexpected: factory + +// keep make_unexpected() removed in p0323r2 for pre-C++17: + +template< typename E> +nsel_constexpr14 auto +make_unexpected( E && value ) -> unexpected_type< typename std::decay<E>::type > +{ + return unexpected_type< typename std::decay<E>::type >( std::forward<E>(value) ); +} + +#if nsel_P0323R <= 3 + +/*nsel_constexpr14*/ auto inline +make_unexpected_from_current_exception() -> unexpected_type< std::exception_ptr > +{ + return unexpected_type< std::exception_ptr >( std::current_exception() ); +} + +#endif // nsel_P0323R + +/// x.x.6, x.x.7 expected access error + +template< typename E > +class bad_expected_access; + +/// x.x.7 bad_expected_access<void>: expected access error + +template <> +class bad_expected_access< void > : public std::exception +{ +public: + explicit bad_expected_access() + : std::exception() + {} +}; + +/// x.x.6 bad_expected_access: expected access error + +#if !nsel_CONFIG_NO_EXCEPTIONS + +template< typename E > +class bad_expected_access : public bad_expected_access< void > +{ +public: + using error_type = E; + + explicit bad_expected_access( error_type error ) + : m_error( error ) + {} + + virtual char const * what() const noexcept override + { + return "bad_expected_access"; + } + + nsel_constexpr14 error_type & error() & + { + return m_error; + } + + constexpr error_type const & error() const & + { + return m_error; + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + +#endif + +private: + error_type m_error; +}; + +#endif // nsel_CONFIG_NO_EXCEPTIONS + +/// x.x.8 unexpect tag, in_place_unexpected tag: construct an error + +struct unexpect_t{}; +using in_place_unexpected_t = unexpect_t; + +nsel_inline17 constexpr unexpect_t unexpect{}; +nsel_inline17 constexpr unexpect_t in_place_unexpected{}; + +/// class error_traits + +#if nsel_CONFIG_NO_EXCEPTIONS + +namespace detail { + inline bool text( char const * /*text*/ ) { return true; } +} + +template< typename Error > +struct error_traits +{ + static void rethrow( Error const & /*e*/ ) + { +#if nsel_CONFIG_NO_EXCEPTIONS_SEH + RaiseException( EXCEPTION_ACCESS_VIOLATION, EXCEPTION_NONCONTINUABLE, 0, NULL ); +#else + assert( false && detail::text("throw bad_expected_access<Error>{ e };") ); +#endif + } +}; + +template<> +struct error_traits< std::exception_ptr > +{ + static void rethrow( std::exception_ptr const & /*e*/ ) + { +#if nsel_CONFIG_NO_EXCEPTIONS_SEH + RaiseException( EXCEPTION_ACCESS_VIOLATION, EXCEPTION_NONCONTINUABLE, 0, NULL ); +#else + assert( false && detail::text("throw bad_expected_access<std::exception_ptr>{ e };") ); +#endif + } +}; + +template<> +struct error_traits< std::error_code > +{ + static void rethrow( std::error_code const & /*e*/ ) + { +#if nsel_CONFIG_NO_EXCEPTIONS_SEH + RaiseException( EXCEPTION_ACCESS_VIOLATION, EXCEPTION_NONCONTINUABLE, 0, NULL ); +#else + assert( false && detail::text("throw std::system_error( e );") ); +#endif + } +}; + +#else // nsel_CONFIG_NO_EXCEPTIONS + +template< typename Error > +struct error_traits +{ + static void rethrow( Error const & e ) + { + throw bad_expected_access<Error>{ e }; + } +}; + +template<> +struct error_traits< std::exception_ptr > +{ + static void rethrow( std::exception_ptr const & e ) + { + std::rethrow_exception( e ); + } +}; + +template<> +struct error_traits< std::error_code > +{ + static void rethrow( std::error_code const & e ) + { + throw std::system_error( e ); + } +}; + +#endif // nsel_CONFIG_NO_EXCEPTIONS + +#if nsel_P2505R >= 3 +namespace detail { + +// from https://en.cppreference.com/w/cpp/utility/expected/unexpected: +// "the type of the unexpected value. The type must not be an array type, a non-object type, a specialization of std::unexpected, or a cv-qualified type." +template< typename T > +struct valid_unexpected_type : std::integral_constant< bool, + std::is_same< T, typename std20::remove_cvref< T >::type >::value + && std::is_object< T >::value + && !std::is_array< T >::value +> {}; + +template< typename T > +struct valid_unexpected_type< unexpected_type< T > > : std::false_type {}; + +} // namespace detail +#endif // nsel_P2505R >= 3 + +} // namespace expected_lite + +// provide future_std::unexpected_type: + +using expected_lite::unexpected_type; + +namespace expected_lite { + +/// class expected + +#if nsel_P0323R <= 2 +template< typename T, typename E = std::exception_ptr > +class expected +#else +template< typename T, typename E > +class expected +#endif // nsel_P0323R +{ +private: + template< typename, typename > friend class expected; + +public: + using value_type = T; + using error_type = E; + using unexpected_type = future_std::unexpected_type<E>; + + template< typename U > + struct rebind + { + using type = expected<U, error_type>; + }; + + // x.x.4.1 constructors + + nsel_REQUIRES_0( + std::is_default_constructible<T>::value + ) + nsel_constexpr14 expected() + : contained( true ) + { + contained.construct_value(); + } + + nsel_constexpr14 expected( expected const & ) = default; + nsel_constexpr14 expected( expected && ) = default; + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U const &>::value + && std::is_constructible<E, G const &>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const & , T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && (!std::is_convertible<U const &, T>::value || !std::is_convertible<G const &, E>::value ) /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( expected<U, G> const & other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( T{ other.contained.value() } ); + else contained.construct_error( E{ other.contained.error() } ); + } + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U const &>::value + && std::is_constructible<E, G const &>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const &, T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && !(!std::is_convertible<U const &, T>::value || !std::is_convertible<G const &, E>::value ) /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( expected<U, G> const & other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( other.contained.value() ); + else contained.construct_error( other.contained.error() ); + } + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U>::value + && std::is_constructible<E, G>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const & , T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && (!std::is_convertible<U, T>::value || !std::is_convertible<G, E>::value ) /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( expected<U, G> && other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( T{ std::move( other.contained.value() ) } ); + else contained.construct_error( E{ std::move( other.contained.error() ) } ); + } + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U>::value + && std::is_constructible<E, G>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const & , T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && !(!std::is_convertible<U, T>::value || !std::is_convertible<G, E>::value ) /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( expected<U, G> && other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( std::move( other.contained.value() ) ); + else contained.construct_error( std::move( other.contained.error() ) ); + } + + template< typename U = T + nsel_REQUIRES_T( + std::is_copy_constructible<U>::value + ) + > + nsel_constexpr14 expected( value_type const & value ) + : contained( true ) + { + contained.construct_value( value ); + } + + template< typename U = T + nsel_REQUIRES_T( + std::is_constructible<T,U&&>::value + && !std::is_same<typename std20::remove_cvref<U>::type, future_std_lite_in_place_t(U)>::value + && !std::is_same< expected<T,E> , typename std20::remove_cvref<U>::type>::value + && !std::is_same<future_std::unexpected_type<E>, typename std20::remove_cvref<U>::type>::value + && !std::is_convertible<U&&,T>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( U && value ) noexcept + ( + std::is_nothrow_move_constructible<U>::value && + std::is_nothrow_move_constructible<E>::value + ) + : contained( true ) + { + contained.construct_value( T{ std::forward<U>( value ) } ); + } + + template< typename U = T + nsel_REQUIRES_T( + std::is_constructible<T,U&&>::value + && !std::is_same<typename std20::remove_cvref<U>::type, future_std_lite_in_place_t(U)>::value + && !std::is_same< expected<T,E> , typename std20::remove_cvref<U>::type>::value + && !std::is_same<future_std::unexpected_type<E>, typename std20::remove_cvref<U>::type>::value + && std::is_convertible<U&&,T>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( U && value ) noexcept + ( + std::is_nothrow_move_constructible<U>::value && + std::is_nothrow_move_constructible<E>::value + ) + : contained( true ) + { + contained.construct_value( std::forward<U>( value ) ); + } + + // construct error: + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G const & >::value + && !std::is_convertible< G const &, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( E{ error.value() } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G const & >::value + && std::is_convertible< G const &, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( error.value() ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G&& >::value + && !std::is_convertible< G&&, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( E{ std::move( error.value() ) } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G&& >::value + && std::is_convertible< G&&, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( std::move( error.value() ) ); + } + + // in-place construction, value + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<T, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( future_std_lite_in_place_t(T), Args&&... args ) + : contained( true ) + { + contained.emplace_value( std::forward<Args>( args )... ); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<T, std::initializer_list<U>, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( future_std_lite_in_place_t(T), std::initializer_list<U> il, Args&&... args ) + : contained( true ) + { + contained.emplace_value( il, std::forward<Args>( args )... ); + } + + // in-place construction, error + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, Args&&... args ) + : contained( false ) + { + contained.emplace_error( std::forward<Args>( args )... ); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, std::initializer_list<U>, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, std::initializer_list<U> il, Args&&... args ) + : contained( false ) + { + contained.emplace_error( il, std::forward<Args>( args )... ); + } + + // x.x.4.2 destructor + + // TODO: ~expected: triviality + // Effects: If T is not cv void and is_trivially_destructible_v<T> is false and bool(*this), calls val.~T(). If is_trivially_destructible_v<E> is false and !bool(*this), calls unexpect.~unexpected<E>(). + // Remarks: If either T is cv void or is_trivially_destructible_v<T> is true, and is_trivially_destructible_v<E> is true, then this destructor shall be a trivial destructor. + + ~expected() + { + if ( has_value() ) contained.destruct_value(); + else contained.destruct_error(); + } + + // x.x.4.3 assignment + + expected & operator=( expected const & other ) + { + expected( other ).swap( *this ); + return *this; + } + + expected & operator=( expected && other ) noexcept + ( + std::is_nothrow_move_constructible< T>::value + && std::is_nothrow_move_assignable< T>::value + && std::is_nothrow_move_constructible<E>::value // added for missing + && std::is_nothrow_move_assignable< E>::value ) // nothrow above + { + expected( std::move( other ) ).swap( *this ); + return *this; + } + + template< typename U + nsel_REQUIRES_T( + !std::is_same<expected<T,E>, typename std20::remove_cvref<U>::type>::value + && std17::conjunction<std::is_scalar<T>, std::is_same<T, std::decay<U>> >::value + && std::is_constructible<T ,U>::value + && std::is_assignable< T&,U>::value + && std::is_nothrow_move_constructible<E>::value ) + > + expected & operator=( U && value ) + { + expected( std::forward<U>( value ) ).swap( *this ); + return *this; + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G const&>::value && + std::is_copy_constructible<G>::value // TODO: std::is_nothrow_copy_constructible<G> + && std::is_copy_assignable<G>::value + ) + > + expected & operator=( future_std::unexpected_type<G> const & error ) + { + expected( unexpect, error.value() ).swap( *this ); + return *this; + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G&&>::value && + std::is_move_constructible<G>::value // TODO: std::is_nothrow_move_constructible<G> + && std::is_move_assignable<G>::value + ) + > + expected & operator=( future_std::unexpected_type<G> && error ) + { + expected( unexpect, std::move( error.value() ) ).swap( *this ); + return *this; + } + + template< typename... Args + nsel_REQUIRES_T( + std::is_nothrow_constructible<T, Args&&...>::value + ) + > + value_type & emplace( Args &&... args ) + { + expected( future_std_lite_in_place(T), std::forward<Args>(args)... ).swap( *this ); + return value(); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_nothrow_constructible<T, std::initializer_list<U>&, Args&&...>::value + ) + > + value_type & emplace( std::initializer_list<U> il, Args &&... args ) + { + expected( future_std_lite_in_place(T), il, std::forward<Args>(args)... ).swap( *this ); + return value(); + } + + // x.x.4.4 swap + + template< typename U=T, typename G=E > + nsel_REQUIRES_R( void, + std17::is_swappable< U>::value + && std17::is_swappable<G>::value + && ( std::is_move_constructible<U>::value || std::is_move_constructible<G>::value ) + ) + swap( expected & other ) noexcept + ( + std::is_nothrow_move_constructible<T>::value && std17::is_nothrow_swappable<T&>::value && + std::is_nothrow_move_constructible<E>::value && std17::is_nothrow_swappable<E&>::value + ) + { + using std::swap; + + if ( bool(*this) && bool(other) ) { swap( contained.value(), other.contained.value() ); } + else if ( ! bool(*this) && ! bool(other) ) { swap( contained.error(), other.contained.error() ); } + else if ( bool(*this) && ! bool(other) ) { error_type t( std::move( other.error() ) ); + other.contained.destruct_error(); + other.contained.construct_value( std::move( contained.value() ) ); + contained.destruct_value(); + contained.construct_error( std::move( t ) ); + bool has_value = contained.has_value(); + bool other_has_value = other.has_value(); + other.contained.set_has_value(has_value); + contained.set_has_value(other_has_value); + } + else if ( ! bool(*this) && bool(other) ) { other.swap( *this ); } + } + + // x.x.4.5 observers + + constexpr value_type const * operator ->() const + { + return assert( has_value() ), contained.value_ptr(); + } + + value_type * operator ->() + { + return assert( has_value() ), contained.value_ptr(); + } + + constexpr value_type const & operator *() const & + { + return assert( has_value() ), contained.value(); + } + + value_type & operator *() & + { + return assert( has_value() ), contained.value(); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr value_type const && operator *() const && + { + return std::move( ( assert( has_value() ), contained.value() ) ); + } + + nsel_constexpr14 value_type && operator *() && + { + return std::move( ( assert( has_value() ), contained.value() ) ); + } + +#endif + + constexpr explicit operator bool() const noexcept + { + return has_value(); + } + + constexpr bool has_value() const noexcept + { + return contained.has_value(); + } + + constexpr value_type const & value() const & + { + return has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ); + } + + value_type & value() & + { + return has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr value_type const && value() const && + { + return std::move( has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ) ); + } + + nsel_constexpr14 value_type && value() && + { + return std::move( has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ) ); + } + +#endif + + constexpr error_type const & error() const & + { + return assert( ! has_value() ), contained.error(); + } + + error_type & error() & + { + return assert( ! has_value() ), contained.error(); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr error_type const && error() const && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + + error_type && error() && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + +#endif + + constexpr unexpected_type get_unexpected() const + { + return make_unexpected( contained.error() ); + } + + template< typename Ex > + bool has_exception() const + { + using ContainedEx = typename std::remove_reference< decltype( get_unexpected().value() ) >::type; + return ! has_value() && std::is_base_of< Ex, ContainedEx>::value; + } + + template< typename U + nsel_REQUIRES_T( + std::is_copy_constructible< T>::value + && std::is_convertible<U&&, T>::value + ) + > + value_type value_or( U && v ) const & + { + return has_value() + ? contained.value() + : static_cast<T>( std::forward<U>( v ) ); + } + + template< typename U + nsel_REQUIRES_T( + std::is_move_constructible< T>::value + && std::is_convertible<U&&, T>::value + ) + > + value_type value_or( U && v ) && + { + return has_value() + ? std::move( contained.value() ) + : static_cast<T>( std::forward<U>( v ) ); + } + +#if nsel_P2505R >= 4 + template< typename G = E + nsel_REQUIRES_T( + std::is_copy_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr error_type error_or( G && e ) const & + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : contained.error(); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_move_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr14 error_type error_or( G && e ) && + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : std::move( contained.error() ); + } +#endif // nsel_P2505R >= 4 + +#if nsel_P2505R >= 3 + // Monadic operations (P2505) + template< typename F + nsel_REQUIRES_T( + detail::is_expected < detail::invoke_result_nocvref_t< F, value_type & > > ::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, value_type & >::error_type, error_type >::value + && std::is_constructible< error_type, error_type & >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, value_type & > and_then( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, value_type & >( detail::invoke( std::forward< F >( f ), value() ) ) + : detail::invoke_result_nocvref_t< F, value_type & >( unexpect, error() ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const value_type & > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const value_type & >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type & >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const value_type & > and_then( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const value_type & >( detail::invoke( std::forward< F >( f ), value() ) ) + : detail::invoke_result_nocvref_t< F, const value_type & >( unexpect, error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, value_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, value_type && >::error_type, error_type >::value + && std::is_constructible< error_type, error_type && >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, value_type && > and_then( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, value_type && >( detail::invoke( std::forward< F >( f ), std::move( value() ) ) ) + : detail::invoke_result_nocvref_t< F, value_type && >( unexpect, std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const value_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const value_type & >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type && >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const value_type && > and_then( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const value_type && >( detail::invoke( std::forward< F >( f ), std::move( value() ) ) ) + : detail::invoke_result_nocvref_t< F, const value_type && >( unexpect, std::move( error() ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type & > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, error_type & >::value_type, value_type >::value + && std::is_constructible< value_type, value_type & >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type & > or_else( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type & >( value() ) + : detail::invoke_result_nocvref_t< F, error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type & > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const error_type & >::value_type, value_type >::value + && std::is_constructible< value_type, const value_type & >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type & > or_else( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type & >( value() ) + : detail::invoke_result_nocvref_t< F, const error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, error_type && >::value_type, value_type >::value + && std::is_constructible< value_type, value_type && >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type && > or_else( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type && >( std::move( value() ) ) + : detail::invoke_result_nocvref_t< F, error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const error_type && >::value_type, value_type >::value + && std::is_constructible< value_type, const value_type && >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type && > or_else( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type && >( std::move( value() ) ) + : detail::invoke_result_nocvref_t< F, const error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F, value_type & > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, value_type & > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F, value_type & >, error_type > transform( F && f ) & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, value_type & >, error_type >( detail::invoke( std::forward< F >( f ), **this ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F, value_type & > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F, const value_type & > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, const value_type & > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F, const value_type & >, error_type > transform( F && f ) const & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, const value_type & >, error_type >( detail::invoke( std::forward< F >( f ), **this ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F, const value_type & > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F, value_type && > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, value_type && > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F, value_type && >, error_type > transform( F && f ) && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, value_type && >, error_type >( detail::invoke( std::forward< F >( f ), std::move( **this ) ) ) + : make_unexpected( std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F, value_type && > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F, const value_type && > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, const value_type && > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F, const value_type && >, error_type > transform( F && f ) const && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, const value_type && >, error_type >( detail::invoke( std::forward< F >( f ), std::move( **this ) ) ) + : make_unexpected( std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F, const value_type && > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( std::move( error() ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type & > >::value + && std::is_constructible< value_type, value_type & >::value + ) + > + nsel_constexpr14 expected< value_type, detail::transform_invoke_result_t< F, error_type & > > transform_error( F && f ) & + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, error_type & > >( in_place, **this ) + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type & > >::value + && std::is_constructible< value_type, const value_type & >::value + ) + > + nsel_constexpr expected< value_type, detail::transform_invoke_result_t< F, const error_type & > > transform_error( F && f ) const & + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, const error_type & > >( in_place, **this ) + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type && > >::value + && std::is_constructible< value_type, value_type && >::value + ) + > + nsel_constexpr14 expected< value_type, detail::transform_invoke_result_t< F, error_type && > > transform_error( F && f ) && + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, error_type && > >( in_place, std::move( **this ) ) + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type && > >::value + && std::is_constructible< value_type, const value_type && >::value + ) + > + nsel_constexpr expected< value_type, detail::transform_invoke_result_t< F, const error_type && > > transform_error( F && f ) const && + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, const error_type && > >( in_place, std::move( **this ) ) + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif +#endif // nsel_P2505R >= 3 + // unwrap() + +// template <class U, class E> +// constexpr expected<U,E> expected<expected<U,E>,E>::unwrap() const&; + +// template <class T, class E> +// constexpr expected<T,E> expected<T,E>::unwrap() const&; + +// template <class U, class E> +// expected<U,E> expected<expected<U,E>, E>::unwrap() &&; + +// template <class T, class E> +// template expected<T,E> expected<T,E>::unwrap() &&; + + // factories + +// template< typename Ex, typename F> +// expected<T,E> catch_exception(F&& f); + +// template< typename F> +// expected<decltype(func(declval<T>())),E> map(F&& func) ; + +// template< typename F> +// 'see below' bind(F&& func); + +// template< typename F> +// expected<T,E> catch_error(F&& f); + +// template< typename F> +// 'see below' then(F&& func); + +private: + detail::storage_t + < + T + ,E + , std::is_copy_constructible<T>::value && std::is_copy_constructible<E>::value + , std::is_move_constructible<T>::value && std::is_move_constructible<E>::value + > + contained; +}; + +/// class expected, void specialization + +template< typename E > +class expected<void, E> +{ +private: + template< typename, typename > friend class expected; + +public: + using value_type = void; + using error_type = E; + using unexpected_type = future_std::unexpected_type<E>; + + // x.x.4.1 constructors + + constexpr expected() noexcept + : contained( true ) + {} + + nsel_constexpr14 expected( expected const & other ) = default; + nsel_constexpr14 expected( expected && other ) = default; + + constexpr explicit expected( future_std_lite_in_place_t(void) ) + : contained( true ) + {} + + template< typename G = E + nsel_REQUIRES_T( + !std::is_convertible<G const &, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( E{ error.value() } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_convertible<G const &, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( error.value() ); + } + + template< typename G = E + nsel_REQUIRES_T( + !std::is_convertible<G&&, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( E{ std::move( error.value() ) } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_convertible<G&&, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( std::move( error.value() ) ); + } + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, Args&&... args ) + : contained( false ) + { + contained.emplace_error( std::forward<Args>( args )... ); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, std::initializer_list<U>, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, std::initializer_list<U> il, Args&&... args ) + : contained( false ) + { + contained.emplace_error( il, std::forward<Args>( args )... ); + } + + // destructor + + ~expected() + { + if ( ! has_value() ) + { + contained.destruct_error(); + } + } + + // x.x.4.3 assignment + + expected & operator=( expected const & other ) + { + expected( other ).swap( *this ); + return *this; + } + + expected & operator=( expected && other ) noexcept + ( + std::is_nothrow_move_assignable<E>::value && + std::is_nothrow_move_constructible<E>::value ) + { + expected( std::move( other ) ).swap( *this ); + return *this; + } + + void emplace() + { + expected().swap( *this ); + } + + // x.x.4.4 swap + + template< typename G = E > + nsel_REQUIRES_R( void, + std17::is_swappable<G>::value + && std::is_move_constructible<G>::value + ) + swap( expected & other ) noexcept + ( + std::is_nothrow_move_constructible<E>::value && std17::is_nothrow_swappable<E&>::value + ) + { + using std::swap; + + if ( ! bool(*this) && ! bool(other) ) { swap( contained.error(), other.contained.error() ); } + else if ( bool(*this) && ! bool(other) ) { contained.construct_error( std::move( other.error() ) ); + bool has_value = contained.has_value(); + bool other_has_value = other.has_value(); + other.contained.set_has_value(has_value); + contained.set_has_value(other_has_value); + } + else if ( ! bool(*this) && bool(other) ) { other.swap( *this ); } + } + + // x.x.4.5 observers + + constexpr explicit operator bool() const noexcept + { + return has_value(); + } + + constexpr bool has_value() const noexcept + { + return contained.has_value(); + } + + void value() const + { + if ( ! has_value() ) + { + error_traits<error_type>::rethrow( contained.error() ); + } + } + + constexpr error_type const & error() const & + { + return assert( ! has_value() ), contained.error(); + } + + error_type & error() & + { + return assert( ! has_value() ), contained.error(); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr error_type const && error() const && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + + error_type && error() && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + +#endif + + constexpr unexpected_type get_unexpected() const + { + return make_unexpected( contained.error() ); + } + + template< typename Ex > + bool has_exception() const + { + using ContainedEx = typename std::remove_reference< decltype( get_unexpected().value() ) >::type; + return ! has_value() && std::is_base_of< Ex, ContainedEx>::value; + } + +#if nsel_P2505R >= 4 + template< typename G = E + nsel_REQUIRES_T( + std::is_copy_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr error_type error_or( G && e ) const & + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : contained.error(); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_move_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr14 error_type error_or( G && e ) && + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : std::move( contained.error() ); + } +#endif // nsel_P2505R >= 4 + +#if nsel_P2505R >= 3 + // Monadic operations (P2505) + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, error_type & >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F > and_then( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, error() ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type & >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F > and_then( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, error_type && >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F > and_then( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type && >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F > and_then( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, std::move( error() ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type & > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, error_type & >::value_type >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type & > or_else( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type & >() + : detail::invoke_result_nocvref_t< F, error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type & > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, const error_type & >::value_type >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type & > or_else( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type & >() + : detail::invoke_result_nocvref_t< F, const error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type && > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, error_type && >::value_type >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type && > or_else( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type && >() + : detail::invoke_result_nocvref_t< F, error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type && > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, const error_type && >::value_type >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type && > or_else( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type && >() + : detail::invoke_result_nocvref_t< F, const error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) const & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) const && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type & > >::value + ) + > + nsel_constexpr14 expected< void, detail::transform_invoke_result_t< F, error_type & > > transform_error( F && f ) & + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, error_type & > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type & > >::value + ) + > + nsel_constexpr expected< void, detail::transform_invoke_result_t< F, const error_type & > > transform_error( F && f ) const & + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, const error_type & > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type && > >::value + ) + > + nsel_constexpr14 expected< void, detail::transform_invoke_result_t< F, error_type && > > transform_error( F && f ) && + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, error_type && > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type && > >::value + ) + > + nsel_constexpr expected< void, detail::transform_invoke_result_t< F, const error_type && > > transform_error( F && f ) const && + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, const error_type && > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif +#endif // nsel_P2505R >= 3 + +// template constexpr 'see below' unwrap() const&; +// +// template 'see below' unwrap() &&; + + // factories + +// template< typename Ex, typename F> +// expected<void,E> catch_exception(F&& f); +// +// template< typename F> +// expected<decltype(func()), E> map(F&& func) ; +// +// template< typename F> +// 'see below' bind(F&& func) ; +// +// template< typename F> +// expected<void,E> catch_error(F&& f); +// +// template< typename F> +// 'see below' then(F&& func); + +private: + detail::storage_t + < + void + , E + , std::is_copy_constructible<E>::value + , std::is_move_constructible<E>::value + > + contained; +}; + +// x.x.4.6 expected<>: comparison operators + +template< typename T1, typename E1, typename T2, typename E2 + nsel_REQUIRES_T( + !std::is_void<T1>::value && !std::is_void<T2>::value + ) +> +constexpr bool operator==( expected<T1,E1> const & x, expected<T2,E2> const & y ) +{ + return bool(x) != bool(y) ? false : bool(x) ? *x == *y : x.error() == y.error(); +} + +template< typename T1, typename E1, typename T2, typename E2 + nsel_REQUIRES_T( + std::is_void<T1>::value && std::is_void<T2>::value + ) +> +constexpr bool operator==( expected<T1,E1> const & x, expected<T2,E2> const & y ) +{ + return bool(x) != bool(y) ? false : bool(x) || static_cast<bool>( x.error() == y.error() ); +} + +template< typename T1, typename E1, typename T2, typename E2 > +constexpr bool operator!=( expected<T1,E1> const & x, expected<T2,E2> const & y ) +{ + return !(x == y); +} + +#if nsel_P0323R <= 2 + +template< typename T, typename E > +constexpr bool operator<( expected<T,E> const & x, expected<T,E> const & y ) +{ + return (!y) ? false : (!x) ? true : *x < *y; +} + +template< typename T, typename E > +constexpr bool operator>( expected<T,E> const & x, expected<T,E> const & y ) +{ + return (y < x); +} + +template< typename T, typename E > +constexpr bool operator<=( expected<T,E> const & x, expected<T,E> const & y ) +{ + return !(y < x); +} + +template< typename T, typename E > +constexpr bool operator>=( expected<T,E> const & x, expected<T,E> const & y ) +{ + return !(x < y); +} + +#endif + +// x.x.4.7 expected: comparison with T + +template< typename T1, typename E1, typename T2 + nsel_REQUIRES_T( + !std::is_void<T1>::value + ) +> +constexpr bool operator==( expected<T1,E1> const & x, T2 const & v ) +{ + return bool(x) ? *x == v : false; +} + +template< typename T1, typename E1, typename T2 + nsel_REQUIRES_T( + !std::is_void<T1>::value + ) +> +constexpr bool operator==(T2 const & v, expected<T1,E1> const & x ) +{ + return bool(x) ? v == *x : false; +} + +template< typename T1, typename E1, typename T2 > +constexpr bool operator!=( expected<T1,E1> const & x, T2 const & v ) +{ + return bool(x) ? *x != v : true; +} + +template< typename T1, typename E1, typename T2 > +constexpr bool operator!=( T2 const & v, expected<T1,E1> const & x ) +{ + return bool(x) ? v != *x : true; +} + +#if nsel_P0323R <= 2 + +template< typename T, typename E > +constexpr bool operator<( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? *x < v : true; +} + +template< typename T, typename E > +constexpr bool operator<( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? v < *x : false; +} + +template< typename T, typename E > +constexpr bool operator>( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? *x < v : false; +} + +template< typename T, typename E > +constexpr bool operator>( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? v < *x : false; +} + +template< typename T, typename E > +constexpr bool operator<=( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? ! ( *x < v ) : false; +} + +template< typename T, typename E > +constexpr bool operator<=( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? ! ( v < *x ) : true; +} + +template< typename T, typename E > +constexpr bool operator>=( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? ! ( *x < v ) : false; +} + +template< typename T, typename E > +constexpr bool operator>=( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? ! ( v < *x ) : true; +} + +#endif // nsel_P0323R + +// x.x.4.8 expected: comparison with unexpected_type + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator==( expected<T1,E1> const & x, unexpected_type<E2> const & u ) +{ + return (!x) ? x.get_unexpected() == u : false; +} + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator==( unexpected_type<E2> const & u, expected<T1,E1> const & x ) +{ + return ( x == u ); +} + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator!=( expected<T1,E1> const & x, unexpected_type<E2> const & u ) +{ + return ! ( x == u ); +} + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator!=( unexpected_type<E2> const & u, expected<T1,E1> const & x ) +{ + return ! ( x == u ); +} + +#if nsel_P0323R <= 2 + +template< typename T, typename E > +constexpr bool operator<( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return (!x) ? ( x.get_unexpected() < u ) : false; +} + +template< typename T, typename E > +constexpr bool operator<( unexpected_type<E> const & u, expected<T,E> const & x ) +{ + return (!x) ? ( u < x.get_unexpected() ) : true ; +} + +template< typename T, typename E > +constexpr bool operator>( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return ( u < x ); +} + +template< typename T, typename E > +constexpr bool operator>( unexpected_type<E> const & u, expected<T,E> const & x ) +{ + return ( x < u ); +} + +template< typename T, typename E > +constexpr bool operator<=( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return ! ( u < x ); +} + +template< typename T, typename E > +constexpr bool operator<=( unexpected_type<E> const & u, expected<T,E> const & x) +{ + return ! ( x < u ); +} + +template< typename T, typename E > +constexpr bool operator>=( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return ! ( u > x ); +} + +template< typename T, typename E > +constexpr bool operator>=( unexpected_type<E> const & u, expected<T,E> const & x ) +{ + return ! ( x > u ); +} + +#endif // nsel_P0323R + +/// x.x.x Specialized algorithms + +template< typename T, typename E + nsel_REQUIRES_T( + ( std::is_void<T>::value || std::is_move_constructible<T>::value ) + && std::is_move_constructible<E>::value + && std17::is_swappable<T>::value + && std17::is_swappable<E>::value ) +> +void swap( expected<T,E> & x, expected<T,E> & y ) noexcept ( noexcept ( x.swap(y) ) ) +{ + x.swap( y ); +} + +#if nsel_P0323R <= 3 + +template< typename T > +constexpr auto make_expected( T && v ) -> expected< typename std::decay<T>::type > +{ + return expected< typename std::decay<T>::type >( std::forward<T>( v ) ); +} + +// expected<void> specialization: + +auto inline make_expected() -> expected<void> +{ + return expected<void>( in_place ); +} + +template< typename T > +constexpr auto make_expected_from_current_exception() -> expected<T> +{ + return expected<T>( make_unexpected_from_current_exception() ); +} + +template< typename T > +auto make_expected_from_exception( std::exception_ptr v ) -> expected<T> +{ + return expected<T>( unexpected_type<std::exception_ptr>( std::forward<std::exception_ptr>( v ) ) ); +} + +template< typename T, typename E > +constexpr auto make_expected_from_error( E e ) -> expected<T, typename std::decay<E>::type> +{ + return expected<T, typename std::decay<E>::type>( make_unexpected( e ) ); +} + +template< typename F + nsel_REQUIRES_T( ! std::is_same<typename std::result_of<F()>::type, void>::value ) +> +/*nsel_constexpr14*/ +auto make_expected_from_call( F f ) -> expected< typename std::result_of<F()>::type > +{ + try + { + return make_expected( f() ); + } + catch (...) + { + return make_unexpected_from_current_exception(); + } +} + +template< typename F + nsel_REQUIRES_T( std::is_same<typename std::result_of<F()>::type, void>::value ) +> +/*nsel_constexpr14*/ +auto make_expected_from_call( F f ) -> expected<void> +{ + try + { + f(); + return make_expected(); + } + catch (...) + { + return make_unexpected_from_current_exception(); + } +} + +#endif // nsel_P0323R + +} // namespace expected_lite + +using namespace expected_lite; + +// using expected_lite::expected; +// using ... + +} // namespace future_std + +namespace std { + +// expected: hash support + +template< typename T, typename E > +struct hash< future_std::expected<T,E> > +{ + using result_type = std::size_t; + using argument_type = future_std::expected<T,E>; + + constexpr result_type operator()(argument_type const & arg) const + { + return arg ? std::hash<T>{}(*arg) : result_type{}; + } +}; + +// TBD - ?? remove? see spec. +template< typename T, typename E > +struct hash< future_std::expected<T&,E> > +{ + using result_type = std::size_t; + using argument_type = future_std::expected<T&,E>; + + constexpr result_type operator()(argument_type const & arg) const + { + return arg ? std::hash<T>{}(*arg) : result_type{}; + } +}; + +// TBD - implement +// bool(e), hash<expected<void,E>>()(e) shall evaluate to the hashing true; +// otherwise it evaluates to an unspecified value if E is exception_ptr or +// a combination of hashing false and hash<E>()(e.error()). + +template< typename E > +struct hash< future_std::expected<void,E> > +{ +}; + +} // namespace std + +namespace future_std { + +// void unexpected() is deprecated && removed in C++17 + +#if nsel_CPP17_OR_GREATER || nsel_COMPILER_MSVC_VERSION > 141 +template< typename E > +using unexpected = unexpected_type<E>; +#endif + +} // namespace future_std + +#undef nsel_REQUIRES +#undef nsel_REQUIRES_0 +#undef nsel_REQUIRES_T + +nsel_RESTORE_WARNINGS() + +#endif // nsel_USES_STD_EXPECTED + +#endif // AIDGE_CORE_UTILS_FUTURE_STD_EXPECTED_H_ diff --git a/include/aidge/utilsParsing/AstNode.hpp b/include/aidge/utilsParsing/AstNode.hpp index 1158ae148a22993476adb00ecbf8ebd24101830c..bf4f73236fb65b88da309e71ba55997b5342df41 100644 --- a/include/aidge/utilsParsing/AstNode.hpp +++ b/include/aidge/utilsParsing/AstNode.hpp @@ -1,7 +1,7 @@ -#ifndef _AIDGE_AST_NODE_H_ -#define _AIDGE_AST_NODE_H_ +#ifndef AIDGE_CORE_AST_NODE_H_ +#define AIDGE_CORE_AST_NODE_H_ #include <string> #include <type_traits> @@ -12,11 +12,11 @@ namespace Aidge{ template <typename EnumType> - class AstNode: public std::enable_shared_from_this<AstNode> + class AstNode: public std::enable_shared_from_this<AstNode<EnumType>> { static_assert(std::is_enum<EnumType>::value, "AstNode EnumType must be an enum type"); public: - AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode>> child ={}):mToken(token),mChild(child){} + AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode<EnumType>>> child ={}):mToken(token),mChild(child){} /** * @brief get the type of the token * @return the type @@ -41,7 +41,7 @@ namespace Aidge{ } /** * @brief test if the node is a leaf in the tree - * @return true if a leaf + * @return true if a leaf */ bool isLeaf() const { return mChild.size() == 0; @@ -66,4 +66,4 @@ namespace Aidge{ }; } -#endif //_AIDGE_AST_NODE_H_ +#endif //AIDGE_CORE_AST_NODE_H_ diff --git a/include/aidge/utilsParsing/ParsingToken.hpp b/include/aidge/utilsParsing/ParsingToken.hpp index 78045cf3085a18bfd0565354fd34aef02ef395bd..e303a5eabe6f7710873468f8edc8f3e844f4175f 100644 --- a/include/aidge/utilsParsing/ParsingToken.hpp +++ b/include/aidge/utilsParsing/ParsingToken.hpp @@ -1,13 +1,15 @@ -#ifndef _AIDGE_PARSING_TOKEN_H_ -#define _AIDGE_PARSING_TOKEN_H_ +#ifndef AIDGE_CORE_PARSING_TOKEN_H_ +#define AIDGE_CORE_PARSING_TOKEN_H_ #include <string> #include <type_traits> +#include <sstream> // Include the necessary header namespace Aidge{ + template <typename EnumType> - class ParsingToken: public std::enable_shared_from_this<ParsingToken> + class ParsingToken: public std::enable_shared_from_this<ParsingToken<EnumType>> { static_assert(std::is_enum<EnumType>::value, "ParsingToken EnumType must be an enum type"); public: @@ -16,11 +18,11 @@ namespace Aidge{ * @param type one of the token type * @param lexeme String representing aditional information of the token */ - ParsingToken(const EnumType type , const std::string lexeme )mLexeme(lexeme),mType(type){} + ParsingToken(const EnumType type , const std::string lexeme ):mLexeme(lexeme),mType(type){} /** * @brief get the lexeme - * @return std::string + * @return std::string */ const std::string getLexeme(void){ return mLexeme; @@ -28,8 +30,8 @@ namespace Aidge{ /** * @brief get the token type - * - * @return ParsingToken + * + * @return ParsingToken */ const EnumType getType(void){ return mType; @@ -39,7 +41,10 @@ namespace Aidge{ * @brief copy the token * @return deep copy of the token */ - std::shared_ptr<Aidge::ParsingToken> copy(); + std::shared_ptr<ParsingToken> copy(){ + auto newToken = std::make_shared<ParsingToken<EnumType>>(mType,mLexeme); + return newToken; + } //TODO std::ostringstream rep(void){ @@ -47,6 +52,7 @@ namespace Aidge{ out << " Token (" << mLexeme <<")" << "\n"; return out; } + private: /** @@ -63,4 +69,4 @@ namespace Aidge{ }; } -#endif //_AIDGE_PARSING_TOKEN_H_ \ No newline at end of file +#endif //AIDGE_CORE_PARSING_TOKEN_H_ diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 11189f2f3c4a46b31d8e08d73bea17f27df07765..34610069079ee792ebbe4b261b57177b3bbe2997 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -10,11 +10,112 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/operator/Operator.hpp" #include "aidge/backend/OperatorImpl.hpp" namespace py = pybind11; namespace Aidge { + +/** + * @brief Trampoline class for binding + * + */ +class pyOperatorImpl: public OperatorImpl { +public: + using OperatorImpl::OperatorImpl; // Inherit constructors + + void forward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + forward, + + ); + } + void backward() override { + PYBIND11_OVERRIDE( + void, + OperatorImpl, + backward, + + ); + } + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_data", + getNbRequiredData, + inputIdx + ); + } + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_required_protected", + getNbRequiredProtected, + inputIdx + + ); + } + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t> &inputsSize) const override { + PYBIND11_OVERRIDE_NAME( + NbElts_t, + OperatorImpl, + "get_required_memory", + getRequiredMemory, + outputIdx, + inputsSize + + ); + } + NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_consumed_data", + getNbConsumedData, + inputIdx + + ); + } + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override { + PYBIND11_OVERRIDE_NAME( + NbElts_t, + OperatorImpl, + "get_nb_produced_data", + getNbProducedData, + outputIdx + + ); + } + void updateConsummerProducer() override { + PYBIND11_OVERRIDE_NAME( + void, + OperatorImpl, + "update_consummer_producer", + updateConsummerProducer, + + ); + } +}; + void init_OperatorImpl(py::module& m){ - py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>>(m, "OperatorImpl"); + + py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) + .def(py::init<const Operator&>()) + .def("forward", &OperatorImpl::forward) + .def("backward", &OperatorImpl::backward) + .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) + .def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected) + .def("get_required_memory", &OperatorImpl::getRequiredMemory) + .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) + .def("get_nb_produced_data", &OperatorImpl::getNbProducedData) + .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) + ; } } diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index d6442723ecc79527e8eaa7d3e03a466c085dfa58..31470e0eb2c50b5386b64498f89419801b133d3a 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -26,10 +26,10 @@ namespace Aidge { template<typename T> void addCtor(py::class_<Tensor, - std::shared_ptr<Tensor>, - Data, + std::shared_ptr<Tensor>, + Data, Registrable<Tensor, - std::tuple<std::string, DataType>, + std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { /* Request a buffer descriptor from Python */ @@ -46,24 +46,27 @@ void addCtor(py::class_<Tensor, }else{ printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); } - + return newTensor; - })); + })) + .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) + .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) + ; } void init_Tensor(py::module& m){ py::class_<Registrable<Tensor, - std::tuple<std::string, DataType>, + std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor&)>, std::shared_ptr<Registrable<Tensor, - std::tuple<std::string, DataType>, + std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable"); - py::class_<Tensor, std::shared_ptr<Tensor>, - Data, + py::class_<Tensor, std::shared_ptr<Tensor>, + Data, Registrable<Tensor, - std::tuple<std::string, DataType>, + std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); @@ -74,6 +77,8 @@ void init_Tensor(py::module& m){ .def("size", &Tensor::size) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize) .def("has_impl", &Tensor::hasImpl) + .def("get_coord", &Tensor::getCoord) + .def("get_idx", &Tensor::getIdx) .def_static("get_available_backends", &Tensor::getAvailableBackends) .def("__str__", [](Tensor& b) { return b.toString(); @@ -82,15 +87,27 @@ void init_Tensor(py::module& m){ return b.size(); }) .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { - // TODO : Should return error if backend not compatible with get if (idx >= b.size()) throw py::index_error(); switch(b.dataType()){ case DataType::Float64: - return py::cast(static_cast<double*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<double>(idx)); + case DataType::Float32: + return py::cast(b.get<float>(idx)); + case DataType::Int32: + return py::cast(b.get<int>(idx)); + default: + return py::none(); + } + }) + .def("__getitem__", [](Tensor& b, std::vector<size_t> coordIdx)-> py::object { + if (b.getIdx(coordIdx) >= b.size()) throw py::index_error(); + switch(b.dataType()){ + case DataType::Float64: + return py::cast(b.get<double>(coordIdx)); case DataType::Float32: - return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<float>(coordIdx)); case DataType::Int32: - return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<int>(coordIdx)); default: return py::none(); } @@ -126,12 +143,12 @@ void init_Tensor(py::module& m){ } return py::buffer_info( - tensorImpl->rawPtr(), /* Pointer to buffer */ - tensorImpl->scalarSize(), /* Size of one scalar */ - dataFormatDescriptor, /* Python struct-style format descriptor */ - b.nbDims(), /* Number of dimensions */ - dims, /* Buffer dimensions */ - strides /* Strides (in bytes) for each index */ + tensorImpl->rawPtr(), /* Pointer to buffer */ + tensorImpl->scalarSize(), /* Size of one scalar */ + dataFormatDescriptor, /* Python struct-style format descriptor */ + b.nbDims(), /* Number of dimensions */ + dims, /* Buffer dimensions */ + strides /* Strides (in bytes) for each index */ ); }); @@ -142,6 +159,6 @@ void init_Tensor(py::module& m){ // #if SIZE_MAX != 0xFFFFFFFF addCtor<double>(pyClassTensor); // #endif - + } } diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index 62b86982053d82bef6e0fd80e490632b95b968e5..e3666d247324fc419570611f41bbe67c7c68cc4e 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -136,6 +136,16 @@ void init_Node(py::module& m) { :rtype: int )mydelimiter") + .def("get_parents", &Node::getParents, + R"mydelimiter( + Get parents. + )mydelimiter") + + .def("get_children", (std::set<std::shared_ptr<Node>> (Node::*)() const) &Node::getChildren, + R"mydelimiter( + Get children. + )mydelimiter") + .def("__call__", &Node::operator(), py::arg("connectors")); } } // namespace Aidge diff --git a/python_binding/graph/pybind_OpArgs.cpp b/python_binding/graph/pybind_OpArgs.cpp index 305c0b73101a97c242413ff84a5ae099764e7e77..6ea89f91945ac44f2142c5b9e8440b11ec6a1663 100644 --- a/python_binding/graph/pybind_OpArgs.cpp +++ b/python_binding/graph/pybind_OpArgs.cpp @@ -10,19 +10,20 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> + #include "aidge/graph/OpArgs.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" -#include <pybind11/stl.h> -#include <pybind11/complex.h> -#include <pybind11/functional.h> -#include <pybind11/chrono.h> + namespace py = pybind11; namespace Aidge { void init_OpArgs(py::module& m){ py::class_<OpArgs, std::shared_ptr<OpArgs>>(m, "OpArgs") + .def(py::init<const std::shared_ptr<GraphView>&>(), py::arg("view_")) + .def(py::init<const std::shared_ptr<Node>&>(), py::arg("node_")) .def("node", &OpArgs::node) .def("view", &OpArgs::view) ; diff --git a/python_binding/operator/pybind_Add.cpp b/python_binding/operator/pybind_Add.cpp index d7099e3856d48262f0f4bbacf025f5a960a220fa..0b2323c5cfb660415ec3ae009beaa7aa78afca0b 100644 --- a/python_binding/operator/pybind_Add.cpp +++ b/python_binding/operator/pybind_Add.cpp @@ -12,7 +12,6 @@ #include <pybind11/pybind11.h> #include "aidge/operator/Add.hpp" -#include "aidge/utils/Parameter.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/utils/Types.h" @@ -21,9 +20,11 @@ namespace py = pybind11; namespace Aidge { template <std::size_t NUM> void declare_Add(py::module &m) { - py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "Add_Op", py::multiple_inheritance()); + py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "AddOp", py::multiple_inheritance()) + .def("get_inputs_name", &Add_Op<NUM>::getInputsName) + .def("get_outputs_name", &Add_Op<NUM>::getOutputsName); - m.def("Add", &Add<NUM>, py::arg("name") = nullptr); + m.def("Add", &Add<NUM>, py::arg("name") = ""); } void init_Add(py::module &m) { diff --git a/python_binding/operator/pybind_AvgPooling.cpp b/python_binding/operator/pybind_AvgPooling.cpp index 66dadba7244a199bd4ca8a0dd814f20a8049a62f..fe67fcb7a26f6ea1f05577b47444df5cb271110a 100644 --- a/python_binding/operator/pybind_AvgPooling.cpp +++ b/python_binding/operator/pybind_AvgPooling.cpp @@ -8,7 +8,7 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ -#ifdef PYBIND + #include <pybind11/pybind11.h> #include <pybind11/stl.h> @@ -16,7 +16,6 @@ #include <vector> #include <array> -#include "aidge/utils/Parameter.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/Operator.hpp" @@ -27,52 +26,27 @@ namespace py = pybind11; namespace Aidge { template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) { - py::class_<AvgPooling_Op<DIM>, std::shared_ptr<AvgPooling_Op<DIM>>, Operator, PyAbstractParametrizable>( + py::class_<AvgPooling_Op<DIM>, std::shared_ptr<AvgPooling_Op<DIM>>, Operator, Attributes>( m, ("AvgPoolingOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &>(), + const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), - py::arg("stride_dims"), - py::arg("padding_dims")); - - m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](std::vector<DimSize_t>& kernel_dims, - const char* name, - std::vector<DimSize_t> &stride_dims, - std::vector<DimSize_t> &padding_dims) { - // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. - if (kernel_dims.size() != DIM) { - throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (stride_dims.size() != DIM) { - throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } - DimSize_t tmp_kernel_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_kernel_dims_array[i] = kernel_dims[i]; - } - DimSize_t tmp_stride_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_stride_dims_array[i] = stride_dims[i]; - } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } - const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; - const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - return AvgPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array)); + py::arg("stride_dims")) + .def("get_inputs_name", &AvgPooling_Op<DIM>::getInputsName) + .def("get_outputs_name", &AvgPooling_Op<DIM>::getOutputsName); + + m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims) { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + + return AvgPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin())); }, py::arg("kernel_dims"), - py::arg("name") = nullptr, - py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0)); - + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1)); + } @@ -80,10 +54,9 @@ void init_AvgPooling(py::module &m) { declare_AvgPoolingOp<1>(m); declare_AvgPoolingOp<2>(m); declare_AvgPoolingOp<3>(m); - + // FIXME: // m.def("AvgPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&AvgPooling)); } } // namespace Aidge -#endif \ No newline at end of file diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index 52578c55ac0e3e1112bdbedc15bbaa3e155d9b44..cabaa2edd7053718160fa5013492d1914ee4cf16 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -14,7 +14,6 @@ #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Operator.hpp" -#include "aidge/utils/Parameter.hpp" #include "aidge/utils/Types.h" namespace py = pybind11; @@ -22,9 +21,11 @@ namespace Aidge { template <DimSize_t DIM> void declare_BatchNormOp(py::module& m) { - py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, PyAbstractParametrizable>(m, ("BatchNorm_Op" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()); + py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, Attributes>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) + .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) + .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); - m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = nullptr); + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } void init_BatchNorm(py::module &m) { diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 3cf5d818f9b6e3bdfaf9a2d0b74ec0480beb6967..f4f7946c6ecc180f83e4bf58eee16102752f0c6e 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -11,12 +11,11 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> - +#include <iostream> #include <string> #include <vector> #include <array> -#include "aidge/utils/Parameter.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/Operator.hpp" @@ -26,72 +25,40 @@ namespace py = pybind11; namespace Aidge { template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { - py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Operator, PyAbstractParametrizable>( + py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Operator, Attributes>( m, ("ConvOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) .def(py::init<DimSize_t, DimSize_t, const std::array<DimSize_t, DIM> &, const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &, const std::array<DimSize_t, DIM> &>(), py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_dims"), py::arg("stride_dims"), - py::arg("padding_dims"), - py::arg("dilation_dims")); - + py::arg("dilation_dims")) + .def("get_inputs_name", &Conv_Op<DIM>::getInputsName) + .def("get_outputs_name", &Conv_Op<DIM>::getOutputsName) + ; + m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, DimSize_t out_channels, - std::vector<DimSize_t>& kernel_dims, - const char* name, - std::vector<DimSize_t> &stride_dims, - std::vector<DimSize_t> &padding_dims, - std::vector<DimSize_t> &dilation_dims) { - // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. - if (kernel_dims.size() != DIM) { - throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (stride_dims.size() != DIM) { - throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } - if (dilation_dims.size() != DIM) { - throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - DimSize_t tmp_kernel_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_kernel_dims_array[i] = kernel_dims[i]; - } - DimSize_t tmp_stride_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_stride_dims_array[i] = stride_dims[i]; - } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } - DimSize_t tmp_dilation_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_dilation_dims_array[i] = dilation_dims[i]; - } - const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; - const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; - return Conv<DIM>(in_channels, out_channels, to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &dilation_dims) { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return Conv<DIM>(in_channels, out_channels, to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<DIM>(dilation_dims.begin())); }, py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_dims"), - py::arg("name") = nullptr, + py::arg("name") = "", py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); - } @@ -99,7 +66,7 @@ void init_Conv(py::module &m) { declare_ConvOp<1>(m); declare_ConvOp<2>(m); declare_ConvOp<3>(m); - + // FIXME: // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&Conv)); diff --git a/python_binding/operator/pybind_ConvDepthWise.cpp b/python_binding/operator/pybind_ConvDepthWise.cpp index b64409bdbb5f094e85cb094017a6fb837893a2db..4745ef345264763f1a890d566235be072c8e50d8 100644 --- a/python_binding/operator/pybind_ConvDepthWise.cpp +++ b/python_binding/operator/pybind_ConvDepthWise.cpp @@ -16,7 +16,6 @@ #include <vector> #include <array> -#include "aidge/utils/Parameter.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/Operator.hpp" @@ -27,64 +26,32 @@ namespace py = pybind11; namespace Aidge { template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { - py::class_<ConvDepthWise_Op<DIM>, std::shared_ptr<ConvDepthWise_Op<DIM>>, Operator, PyAbstractParametrizable>( + py::class_<ConvDepthWise_Op<DIM>, std::shared_ptr<ConvDepthWise_Op<DIM>>, Operator, Attributes>( m, ("ConvDepthWiseOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &, const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), py::arg("stride_dims"), - py::arg("padding_dims"), - py::arg("dilation_dims")); - - m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](std::vector<DimSize_t>& kernel_dims, - const char* name, - std::vector<DimSize_t> &stride_dims, - std::vector<DimSize_t> &padding_dims, - std::vector<DimSize_t> &dilation_dims) { - // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. - if (kernel_dims.size() != DIM) { - throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (stride_dims.size() != DIM) { - throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } - if (dilation_dims.size() != DIM) { - throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - DimSize_t tmp_kernel_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_kernel_dims_array[i] = kernel_dims[i]; - } - DimSize_t tmp_stride_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_stride_dims_array[i] = stride_dims[i]; - } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } - DimSize_t tmp_dilation_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_dilation_dims_array[i] = dilation_dims[i]; - } - const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; - const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; - return ConvDepthWise<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + py::arg("dilation_dims")) + .def("get_inputs_name", &ConvDepthWise_Op<DIM>::getInputsName) + .def("get_outputs_name", &ConvDepthWise_Op<DIM>::getOutputsName); + + m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &dilation_dims) { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return ConvDepthWise<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<DIM>(dilation_dims.begin())); }, py::arg("kernel_dims"), - py::arg("name") = nullptr, + py::arg("name") = "", py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); - + } @@ -92,7 +59,7 @@ void init_ConvDepthWise(py::module &m) { declare_ConvDepthWiseOp<1>(m); declare_ConvDepthWiseOp<2>(m); declare_ConvDepthWiseOp<3>(m); - + // FIXME: // m.def("ConvDepthWise1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&ConvDepthWise)); diff --git a/python_binding/operator/pybind_Div.cpp b/python_binding/operator/pybind_Div.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3492bf244952ba6ed0d77cb16de758e61fb26383 --- /dev/null +++ b/python_binding/operator/pybind_Div.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Div.hpp" +#include "aidge/operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Div(py::module& m) { + py::class_<Div_Op, std::shared_ptr<Div_Op>, Operator>(m, "DivOp", py::multiple_inheritance()) + .def("get_inputs_name", &Div_Op::getInputsName) + .def("get_outputs_name", &Div_Op::getOutputsName); + + m.def("Div", &Div, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index 82eaa0062b7db0e57da3d78d56e503e3a4beb19f..c6a1c70000e3e6d604a6652716667efa1c18e956 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -12,7 +12,6 @@ #include <pybind11/pybind11.h> #include "aidge/operator/FC.hpp" -#include "aidge/utils/Parameter.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/utils/Types.h" @@ -21,9 +20,11 @@ namespace py = pybind11; namespace Aidge { void declare_FC(py::module &m) { - py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator, PyAbstractParametrizable>(m, "FC_Op", py::multiple_inheritance()); + py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator, Attributes>(m, "FCOp", py::multiple_inheritance()) + .def("get_inputs_name", &FC_Op::getInputsName) + .def("get_outputs_name", &FC_Op::getOutputsName); - m.def("FC", &FC, py::arg("out_channels"), py::arg("nobias") = false, py::arg("name") = nullptr); + m.def("FC", &FC, py::arg("out_channels"), py::arg("nobias") = false, py::arg("name") = ""); } void init_FC(py::module &m) { diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index 578d2ccd2ed143c3f9a67c0430c12aa7214cb8dc..241fc7f4a003f53de15a42859b078c54cc98b63a 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -11,6 +11,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <pybind11/functional.h> #include <stdio.h> #include "aidge/backend/OperatorImpl.hpp" @@ -20,48 +21,13 @@ namespace py = pybind11; namespace Aidge { void init_GenericOperator(py::module& m) { - py::class_<GenericOperator_Op, std::shared_ptr<GenericOperator_Op>, Operator>(m, "GenericOperatorOp", + py::class_<GenericOperator_Op, std::shared_ptr<GenericOperator_Op>, Operator, DynamicAttributes>(m, "GenericOperatorOp", py::multiple_inheritance()) - .def("get_parameter_type", &GenericOperator_Op::getParameterType) - .def("get_parameters_name", &GenericOperator_Op::getParametersName) - .def("add_parameter", &GenericOperator_Op::addParameter<bool>) - .def("add_parameter", &GenericOperator_Op::addParameter<int>) - .def("add_parameter", &GenericOperator_Op::addParameter<float>) - .def("add_parameter", &GenericOperator_Op::addParameter<std::string>) - .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<bool>>) - .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<int>>) - .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<float>>) - .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<std::string>>) - .def("get_parameter", [](GenericOperator_Op& self, std::string key) -> py::object { - /* - This getParameter method returns the good python type without having to have - prior knowledge of the parameter type. - */ - py::object res = py::none(); - std::string paramType = self.getParameterType(key); - if(paramType == typeid(int).name()) - res = py::cast(self.getParameter<int>(key)); - else if(paramType == typeid(float).name()) - res = py::cast(self.getParameter<float>(key)); - else if(paramType == typeid(bool).name()) - res = py::cast(self.getParameter<bool>(key)); - else if(paramType == typeid(std::string).name()) - res = py::cast(self.getParameter<std::string>(key)); - else if(paramType == typeid(std::vector<bool>).name()) - res = py::cast(self.getParameter<std::vector<bool>>(key)); - else if(paramType == typeid(std::vector<int>).name()) - res = py::cast(self.getParameter<std::vector<int>>(key)); - else if(paramType == typeid(std::vector<float>).name()) - res = py::cast(self.getParameter<std::vector<float>>(key)); - else if(paramType == typeid(std::vector<std::string>).name()) - res = py::cast(self.getParameter<std::vector<std::string>>(key)); - else { - throw py::key_error("Failed to convert parameter type " + key + ", this issue may come from typeid function which gave an unknown key : [" + paramType + "]. Please open an issue asking to add the support for this key."); - } - return res; - }); + .def_readonly_static("identity", &GenericOperator_Op::Identity) + .def("compute_output_dims", &GenericOperator_Op::computeOutputDims) + .def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function")); - m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"), - py::arg("name") = nullptr); + m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nb_data_in"), py::arg("nb_in"), py::arg("nb_out"), + py::arg("name") = ""); } } // namespace Aidge diff --git a/python_binding/operator/pybind_LeakyReLU.cpp b/python_binding/operator/pybind_LeakyReLU.cpp index 27a292f0baf2673f3d963f3c3b9a69892c4c6521..af7689f0e64dd4ca8f798dcb34ea968972ace464 100644 --- a/python_binding/operator/pybind_LeakyReLU.cpp +++ b/python_binding/operator/pybind_LeakyReLU.cpp @@ -13,14 +13,15 @@ #include "aidge/operator/LeakyReLU.hpp" #include "aidge/operator/Operator.hpp" -#include "aidge/utils/Parameter.hpp" namespace py = pybind11; namespace Aidge { void init_LeakyReLU(py::module& m) { - py::class_<LeakyReLU_Op, std::shared_ptr<LeakyReLU_Op>, Operator, PyAbstractParametrizable>(m, "LeakyReLU_Op", py::multiple_inheritance()); + py::class_<LeakyReLU_Op, std::shared_ptr<LeakyReLU_Op>, Operator, Attributes>(m, "LeakyReLUOp", py::multiple_inheritance()) + .def("get_inputs_name", &LeakyReLU_Op::getInputsName) + .def("get_outputs_name", &LeakyReLU_Op::getOutputsName); - m.def("LeakyReLU", &LeakyReLU, py::arg("negative_slope") = 0.0f, py::arg("name") = nullptr); + m.def("LeakyReLU", &LeakyReLU, py::arg("negative_slope") = 0.0f, py::arg("name") = ""); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Matmul.cpp b/python_binding/operator/pybind_Matmul.cpp index c81845ca5e5ba3674356d16db660f4e3550e9004..fdb51b24a87ce358c1e7808873ebc569ca2227c8 100644 --- a/python_binding/operator/pybind_Matmul.cpp +++ b/python_binding/operator/pybind_Matmul.cpp @@ -11,8 +11,7 @@ #include <pybind11/pybind11.h> -#include "aidge/operator/Matmul.hpp" -#include "aidge/utils/Parameter.hpp" +#include "aidge/operator/MatMul.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/utils/Types.h" @@ -20,13 +19,15 @@ namespace py = pybind11; namespace Aidge { -void declare_Matmul(py::module &m) { - py::class_<Matmul_Op, std::shared_ptr<Matmul_Op>, Operator, PyAbstractParametrizable>(m, "Matmul_Op", py::multiple_inheritance()); +void declare_MatMul(py::module &m) { + py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, Operator, Attributes>(m, "MatMulOp", py::multiple_inheritance()) + .def("get_inputs_name", &MatMul_Op::getInputsName) + .def("get_outputs_name", &MatMul_Op::getOutputsName); - m.def("Matmul", &Matmul, py::arg("out_channels"), py::arg("name") = nullptr); + m.def("MatMul", &MatMul, py::arg("out_channels"), py::arg("name") = ""); } -void init_Matmul(py::module &m) { - declare_Matmul(m); +void init_MatMul(py::module &m) { + declare_MatMul(m); } } // namespace Aidge diff --git a/python_binding/operator/pybind_MaxPooling.cpp b/python_binding/operator/pybind_MaxPooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..907e8cfaa6cde2451677b72beab38bd9a3938735 --- /dev/null +++ b/python_binding/operator/pybind_MaxPooling.cpp @@ -0,0 +1,63 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <vector> +#include <array> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Types.h" +#include "aidge/data/Tensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_MaxPoolingOp(py::module &m) { + py::class_<MaxPooling_Op<DIM>, std::shared_ptr<MaxPooling_Op<DIM>>, Operator, Attributes>( + m, ("MaxPoolingOp" + std::to_string(DIM) + "D").c_str(), + py::multiple_inheritance()) + .def(py::init<const std::array<DimSize_t, DIM> &, + const std::array<DimSize_t, DIM> &, + bool>(), + py::arg("kernel_dims"), + py::arg("stride_dims"), + py::arg("ceil_mode")) + .def("get_inputs_name", &MaxPooling_Op<DIM>::getInputsName) + .def("get_outputs_name", &MaxPooling_Op<DIM>::getOutputsName); + + m.def(("MaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + bool ceil_mode) { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + + return MaxPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), ceil_mode); + }, py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("ceil_mode") = false); + +} + + +void init_MaxPooling(py::module &m) { + declare_MaxPoolingOp<1>(m); + declare_MaxPoolingOp<2>(m); + declare_MaxPoolingOp<3>(m); + +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa9f3c50e6b8c6ab9e7be46776d5fba30d775be2 --- /dev/null +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -0,0 +1,126 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <vector> +#include <array> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_PaddedConvOp(py::module &m) { + m.def(("PaddedConv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, + DimSize_t out_channels, + const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims, + const std::vector<DimSize_t> &dilation_dims) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return PaddedConv<DIM>(in_channels, out_channels, to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin()), to_array<DIM>(dilation_dims.begin())); + }, py::arg("in_channels"), + py::arg("out_channels"), + py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0), + py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); +} + +template <DimIdx_t DIM> void declare_PaddedConvDepthWiseOp(py::module &m) { + m.def(("PaddedConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims, + const std::vector<DimSize_t> &dilation_dims) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return PaddedConvDepthWise<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin()), to_array<DIM>(dilation_dims.begin())); + }, py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0), + py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); + +} + +template <DimIdx_t DIM> void declare_PaddedAvgPoolingOp(py::module &m) { + m.def(("PaddedAvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + + return PaddedAvgPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin())); + }, py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0)); + +} + +template <DimIdx_t DIM> void declare_PaddedMaxPoolingOp(py::module &m) { + m.def(("PaddedMaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims, + bool ceil_mode) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + + return PaddedMaxPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin()), ceil_mode); + }, py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0), + py::arg("ceil_mode") = false); + +} + +void init_MetaOperatorDefs(py::module &m) { + declare_PaddedConvOp<1>(m); + declare_PaddedConvOp<2>(m); + declare_PaddedConvOp<3>(m); + declare_PaddedConvDepthWiseOp<1>(m); + declare_PaddedConvDepthWiseOp<2>(m); + declare_PaddedConvDepthWiseOp<3>(m); + declare_PaddedAvgPoolingOp<1>(m); + declare_PaddedAvgPoolingOp<2>(m); + declare_PaddedAvgPoolingOp<3>(m); + declare_PaddedMaxPoolingOp<1>(m); + declare_PaddedMaxPoolingOp<2>(m); + declare_PaddedMaxPoolingOp<3>(m); + + +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Mul.cpp b/python_binding/operator/pybind_Mul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2627c99005b009769e8fbb97b1f5d79e2424c997 --- /dev/null +++ b/python_binding/operator/pybind_Mul.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Mul(py::module& m) { + py::class_<Mul_Op, std::shared_ptr<Mul_Op>, Operator>(m, "MulOp", py::multiple_inheritance()) + .def("get_inputs_name", &Mul_Op::getInputsName) + .def("get_outputs_name", &Mul_Op::getOutputsName); + + m.def("Mul", &Mul, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index ac9a34e0a14ace2cf264188302f52a27bf0f7222..6b535e8cf3293b26aaa64f95ca2f9a394768935f 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,9 +20,13 @@ void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("output", &Operator::output, py::arg("outputIdx")) .def("input", &Operator::input, py::arg("inputIdx")) + .def("nb_data_inputs", &Operator::nbDataInputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) + .def("forward", &Operator::forward) + // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! + .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) ; } } diff --git a/python_binding/operator/pybind_Pow.cpp b/python_binding/operator/pybind_Pow.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22866c5460381b6f494948c7410bcd67e7e46edb --- /dev/null +++ b/python_binding/operator/pybind_Pow.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Pow.hpp" +#include "aidge/operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Pow(py::module& m) { + py::class_<Pow_Op, std::shared_ptr<Pow_Op>, Operator>(m, "PowOp", py::multiple_inheritance()) + .def("get_inputs_name", &Pow_Op::getInputsName) + .def("get_outputs_name", &Pow_Op::getOutputsName); + + m.def("Pow", &Pow, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 5757891a30c5b40dcfa5ff99b1f06e00376f475a..107b7ba00e4077d9f7c215257bf7fd46629481c1 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -13,7 +13,6 @@ #include <pybind11/stl.h> #include "aidge/utils/Types.h" -#include "aidge/utils/Parameter.hpp" // #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" @@ -25,19 +24,22 @@ namespace Aidge { template <DimIdx_t DIM> void declare_Producer(py::module &m) { // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); - m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const char*)>(&Producer), py::arg("dims"), py::arg("name") = nullptr); - + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = ""); + } void init_Producer(py::module &m) { py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>( - m, - "ProducerOp", + m, + "ProducerOp", py::multiple_inheritance()) - .def("dims", &Producer_Op::dims); - m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const char*)>(&Producer), py::arg("tensor"), py::arg("name") = nullptr); - + .def("dims", &Producer_Op::dims) + .def("set_output_tensor", &Producer_Op::setOutputTensor) + .def("get_inputs_name", &Producer_Op::getInputsName) + .def("get_outputs_name", &Producer_Op::getOutputsName); + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); + declare_Producer<1>(m); declare_Producer<2>(m); declare_Producer<3>(m); diff --git a/python_binding/operator/pybind_ReLU.cpp b/python_binding/operator/pybind_ReLU.cpp index e0d34d5a91a4ed1fcb8507198eb222b2d02e4e26..dbcb483e8089373bc8599c2d09fed00049e2a2ac 100644 --- a/python_binding/operator/pybind_ReLU.cpp +++ b/python_binding/operator/pybind_ReLU.cpp @@ -18,8 +18,10 @@ namespace py = pybind11; namespace Aidge { void init_ReLU(py::module& m) { - py::class_<ReLU_Op, std::shared_ptr<ReLU_Op>, Operator>(m, "ReLU_Op", py::multiple_inheritance()); + py::class_<ReLU_Op, std::shared_ptr<ReLU_Op>, Operator>(m, "ReLUOp", py::multiple_inheritance()) + .def("get_inputs_name", &ReLU_Op::getInputsName) + .def("get_outputs_name", &ReLU_Op::getOutputsName); - m.def("ReLU", &ReLU, py::arg("name") = nullptr); + m.def("ReLU", &ReLU, py::arg("name") = ""); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Softmax.cpp b/python_binding/operator/pybind_Softmax.cpp index 13ba96ade4f5c5d132274e457efa5b4edcd3dc78..8e50ab7c83bf43285b357cb803c0ce3eb42f4cc7 100644 --- a/python_binding/operator/pybind_Softmax.cpp +++ b/python_binding/operator/pybind_Softmax.cpp @@ -19,8 +19,10 @@ namespace py = pybind11; namespace Aidge { void init_Softmax(py::module& m) { - py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "Softmax_Op", py::multiple_inheritance()); + py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "SoftmaxOp", py::multiple_inheritance()) + .def("get_inputs_name", &Softmax_Op::getInputsName) + .def("get_outputs_name", &Softmax_Op::getOutputsName); - m.def("Softmax", &Softmax, py::arg("name") = nullptr); + m.def("Softmax", &Softmax, py::arg("name") = ""); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Sqrt.cpp b/python_binding/operator/pybind_Sqrt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b70171814662c861f19b3048b018260170d37491 --- /dev/null +++ b/python_binding/operator/pybind_Sqrt.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Sqrt.hpp" +#include "aidge/operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Sqrt(py::module& m) { + py::class_<Sqrt_Op, std::shared_ptr<Sqrt_Op>, Operator>(m, "SqrtOp", py::multiple_inheritance()) + .def("get_inputs_name", &Sqrt_Op::getInputsName) + .def("get_outputs_name", &Sqrt_Op::getOutputsName); + + m.def("Sqrt", &Sqrt, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Sub.cpp b/python_binding/operator/pybind_Sub.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10c95939646a6b605f23c42618bfbdd00ceb6e2e --- /dev/null +++ b/python_binding/operator/pybind_Sub.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Sub.hpp" +#include "aidge/operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Sub(py::module& m) { + py::class_<Sub_Op, std::shared_ptr<Sub_Op>, Operator>(m, "SubOp", py::multiple_inheritance()) + .def("get_inputs_name", &Sub_Op::getInputsName) + .def("get_outputs_name", &Sub_Op::getOutputsName); + + m.def("Sub", &Sub, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 42b10f7b07dab09348d06bc02c9c726aaa9c1842..d2fd317b5801a5170852e9233cb1062622b51ddf 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -18,7 +18,7 @@ void init_Data(py::module&); void init_Database(py::module& m); void init_Tensor(py::module&); void init_OperatorImpl(py::module&); -void init_Parameterizable(py::module&); +void init_Attributes(py::module&); void init_Operator(py::module&); void init_Add(py::module&); @@ -26,13 +26,20 @@ void init_AvgPooling(py::module&); void init_BatchNorm(py::module&); void init_Conv(py::module&); void init_ConvDepthWise(py::module&); +void init_Div(py::module&); void init_FC(py::module&); void init_GenericOperator(py::module&); void init_LeakyReLU(py::module&); -void init_Matmul(py::module&); +void init_MatMul(py::module&); +void init_MaxPooling(py::module&); +void init_MetaOperatorDefs(py::module&); +void init_Mul(py::module&); void init_Producer(py::module&); +void init_Pow(py::module&); void init_ReLU(py::module&); void init_Softmax(py::module&); +void init_Sqrt(py::module&); +void init_Sub(py::module&); void init_Node(py::module&); void init_GraphView(py::module&); @@ -46,16 +53,10 @@ void init_GRegex(py::module&); void init_Recipies(py::module&); void init_Scheduler(py::module&); +void init_TensorUtils(py::module&); -void set_python_flag(){ - // Set an env variable to know if we run with ypthon or cpp - py::module os_module = py::module::import("os"); - os_module.attr("environ")["AIDGE_CORE_WITH_PYBIND"] = "1"; -} - void init_Aidge(py::module& m){ - set_python_flag(); init_Data(m); init_Database(m); init_Tensor(m); @@ -66,19 +67,26 @@ void init_Aidge(py::module& m){ init_Connector(m); init_OperatorImpl(m); - init_Parameterizable(m); + init_Attributes(m); init_Operator(m); init_Add(m); init_AvgPooling(m); init_BatchNorm(m); init_Conv(m); init_ConvDepthWise(m); + init_Div(m); init_FC(m); init_GenericOperator(m); init_LeakyReLU(m); - init_Matmul(m); + init_MatMul(m); + init_MaxPooling(m); + init_MetaOperatorDefs(m); + init_Mul(m); + init_Pow(m); init_ReLU(m); init_Softmax(m); + init_Sqrt(m); + init_Sub(m); init_Producer(m); init_Match(m); @@ -86,6 +94,7 @@ void init_Aidge(py::module& m){ init_GRegex(m); init_Recipies(m); init_Scheduler(m); + init_TensorUtils(m); } PYBIND11_MODULE(aidge_core, m) { diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index b4147dcb4fb82dbfe9f5b4605604725c6945ece9..93c131ef7417135bfdbc657c5c809339430616ed 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -20,24 +20,51 @@ namespace py = pybind11; namespace Aidge { void init_Recipies(py::module &m) { - m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator. - - Parameters - ---------- + + + m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of `aidge.node` + :type nodes: list of :py:class:`aidge_core.Node` + )mydelimiter"); + + m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a flatten operator. + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter( + m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( Recipie to remove a flatten operator. - - Parameters - ---------- + :param nodes: The flatten operator to remove. - :type nodes: list of `aidge.node` + :type nodes: list of :py:class:`aidge_core.Node` + )mydelimiter"); + m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + :param nodes: The MatMul and Add nodes to fuse. + :type nodes: list of :py:class:`aidge_core.Node` + )mydelimiter"); + + m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a flatten operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( + Recipie to remove a flatten operator. + + :param nodes: The flatten operator to remove. + :type nodes: list of :py:class:`aidge_core.Node` )mydelimiter"); - } } // namespace Aidge diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 2490d5c55a497223b13bceee6772c2dd44e733ef..85479d41f51e74dee4079e78a37e7f3a520639e2 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> #include "aidge/scheduler/Scheduler.hpp" #include "aidge/graph/GraphView.hpp" @@ -20,6 +21,8 @@ void init_Scheduler(py::module& m){ .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) + .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) + .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling) ; } } diff --git a/python_binding/utils/pybind_Parameter.cpp b/python_binding/utils/pybind_Parameter.cpp index 358316ea00413813d6d482a8a4601e69af3aa992..2957876f31ad0781a36905cef3a5ae88934b6a8a 100644 --- a/python_binding/utils/pybind_Parameter.cpp +++ b/python_binding/utils/pybind_Parameter.cpp @@ -1,12 +1,36 @@ #include <pybind11/pybind11.h> -#include "aidge/utils/Parameter.hpp" +#include "aidge/utils/Attributes.hpp" +#include "aidge/utils/DynamicAttributes.hpp" namespace py = pybind11; namespace Aidge { -void init_Parameterizable(py::module& m){ - py::class_<PyAbstractParametrizable, std::shared_ptr<PyAbstractParametrizable>>(m, "PyAbstractParametrizable") - .def("get", &PyAbstractParametrizable::getPy, py::arg("name")) - ; +DynamicAttributes test_DynamicAttributes_binding() { + DynamicAttributes attrs; + attrs.addAttr<int>("a", 42); + attrs.addAttr<std::string>("b", "test"); + attrs.addAttr<std::vector<bool>>("c", {true, false, true}); + return attrs; } + +double test_DynamicAttributes_binding_check(DynamicAttributes& attrs) { + return attrs.getAttr<double>("d"); +} + +void init_Attributes(py::module& m){ + py::class_<Attributes, std::shared_ptr<Attributes>>(m, "Attributes") + .def("has_attr", &Attributes::hasAttr, py::arg("name")) + .def("get_attr_type", &Attributes::getAttrType, py::arg("name")) + .def("get_attrs_name", &Attributes::getAttrsName) + .def("get_attr", &Attributes::getAttrPy, py::arg("name")); + + py::class_<DynamicAttributes, std::shared_ptr<DynamicAttributes>, Attributes>(m, "DynamicAttributes") + .def("add_attr", &DynamicAttributes::addAttrPy, py::arg("name"), py::arg("value")) + .def("set_attr", &DynamicAttributes::setAttrPy, py::arg("name"), py::arg("value")) + .def("del_attr", &DynamicAttributes::delAttr, py::arg("name")); + + m.def("test_DynamicAttributes_binding", &test_DynamicAttributes_binding); + m.def("test_DynamicAttributes_binding_check", &test_DynamicAttributes_binding_check, py::arg("attrs")); +} + } diff --git a/python_binding/utils/pybind_TensorUtils.cpp b/python_binding/utils/pybind_TensorUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..78825a5f3b8d45f22f76c57bd780dc7019fbc123 --- /dev/null +++ b/python_binding/utils/pybind_TensorUtils.cpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> + +#include "aidge/utils/TensorUtils.hpp" + +namespace py = pybind11; + +namespace Aidge { + +template<typename T> +void addTensorUtilsFunction(py::module &m){ + m.def("approx_eq", + & approxEq<T>, + py::arg("t1"), + py::arg("t2"), + py::arg("relative"), + py::arg("absolute"), + R"mydelimiter( + Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is: + |t1-t2| <= absolute + relative * |t2| + + If a tensor value is different from the other tensor return False + If the tensor does not have the same size, return False + If the datatype is not the same between each tensor return False + If the templated type does not correspond to the datatype of each tensor, raise an assertion error + + :param t1: first tensor to test + :type t1: :py:class:`aidge_core.Tensor` + :param t2: second tensor to test + :type t2: :py:class:`aidge_core.Tensor` + :param relative: relative difference allowed (should be betwen 0 and 1) + :type relative: float + :param absolute: absolute error allowed (shoulmd be positive) + :type absolute: float + )mydelimiter"); +} + +void init_TensorUtils(py::module &m) { + addTensorUtilsFunction<float>(m); + addTensorUtilsFunction<double>(m); + addTensorUtilsFunction<int>(m); + addTensorUtilsFunction<long>(m); +} +} // namespace Aidge diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..24ce15ab7ead32f98c7ac3edcd34bb2010ff4326 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +numpy diff --git a/setup.ps1 b/setup.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..61324cf4a7d64094f5ead498adf64719c3290f06 --- /dev/null +++ b/setup.ps1 @@ -0,0 +1,52 @@ +# Helper setup tool to automatically build aidge_core on Windows. + +# Requirements +################################################################################ +# You have either VS BuildTools or VS Community already present on your +# system, with the build tools installed. +# If not, download Visual Studio Community here: +# https://visualstudio.microsoft.com/fr/vs/community/ +# Make sure to install the "Desktop Development with C++" workload. +# Run this script in a Powershell console with Administrator rights in order to +# automatically install the dependencies, or just execute the second part if you +# already have all the dependencies satisfied. + +# Enable or disable automatic installation of requirements +# Run .\setup.ps1 -install_reqs:$false to disable it +param ([bool]$install_reqs=$true) + +# Default install path is .\install_cpp +if (-not $env:AIDGE_INSTALL_PATH) +{ + $env:AIDGE_INSTALL_PATH = $(Join-Path $pwd install_cpp) +} + +# 1. Setup environment +################################################################################ +if ($install_reqs) +{ + # Install Chocolatey + Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + # Install dependencies + choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + choco install git -Y + choco install python -Y + # Update PATH + $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") +} + +# 2. Compile & install aidge_core +################################################################################ +mkdir -Force build_cpp +mkdir -Force $env:AIDGE_INSTALL_PATH +Set-Location build_cpp +cmake -DCMAKE_INSTALL_PREFIX:PATH=$env:AIDGE_INSTALL_PATH -DCMAKE_BUILD_TYPE=Debug .. +if(!$?) { $lastError = $LASTEXITCODE; Set-Location $PSScriptRoot; Exit $lastError } +cmake --build . -j2 +if(!$?) { $lastError = $LASTEXITCODE; Set-Location $PSScriptRoot; Exit $lastError } +cmake --install . --config Debug +if(!$?) { $lastError = $LASTEXITCODE; Set-Location $PSScriptRoot; Exit $lastError } +# Optional: run the unit tests +ctest --output-on-failure +if(!$?) { $lastError = $LASTEXITCODE; Set-Location $PSScriptRoot; Exit $lastError } +Set-Location $PSScriptRoot diff --git a/setup.py b/setup.py index 0b0f66e9132d66cdb6385d7f8c6c69ae0cc5d0e3..60807df560510ad4cfacfdd2b178aca957306439 100644 --- a/setup.py +++ b/setup.py @@ -62,15 +62,17 @@ class CMakeBuild(build_ext): os.chdir(str(build_temp)) - # Impose to use the executable of the python + # Impose to use the executable of the python # used to launch setup.py to setup PythonInterp param_py = "-DPYTHON_EXECUTABLE=" + sys.executable - - install_path = f"{build_temp}/install" if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"] - self.spawn(['cmake', str(cwd), param_py, '-DTEST=OFF', f'-DCMAKE_INSTALL_PREFIX:PATH={install_path}']) + compile_type = 'Debug' + install_path = os.path.join(sys.prefix, "lib", "libAidge") if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"] + + self.spawn(['cmake', str(cwd), param_py, '-DTEST=OFF', f'-DCMAKE_INSTALL_PREFIX:PATH={install_path}', f'-DCMAKE_BUILD_TYPE={compile_type}']) if not self.dry_run: - self.spawn(['make', 'all', 'install', '-j', max_jobs]) + self.spawn(['cmake', '--build', '.', '--config', compile_type, '-j', max_jobs]) + self.spawn(['cmake', '--install', '.', '--config', compile_type]) os.chdir(str(cwd)) aidge_package = build_lib / (get_project_name()) @@ -81,13 +83,13 @@ class CMakeBuild(build_ext): # Copy all shared object files from build_temp/lib to aidge_package for root, _, files in os.walk(build_temp.absolute()): for file in files: - if file.endswith('.so') and (root != str(aidge_package.absolute())): + if (file.endswith('.so') or file.endswith('.pyd')) and (root != str(aidge_package.absolute())): currentFile=os.path.join(root, file) - shutil.copy(currentFile, str(aidge_package.absolute())) + shutil.copy(currentFile, str(aidge_package.absolute())) # Copy version.txt in aidge_package os.chdir(os.path.dirname(__file__)) - shutil.copy("version.txt", str(aidge_package.absolute())) + shutil.copy("version.txt", str(aidge_package.absolute())) if __name__ == '__main__': @@ -100,7 +102,6 @@ if __name__ == '__main__': long_description_content_type="text/markdown", long_description="\n".join(DOCLINES[2:]), classifiers=[c for c in CLASSIFIERS.split('\n') if c], - platforms=["Linux"], packages=find_packages(where="."), include_package_data=True, ext_modules=[CMakeExtension(get_project_name())], diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..166754cc9fe9774d922ef523ab35f569673701fd --- /dev/null +++ b/src/backend/OperatorImpl.cpp @@ -0,0 +1,77 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" + +Aidge::OperatorImpl::OperatorImpl(const Operator& op): + mOp(op), + mNbConsumedData(mOp.nbInputs(), 0), + mNbProducedData(mOp.nbOutputs(), 0) +{ + //ctor +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { + assert(mOp.getInput(inputIdx) && "requires valid input"); + + // Requires the whole tensor by default + return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { + assert(mOp.getInput(inputIdx) && "requires valid input"); + + // Protect the whole tensor by default + return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); +} + +Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { + assert(mOp.getOutput(outputIdx) && "requires valid output"); + + // Requires the whole tensor by default, regardless of available data on inputs + return std::static_pointer_cast<Tensor>(mOp.getOutput(outputIdx))->size(); +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { + assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size()); + return mNbConsumedData[static_cast<std::size_t>(inputIdx)]; +} + +Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const { + assert(static_cast<std::size_t>(outputIdx) < mNbProducedData.size()); + return mNbProducedData[static_cast<std::size_t>(outputIdx)]; +} + +void Aidge::OperatorImpl::updateConsummerProducer(){ + // Update producer-consumer data + for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) { + // each input is consumed by the minimum amount for a forward pass + mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx)); + } + + for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) { + mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {}); + } +} + +void Aidge::OperatorImpl::forward() { + AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented"); +} + +void Aidge::OperatorImpl::backward() { + AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented"); +} diff --git a/src/graph/Connector.cpp b/src/graph/Connector.cpp index f189b92b24cc5529ae8fb6d8c9faac97e296a92c..cd2ceff8b58076a5054269e4676120b94c8b5beb 100644 --- a/src/graph/Connector.cpp +++ b/src/graph/Connector.cpp @@ -39,7 +39,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ct graph->add(nodesToAdd.back()); // only add, connection already done // between nodes std::vector<std::shared_ptr<Node>> parents = nodesToAdd.back()->getParents(); - std::set<std::shared_ptr<Node>> alreadyAdded = graph->getNodes(); + const std::set<std::shared_ptr<Node>>& alreadyAdded = graph->getNodes(); for (std::shared_ptr<Node> parent : parents) { if (alreadyAdded.find(parent) == alreadyAdded.end()) { buffer.push_back(parent); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index ad412f5b86d9cf0dee0823736548baeb7c7320a7..8f8f51c89bbcc380963f355f781e8fda940dcffc 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -33,13 +33,10 @@ Aidge::Connector Aidge::GraphView::operator()( (void)input; // avoid unused warning } + IOIndex_t inID = 0; for (const Connector &ctor : ctors) { assert((ctor.node() != nullptr) && "Input Connector must be associated with a node"); - (void)ctors; // avoid unused warning - } - IOIndex_t inID = 0; - for (const Connector &ctor : ctors) { ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()), {inNode, inID++}); } @@ -128,21 +125,17 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { - IOIndex_t nbDataIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbDataIn += inputNode->nbDataInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn); - nbDataIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbDataIn); - nbDataIn += inputNode->nbDataInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } @@ -150,21 +143,17 @@ Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { - std::size_t nbIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbIn += inputNode->nbInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn); - nbIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbIn); - nbIn += inputNode->nbInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } @@ -197,7 +186,7 @@ void Aidge::GraphView::forwardDims() { { assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); } - + } } // Compute dimensions of every node @@ -326,7 +315,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara // add learnable parameters to the graph if (includeLearnableParam) { for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { - std::shared_ptr<Node> parentNode = node->getParents(static_cast<IOIndex_t>(i)); + std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i)); if (parentNode) { parentNode->addView(shared_from_this()); mNodes.insert(parentNode); @@ -464,13 +453,13 @@ Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const { std::shared_ptr<Aidge::Node> -Aidge::GraphView::getNode(const char *nodeName) const { +Aidge::GraphView::getNode(const std::string& nodeName) const { std::map<std::string, std::shared_ptr<Node>>::const_iterator it = - mNodeRegistry.find(std::string(nodeName)); + mNodeRegistry.find(nodeName); if (it != mNodeRegistry.end()) { return it->second; } else { - printf("No Node named %s in the current GraphView.\n", nodeName); + printf("No Node named %s in the current GraphView.\n", nodeName.c_str()); exit(-1); } } @@ -522,39 +511,47 @@ void Aidge::GraphView::link(std::string /*name1_inID*/, printf("Not implemented yet.\n"); } -void Aidge::GraphView::insert(Node & /*newNode*/, Node & /*inNode*/, - std::initializer_list<Node> /*outNodes*/, - IOIndex_t /*tensorIdx*/) { - printf("Not implemented yet.\n"); +void Aidge::GraphView::insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, + IOIndex_t newParentOutputTensorIdx){ + NodePtr currentParentNode = childNode->getParent(childInputTensorIdx); + const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second; + // Remove child from current parent & current Parent from child + currentParentNode->removeChild(childNode, currentParentOutputTensorIdx); + + // Add child + currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx); + newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx); + + add(newParentNode); } + bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { // TODO : only supports one input/output node for now assert(mNodes.size()>0 && "There must be at least one Node to replace"); bool replacable; - std::shared_ptr<Node> previousInputNode; - std::shared_ptr<Node> newInputNode; - std::shared_ptr<Node> previousOutputNode; + std::shared_ptr<Node> previousInputNode = (*inputNodes().begin()); + std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin()); std::shared_ptr<Node> newOutputNode; - + auto gNew = std::make_shared<GraphView>(); gNew->add(newNodes, false); if (newNodes.empty()) { replacable = (outputNodes().size() == 1) && - (inputNodes().size() == 1) && - ((*outputNodes().begin())->nbOutputs() == 1) && - ((*inputNodes().begin())->nbInputs() == 1); - previousOutputNode = (*outputNodes().begin()); - previousInputNode = (*inputNodes().begin()); + (inputNodes().size() == 1) && + ((*outputNodes().begin())->nbOutputs() == 1) && + ((*inputNodes().begin())->nbDataInputs() == 1); newOutputNode = previousInputNode->input(0).first; } else { - replacable = ((outputNodes().size() == gNew->outputNodes().size()) && - (outputNodes().size() == 1)); - previousOutputNode = (*outputNodes().begin()); newOutputNode = (*gNew->outputNodes().begin()); - replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); + replacable = (outputNodes().size() == gNew->outputNodes().size()) && + (outputNodes().size() == 1) && + (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); } if (replacable) { @@ -673,4 +670,55 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { mOutputNodes.erase(val); } } -} \ No newline at end of file +} + +std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { + std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); + + // Map for old node -> new node correspondance + std::map<NodePtr, NodePtr> oldToNewNodes; + + for (const std::shared_ptr<Node> &node_ptr : mNodes) { + oldToNewNodes[node_ptr] = cloneNode(node_ptr); + } + + // For each node, convert old node -> new node connections + for (auto &oldToNewNode : oldToNewNodes) { + if (oldToNewNode.second == nullptr) + continue; // deleted node + + // Add new node to new GraphView + newGraph->add(oldToNewNode.second, false); + + // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr + size_t parentId = 0; + for (auto parent : oldToNewNode.first->inputs()) { + while (oldToNewNodes[parent.first] == nullptr) { + // Find next valid parent in line, going backward in the graph + assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); + const auto& parents = parent.first->inputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + parent = parents[0]; + } + else { + break; + } + } + + if (oldToNewNodes[parent.first]) { + oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); + } + + ++parentId; + } + } + + // Update OutputNodes/inputNodes + newGraph->updateInputNodes(); + newGraph->updateOutputNodes(); + + return newGraph; +} diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index b3db5befbdc8299114514d8d554d439bffc5eae2..e6a53c871f5312c68f40dc5c9a2777729470298b 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -17,8 +17,8 @@ #include <vector> #include "aidge/utils/Types.h" -Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name) - : mName((name == nullptr) ? std::string() : std::string(name)), +Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) + : mName(name), mOperator(op), mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), @@ -226,7 +226,7 @@ void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t } void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) { - if (getParents(inId) != nullptr) { + if (getParent(inId) != nullptr) { printf("Warning, you're replacing a Parent.\n"); } assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound."); @@ -321,6 +321,55 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { } } + /////////////////////////////////////////////////////// + // CLONE + /////////////////////////////////////////////////////// + +Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { + return std::make_shared<Node>(mOperator, mName); +} + +Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { + std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) + ? mOperator + : mOperator->clone(); + + return std::make_shared<Node>(op, mName); +} + +Aidge::NodePtr Aidge::Node::clone() const { + return std::make_shared<Node>(mOperator->clone(), mName); +} + + +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ + + std::set<Aidge::NodePtr> out; + nodeSee.insert(shared_from_this()); + + if(delta == 0) { + out.insert(shared_from_this()); + + }else if (delta > 0){ + for (const NodePtr& node : getChildren()) { + if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + out.insert(ch); + } + } + } + }else{ + for (const NodePtr& node : getParents()) { + if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + out.insert(pr); + } + } + } + } + + return out; +} ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/graph/OpArgs.cpp b/src/graph/OpArgs.cpp index f5f33fb049dec440f3bae412348c83e3427f06ce..124878fc45fe632d4a584e76a0eae6e7acfd53b9 100644 --- a/src/graph/OpArgs.cpp +++ b/src/graph/OpArgs.cpp @@ -14,13 +14,13 @@ #include "aidge/graph/OpArgs.hpp" -std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs> inputs) { +std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) { std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); for (const OpArgs& elt : inputs) { if(elt.node() != nullptr) { // >= to allow incomplete graphViews assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); - /* + /* * /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. * Prefer a functional description for detailed inputs */ @@ -44,7 +44,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs } -std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs> inputs) { +std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) { std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); for(const OpArgs& elt : inputs) { if (elt.node()!=nullptr) @@ -56,7 +56,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs> } -std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs> inputs) { +std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) { std::shared_ptr<GraphView> gv = Sequential(inputs); assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); std::shared_ptr<Node> lastNode = *gv->outputNodes().begin(); @@ -70,4 +70,4 @@ std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs> assert(lastNode->getNbFreeDataInputs()>=1); gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex); return gv; -} \ No newline at end of file +} diff --git a/src/graphRegex/GraphFsmInterpreter.cpp b/src/graphRegex/GraphFsmInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2984ab4fb3864244c9e32dbfcda9ef2ae080acf0 --- /dev/null +++ b/src/graphRegex/GraphFsmInterpreter.cpp @@ -0,0 +1,182 @@ +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + +using namespace Aidge; + + +GraphFsmInterpreter::GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition):mParser(graphMatchExpr){ + mActGroupe = 0; + mNodesCondition = nodesCondition; +} +std::shared_ptr<FsmGraph> GraphFsmInterpreter::interpret(void){ + mActGroupe = 0; + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); + return visit(tree); +} +std::shared_ptr<FsmGraph> GraphFsmInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ + + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); + + if(AstTree->getType() == gRegexTokenTypes::SEP){ + return sepF(visit(nextAstNodes[0]),visit(nextAstNodes[1])); + }else if(AstTree->getType() == gRegexTokenTypes::NEXT){ + return nextF(visit(nextAstNodes[0]),visit(nextAstNodes[1])); + }else if(AstTree->getType() == gRegexTokenTypes::QOM){ + return qomF(visit(nextAstNodes[0])); + }else if(AstTree->getType() == gRegexTokenTypes::QZM){ + return qzmF(visit(nextAstNodes[0])); + }else if(AstTree->getType() == gRegexTokenTypes::KEY || AstTree->getType() == gRegexTokenTypes::CKEY){ + return keyF(AstTree); + }else if(AstTree->getType() == gRegexTokenTypes::LPAREN){ + mActGroupe += 1; + std::shared_ptr<FsmGraph> out = visit(nextAstNodes[0]); + mActGroupe -= 1; + return out; + }else{ + throw std::logic_error("visit Bad token type" ); + } +} + + + + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::keyF(std::shared_ptr<AstNode<gRegexTokenTypes>> AstNode){ + + + std::shared_ptr<FsmNode> start = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> valid = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmEdge> edge; + + + if(AstNode->getType() == gRegexTokenTypes::CKEY){ + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::COMMON,mNodesCondition,AstNode->getValue()); + }else if (AstNode->getType() == gRegexTokenTypes::KEY) + { + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::UNIQUE,mNodesCondition,AstNode->getValue()); + }else{ + + throw std::logic_error("keyF Bad in AST" ); + } + + graph->addEdge(edge); + graph->setGroupe(mActGroupe); + return graph; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ + + size_t idxLeft = leftFsm->getNbSubFsm(); + rigthFsm->incOrigineAllNodeBy(idxLeft); + leftFsm->unionG(rigthFsm); + //the rigthFsm is no longer usfull + return leftFsm; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::nextF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ + /* + combine the 2 Graph + all valid node of A are merge with Start B, Start B is un Start + update the relative reference + + A B + SA -> VA + SB -> VB + A B + SA -> q -> VB + */ + leftFsm->mergeOneStartOneValid(rigthFsm); + //the rigthFsm is no longer usfull + return leftFsm; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fsm){ + /* + + + valid node is connect to the child of Start with the same edge condition + A + S -> V + + A + S -> V + (E|R) + V -> S + */ + + std::vector<std::shared_ptr<FsmNode>> allStart = fsm->getStartNodes(); + std::set<std::shared_ptr<FsmNode>> allValid = fsm->getValidNodes(); + std::shared_ptr<FsmEdge> edge; + + if(allStart.size() != 1){ + throw std::logic_error("qomF Bad in AST" ); + } + + for(auto start : allStart ){ + for(auto edgeStart :start->getEdges() ){ + if (auto sharedEdge = edgeStart.lock()) { + + const std::map<size_t, int> commonRef = sharedEdge->getRelative(); + bool haveCommon = !commonRef.empty(); + + for(auto valid : allValid){ + if(haveCommon){ + /* + the // quantif case + get the go back and make a lexeme id(number) + we need to go back to the ref delta min #TODO + */ + bool hasMinRef = false; + std::pair<size_t, int> minRef; + for (const auto& entry : commonRef) { + if (!hasMinRef || std::abs(minRef.second) > std::abs(entry.second)) { + hasMinRef = true; + minRef = entry; + } + } + std::stringstream lexem; + lexem << "(" << minRef.first << ", " << minRef.second << ")"; + edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::REF,mNodesCondition, lexem.str()); + }else{ + /* + the sequensial quantif case + no reference to common + */ + edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::EMPTY,mNodesCondition,""); + + } + fsm->addEdge(edge); + } + }else{ + throw std::runtime_error("edgeStart weak pointer is expired" ); + } + } + + } + return fsm; + +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::qzmF(std::shared_ptr<FsmGraph> fsm){ + /* + qomf and a bypass empty start to valide + */ + fsm = qomF(fsm); + + std::vector<std::shared_ptr<FsmNode>> allStart = fsm->getStartNodes(); + std::set<std::shared_ptr<FsmNode>> allValid = fsm->getValidNodes(); + std::shared_ptr<FsmEdge> edge; + + if(allStart.size() != 1){ + throw std::logic_error("qzmF Bad in AST" ); + } + + for(auto start : allStart ){ + + for(auto valid : allValid){ + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::EMPTY,mNodesCondition,""); + fsm->addEdge(edge); + } + } + + return fsm; + + +} \ No newline at end of file diff --git a/src/graphRegex/GraphLexer.cpp b/src/graphRegex/GraphLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61214f96a090fef5d28cb0ce1a009644d9570880 --- /dev/null +++ b/src/graphRegex/GraphLexer.cpp @@ -0,0 +1,155 @@ + +#include "aidge/graphRegex/GraphLexer.hpp" + +using namespace Aidge; + + +GraphLexer::GraphLexer( const std::string gRegexExpressions ): +mRegularExpressions(gRegexExpressions){ + mPosition = 0; +} + +std::shared_ptr<ParsingToken<gRegexTokenTypes>> GraphLexer::getNextToken(void){ + std::string currentChars = ""; + while (mPosition < mRegularExpressions.length()) + { + //erase all space + if (mRegularExpressions[mPosition] != ' ') + { + currentChars += mRegularExpressions[mPosition]; + } + else + { + mPosition++; + continue; + } + + ///// + // const lent token + ///// + + if (std::regex_match(currentChars,std::regex("\\->")))// the next TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::NEXT,""); + } + else if (std::regex_match(currentChars,std::regex("\\*")))// the * TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::QZM,""); + } + else if (std::regex_match(currentChars,std::regex("\\+")))// the + TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::QOM,""); + } + else if (std::regex_match(currentChars,std::regex("\\(")))// the LPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::LPAREN,""); + } + else if (std::regex_match(currentChars,std::regex("\\)")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::RPAREN,""); + } + + // + else if (std::regex_match(currentChars,std::regex(";")))// the SEP TOKEN + { + //test if the last sep + //std::string subStr = mRegularExpressions.substr(mPosition); + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::SEP,""); + } + + ///// + //unconst lent token + ///// + + else if (std::regex_match(currentChars,std::regex("[A-Za-z_0-9]")))// the KEY or CKEY + { + + //read all the key + bool isCKey = false; + std::regex keyRegex("[A-Za-z_0-9]+"); + std::regex cKeyRegex("[A-Za-z_0-9]+\\#[0-9]*"); + + while ( mPosition < mRegularExpressions.length()) { + + if(!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,cKeyRegex)) + { + currentChars.pop_back(); //the last char is the problemes + break; + } + else if (std::regex_match(currentChars,cKeyRegex)){ + isCKey = true; + } + mPosition++; + if (mPosition < mRegularExpressions.length()) currentChars += mRegularExpressions[mPosition]; + + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mRegularExpressions.length()-1) + { + if (!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,cKeyRegex)) + { + throw badTokenError(currentChars,mPosition); + } + } + + + if (isCKey){ + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::CKEY,currentChars); + } else{ + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::KEY,currentChars); + } + } + + mPosition++; + } + + + //no more to find no one match the currentChars + if (currentChars.empty()) { + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::STOP,""); // Null shared pointer ; + }else{ + throw badTokenError(currentChars,mPosition); + } + +} + +void GraphLexer::rstPosition(void){ + if (isEnd()){ + mPosition = 0; + }else{ + throw badTokenError("end rst",mPosition); + } +} + +bool GraphLexer::isEnd(void){ + return mPosition >= mRegularExpressions.length(); +} + +std::runtime_error GraphLexer::badTokenError(const std::string& currentChars,std::size_t position){ + std::ostringstream errorMessage; + errorMessage << "\nBad syntax " << currentChars << " :\n" << mRegularExpressions << "\n"; + for (std::size_t i = 0; i < position; i++) { + errorMessage << ' '; + } + errorMessage << "^\n"; + + return std::runtime_error(errorMessage.str()); +} + + const std::string GraphLexer::rep(){ + std::string out = mRegularExpressions; + out += "\n"; + for (std::size_t i = 0; i < mPosition; i++) { + out += ' '; + } + out += "^\n"; + return out ; + } \ No newline at end of file diff --git a/src/graphRegex/GraphParser.cpp b/src/graphRegex/GraphParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5aa653c482dae82c2e9fa02bfc36b2ffc821785f --- /dev/null +++ b/src/graphRegex/GraphParser.cpp @@ -0,0 +1,181 @@ +#include "aidge/graphRegex/GraphParser.hpp" + +using namespace Aidge; + +GraphParser::GraphParser(const std::string gRegexExpressions): +mLexer(gRegexExpressions) +{ + mCurrentToken = mLexer.getNextToken(); +} + + +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::parse(void){ + + std::shared_ptr<AstNode<gRegexTokenTypes>> astTree = constructAstAllExpr(); + rstParser(); + return astTree; +} + + +void GraphParser::rstParser(void){ + mLexer.rstPosition(); + mCurrentToken = mLexer.getNextToken(); +} + + +void GraphParser::ackToken(gRegexTokenTypes tokenType){ + + if(mCurrentToken->getType() == tokenType ){ + try { + mCurrentToken = mLexer.getNextToken(); + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "Graph Lexer error in Parser :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + }else{ + std::ostringstream errorMessage; + errorMessage << "Bad syntax GraphParser " << static_cast<int>(mCurrentToken->getType()) <<"!="<< static_cast<int>(tokenType) << "\n"; + errorMessage << mLexer.rep(); + throw std::runtime_error(errorMessage.str()); + } +} + +/* +exp : KEY(QOM | QZM)? | CKEY | domain +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstExp(void) +{ + + try{ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + std::shared_ptr<AstNode<gRegexTokenTypes>> node = std::make_shared<AstNode<gRegexTokenTypes>>(token); + + if (mCurrentToken->getType() == gRegexTokenTypes::KEY ){ + ackToken(gRegexTokenTypes::KEY ); + if (mCurrentToken->getType() == gRegexTokenTypes::QOM ){ + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::QOM ); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + return newNode; + }else if (mCurrentToken->getType() == gRegexTokenTypes::QZM ){ + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::QZM ); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + return newNode; + } + return node; + }else if (mCurrentToken->getType() == gRegexTokenTypes::CKEY){ + ackToken(gRegexTokenTypes::CKEY ); + return node; + }else{ + return constructAstDomain(); + } + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstExp :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} + +/* +seq :exp (NEXT seq)* +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstSeq(void) +{ + + try{ + + std::shared_ptr<AstNode<gRegexTokenTypes>> left = constructAstExp(); + if(mCurrentToken->getType() == gRegexTokenTypes::NEXT ) + { + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::NEXT); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{left,constructAstSeq()}); + left = newNode; + } + return left; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstSeq :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + +} + + +/* +LPAREN seq RPAREN (QOM | QZM) +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstDomain(void) +{ + + try{ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token ; + std::shared_ptr<AstNode<gRegexTokenTypes>> node ; + + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::LPAREN); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{constructAstSeq()}); + ackToken(gRegexTokenTypes::RPAREN); + //(QOM | QZM) + + token = mCurrentToken->copy(); + if (mCurrentToken->getType() == gRegexTokenTypes::QOM){ + ackToken(gRegexTokenTypes::QOM); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + }else if (mCurrentToken->getType() == gRegexTokenTypes::QZM){ + ackToken(gRegexTokenTypes::QZM); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + }else{ + std::ostringstream errorMessage; + errorMessage << "Bad syntax constructAstDomain must have quantifier \n"; + throw std::runtime_error(errorMessage.str()); + } + + return node; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstDomain :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} + +/* + allExpr: seq (SEP allExpr)* | STOP +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstAllExpr(void) +{ + + try{ + std::shared_ptr<AstNode<gRegexTokenTypes>> left = constructAstSeq(); + if(mCurrentToken->getType() == gRegexTokenTypes::SEP ) + { + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::SEP); + + if(mCurrentToken->getType() == gRegexTokenTypes::STOP ) + { + return left; + } + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{left,constructAstAllExpr()}); + left = newNode; + } + return left; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstDomain :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} diff --git a/src/graphRegex/GraphStrInterpreter.cpp b/src/graphRegex/GraphStrInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ad24b5b9b0fee5fba34dd7397132bec2410fd23 --- /dev/null +++ b/src/graphRegex/GraphStrInterpreter.cpp @@ -0,0 +1,38 @@ +#include "aidge/graphRegex/GraphStrInterpreter.hpp" + +using namespace Aidge; + +GraphStrInterpreter::GraphStrInterpreter(const std::string graphMatchExpr):mParser(graphMatchExpr){ + mToTest = graphMatchExpr; + mToTest.erase(std::remove_if(mToTest.begin(), mToTest.end(), ::isspace), mToTest.end()); +} + + +std::string GraphStrInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ + + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); + + if(AstTree->getType() == gRegexTokenTypes::SEP){ + return visit(nextAstNodes[0])+";"+visit(nextAstNodes[1]); + }else if(AstTree->getType() == gRegexTokenTypes::NEXT){ + return visit(nextAstNodes[0])+"->"+visit(nextAstNodes[1]); + }else if(AstTree->getType() == gRegexTokenTypes::QOM){ + return visit(nextAstNodes[0])+"+"; + }else if(AstTree->getType() == gRegexTokenTypes::QZM){ + return visit(nextAstNodes[0])+"*"; + }else if(AstTree->getType() == gRegexTokenTypes::KEY || AstTree->getType() == gRegexTokenTypes::CKEY){ + return AstTree->getValue(); + }else if(AstTree->getType() == gRegexTokenTypes::LPAREN){ + return "("+visit(nextAstNodes[0])+")"; + }else{ + throw std::logic_error("visit Bad token type" ); + } + + +} + + +std::string GraphStrInterpreter::interpret(void){ + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); + return visit(tree); +} \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp new file mode 100644 index 0000000000000000000000000000000000000000..593da06abe18576d435ae55718d379aa5b682d60 --- /dev/null +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -0,0 +1,277 @@ +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + +std::map<std::string,int> FsmEdgeCommon::mCommonIdxMap; + +bool FsmEdge::isCommon(void){ + return false; +} + +size_t FsmEdge::getCommonIdx(void){ + return std::numeric_limits<std::size_t>::max(); +} +const std::map<size_t,int>& FsmEdge::getRelative(void){ + return mRelativePos; +} +void FsmEdge::updateRelative( const std::map<size_t,int>& relativePos ){ + for (const auto& kvp : relativePos) { + mRelativePos.insert(kvp); + } +} +std::shared_ptr<FsmNode> FsmEdge::getSourceNode(void){ + return mNodeSource; +} +void FsmEdge::reSetSouceNode(const std::shared_ptr<FsmNode>& newSource){ + mNodeSource->rmEdge(shared_from_this()); + mNodeSource = newSource; + mNodeSource->addEdge(shared_from_this()); + propagateRelativePos(); + +} +std::shared_ptr<FsmNode> FsmEdge::getDestNode(void){ + return mNodeDest; +} +void FsmEdge::reSetDestNode(const std::shared_ptr<FsmNode>& newDest){ + mNodeDest->rmParent(mNodeSource); + mNodeDest = newDest; + mNodeDest->addParent(mNodeSource); + propagateRelativePos(); +} +void FsmEdge::propagateRelativePos(void){ + + std::set<int> myRelativeID; + for (const auto& kvp : mRelativePos) { + myRelativeID.insert(kvp.first); + } + + for (const auto& nextWeakEdge : mNodeDest->getEdges()){ + + if (auto nextEdge = nextWeakEdge.lock()) { + + if(this == nextEdge.get()){ + continue; + } + + + std::set<int> nextRelativeID; + for (const auto& kvp : nextEdge->getRelative()) { + nextRelativeID.insert(kvp.first); + } + + // Find elements in myRelativeID but not in nextRelativeID + std::set<int> idxsToPush; + std::set_difference(myRelativeID.begin(), myRelativeID.end(), + nextRelativeID.begin(), nextRelativeID.end(), + std::inserter(idxsToPush, idxsToPush.begin())); + + // Find elements in nextRelativeID but not in myRelativeID + std::set<int> idxsToGet; + std::set_difference(nextRelativeID.begin(), nextRelativeID.end(), + myRelativeID.begin(), myRelativeID.end(), + std::inserter(idxsToGet, idxsToGet.begin())); + + // test for integrity we look if 2 edge refert to the samme + // ref and are link the ref dif is one + // not working for common node + // we can go deeper by find the all pass to a ref and see if the delta is good + + // Find elements present in both myRelativeID and nextRelativeID + std::set<int> idxsTotest; + for (int idx : nextRelativeID){ + if (myRelativeID.find(idx) != myRelativeID.end()){ + if (std::abs(getRelative().at(idx) - nextEdge->getRelative().at(idx)) != 1) { + throw std::runtime_error("Bad relative"); + } + } + } + + + + // this egde have more relative info than the next + std::map<size_t,int> tmpRelative; + // we push this info to the next + for( auto idxToPush :idxsToPush ){ + tmpRelative.insert( std::make_pair(idxToPush, getRelative().at(idxToPush) +1)); + } + if(tmpRelative.size() != 0){ + nextEdge->updateRelative(tmpRelative); + nextEdge->propagateRelativePos(); + } + tmpRelative.clear(); + + + // the next node have more info than me i need to get it + for( auto idxToGet :idxsToGet ){ + tmpRelative.insert( std::make_pair(idxToGet, nextEdge->getRelative().at(idxToGet) -1)); + } + if(tmpRelative.size() != 0){ + updateRelative(tmpRelative); + + for(auto weakParent : getSourceNode()->getParentNodes()){ + if (auto parent = weakParent.lock()) { + for(auto weakPEdge : parent->getEdges()){ + if (auto pEdge = weakPEdge.lock()) { + pEdge->propagateRelativePos(); + }else{ + throw std::runtime_error("propagateRelativePos parent edge weak pointer is expired" ); + } + } + }else{ + throw std::runtime_error("propagateRelativePos parent weak pointer is expired" ); + } + } + } + tmpRelative.clear(); + }else{ + throw std::runtime_error("propagateRelativePos edge weak pointer is expired" ); + } + } +} + +void FsmEdge::updateWeak(void){ + mNodeSource->addEdge(shared_from_this()); + mNodeDest->addParent(mNodeSource); +} + +FsmEdge::FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) +:mToTest(toTest) +{ + mNodeSource = source; + mNodeDest = dest; + // wen i make the edge I init the nodes + // mNodeSource->addEdge(shared_from_this()); + // mNodeDest->addParent(mNodeSource); +} + + +/////surchage + +FsmEdgeUnique::FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) +:FsmEdge(source,dest,toTest) +{ +} +const EdgeTestResult FsmEdgeUnique::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + auto opNode = stmContext->getActNode(); + + if(opNode == nullptr){ + return {false,std::set<NodePtr>()};//none + } + + if(mToTest->test(opNode) && opNode->getChildren().size() <= 1){ + stmContext->setValid(opNode,mToTest); + return {true,opNode->getChildren()} ; + }else{ + stmContext->addRejectedNode(opNode); + return {false,std::set<NodePtr>()}; + } +} +///////////////////// +FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey) +:FsmEdge(source,dest,toTest) +{ + //make a uid for common node + if(mCommonIdxMap.find(commonKey) == mCommonIdxMap.end()){ + mCommonIdxMap.insert(std::make_pair(commonKey, mCommonIdxMap.size())); + } + mCommonIdx = mCommonIdxMap[commonKey]; + propagateRelativePos(); +} + + +const EdgeTestResult FsmEdgeCommon::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + + auto opNode = stmContext->getActNode(); + + if(opNode == nullptr){ + return {false,std::set<NodePtr>()};//none + } + if(mToTest->test(opNode)){ + stmContext->setCommon(opNode,mCommonIdx); + stmContext->setValid(opNode,mToTest); + return {true,opNode->getChildren()} ; + }else{ + stmContext->addRejectedNode(opNode); + return {false,std::set<NodePtr>()}; + } +} +bool FsmEdgeCommon::isCommon(void){ + return true; + } +//////////////////// TODO FsmEdgeEmpty must be size_t +FsmEdgeRef::FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx) +:FsmEdge(source,dest,nullptr),mRefCommonIdx(refCommonIdx),mdeltaCommonIdx(deltaCommonIdx) +{ + +} +const EdgeTestResult FsmEdgeRef::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + + NodePtr refNode = stmContext->getCommonNodeFromIdx(mRefCommonIdx); + if (refNode){ + std::set<std::shared_ptr<Node>> see; + return {true,refNode->getNodeDelta(mdeltaCommonIdx,see)}; + } + return {false,std::set<NodePtr>()}; +} +//////////////////// +FsmEdgeEmpty::FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest) +:FsmEdge(source,dest,nullptr) +{} +const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + auto opNode = stmContext->getActNode(); + if(opNode == nullptr){ + return {false,std::set<NodePtr>()}; + } + return {true,std::set<NodePtr>({opNode})};//none +} + +/// factory +std::shared_ptr<FsmEdge> FsmEdgeFactory::make( +std::shared_ptr<FsmNode> source, +std::shared_ptr<FsmNode> dest, FsmEdgeTypes type, +std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest, +const std::string lexeme) +{ + if (type == FsmEdgeTypes::EMPTY) { + if (lexeme.empty()) { + return std::make_shared<FsmEdgeEmpty>(source, dest); + } else { + throw std::invalid_argument("error lexem EMPTY"); + } + } else if (type == FsmEdgeTypes::REF) { + std::smatch m; + std::regex refRegex("\\s*\\(\\s*(\\d+)\\s*,\\s*(-?\\d+)\\s*\\)\\s*"); + if (std::regex_match(lexeme, m, refRegex)) { + int refCommonIdx = std::stoi(m[1]); + int deltaCommonIdx = std::stoi(m[2]); + return std::make_shared<FsmEdgeRef>(source, dest, refCommonIdx, deltaCommonIdx); + } else { + throw std::invalid_argument("error lexem REF " + lexeme); + } + } else if (type == FsmEdgeTypes::COMMON) { + std::smatch m; + std::regex commonRegex("\\s*(\\w+)#(\\d*)"); + if (std::regex_match(lexeme, m, commonRegex)) { + std::string edgeType = m[1]; + std::string commonId = m[2]; + size_t commonIdx = commonId.empty() ? 0 : std::stoi(commonId) + 1; + std::string commonKey = edgeType + std::to_string(commonIdx); + return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey); + } else { + throw std::invalid_argument("error lexem COMMON " + lexeme); + } + } else if (type == FsmEdgeTypes::UNIQUE) { + std::regex uniqueRegex("\\s*(\\w+)"); + std::smatch m; + if (std::regex_match(lexeme, m, uniqueRegex)) { + std::string edgeType = m[1]; + return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType)); + } else { + throw std::invalid_argument("error lexem UNIQUE \"" + std::string(lexeme) +" eee\""); + } + } else { + throw std::invalid_argument("Bad edge Type"); + } + } \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmGraph.cpp b/src/graphRegex/matchFsm/FsmGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a9f00d728cd2cd9f58c2228361f8393de2a3d9d --- /dev/null +++ b/src/graphRegex/matchFsm/FsmGraph.cpp @@ -0,0 +1,201 @@ +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +using namespace Aidge; + + + +FsmGraph::FsmGraph(/* args */){ + +} + +//TODO + std::shared_ptr<MatchResult> FsmGraph::test(std::vector<NodePtr>& startNodes){ + std::vector<std::shared_ptr<Aidge::FsmNode>> startNodesFsm = getStartNodes(); + if(startNodes.size() != startNodesFsm.size()){ + throw std::runtime_error("bad number of Start nodes"); + } + + std::vector<std::shared_ptr<FsmRunTimeContext>> walks; + for(std::size_t i = 0; i < startNodes.size(); i++){ + walks.push_back(std::make_shared<FsmRunTimeContext>(startNodesFsm[i],startNodes[i])); + } + std::vector<std::shared_ptr<FsmRunTimeContext>> nextWalks; + + std::vector<std::shared_ptr<FsmRunTimeContext>> allValidContext; + std::vector<std::shared_ptr<FsmRunTimeContext>> allContextSee; + + + + + while (!walks.empty()) + { + for(auto fsmContext : walks){ + allContextSee.push_back(fsmContext); + //if we are in a valid st we save it + //it's one solution of the posible solution of the matching + if(fsmContext->isOnValidState()){ + //not save 2 time the same end point + if(!std::any_of(allValidContext.begin(), allValidContext.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldValid) { + return fsmContext->areEqual(oldValid); + })){ + allValidContext.push_back(fsmContext); + } + + } + + //dont test 2 time a fsmContext + std::vector<std::shared_ptr<FsmRunTimeContext>> tmpNextWalks = fsmContext->getActState()->test(fsmContext); + for(auto PotentialFsmContext : tmpNextWalks){ + + if(!std::any_of(allContextSee.begin(), allContextSee.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldSee) { + return PotentialFsmContext->areEqual(oldSee); + })){ + nextWalks.push_back(PotentialFsmContext); + } + } + + } + walks.swap(nextWalks); + nextWalks.clear(); + } + + + return std::make_shared<MatchResult>(allValidContext,getNbSubFsm()); + +} + + +/////////////// +// FSM construction +/////////////// +const std::set<std::shared_ptr<FsmEdge>>& FsmGraph::getEdge(void){ + return mEdges; +} + +void FsmGraph::addEdge(std::shared_ptr<FsmEdge>& edge){ + edge->updateWeak(); + mEdges.insert(edge); + mAllOrigine.insert(edge->getDestNode()->getOrigine()); + mAllOrigine.insert(edge->getSourceNode()->getOrigine()); +} + +const std::vector<std::shared_ptr<FsmNode>> FsmGraph::getStartNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + std::vector<std::shared_ptr<FsmNode>> startNodes; + for(auto node :nodes){ + if(node->isStart()){ + startNodes.push_back(node); + } + } + return startNodes; +} + +const std::set<std::shared_ptr<FsmNode>> FsmGraph::getValidNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + std::set<std::shared_ptr<FsmNode>> ValidNodes; + for(auto node :nodes){ + if(node->isValid()){ + ValidNodes.insert(node); + } + } + //may short + return ValidNodes; +} + +const std::set<std::shared_ptr<FsmNode>> FsmGraph::getNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes; + for(auto edge : mEdges){ + nodes.insert(edge->getDestNode()); + nodes.insert(edge->getSourceNode()); + } + return nodes; +} + +void FsmGraph::setGroupe(std::size_t groupeIdx){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + for(auto node :nodes){ + node->setGroupe(groupeIdx); + } +} + +void FsmGraph::unionG(const std::shared_ptr<FsmGraph> fsmGraph){ + + for(auto edge : fsmGraph->getEdge()){ + addEdge(edge); + } +} + +void FsmGraph::mergeOneStartOneValid(const std::shared_ptr<FsmGraph> fsmGraph){ + std::set<std::shared_ptr<FsmNode>> validNodes = getValidNodes(); + std::vector<std::shared_ptr<FsmNode>> startNodes = fsmGraph->getStartNodes(); + + if (startNodes.size() != 1 || validNodes.size() != 1){ + + std::ostringstream errorMessage; + errorMessage <<"mergeOneStartOneValid start size: " << startNodes.size() << " valide size : " << validNodes.size() + <<" can only merge FSM 1 start 1 valide"; + throw std::runtime_error(errorMessage.str()); + } + + unionG(fsmGraph); + //for loop useless but for future merge it's coudl be used + for(auto valid : validNodes){ + valid->unValid(); + for(auto start : startNodes){ + start->unStart(); + _mergeNode(start,valid); + } + } +} + +std::size_t FsmGraph::getNbSubFsm(void){ + return mAllOrigine.size(); +} + +void FsmGraph::incOrigineAllNodeBy(std::size_t incr){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + for(auto node :nodes){ + node->incOrigine(incr); + } + std::set<std::size_t> updatedOrigin; + for(auto origin : mAllOrigine){ + updatedOrigin.insert(origin + incr); + } + mAllOrigine.swap(updatedOrigin); +} + +void FsmGraph::_mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + + if(nodes.find(source) == nodes.end() || nodes.find(dest) == nodes.end()){ + throw std::runtime_error("FsmGraph can not merge node not in the graph"); + } + nodes.clear(); + + //probagate source attribut + if(source->isValid()){ + dest->valid(); + } + if(source->isStart()){ + dest->start(); + } + + //merge source to dest by replace source by dest in all EDGE + for(auto edge : mEdges){ + if(edge->getDestNode() == source ){ + edge->reSetDestNode(dest); + }else if(edge->getSourceNode() == source ){ + edge->reSetSouceNode(dest); + } + + } + //check is source is not in graph + nodes = getNodes(); + if(nodes.find(source) != nodes.end() ){ + throw std::runtime_error("FsmGraph merge node not effective"); + } + nodes.clear(); + +} diff --git a/src/graphRegex/matchFsm/FsmNode.cpp b/src/graphRegex/matchFsm/FsmNode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84b4a0c3fdbe0730a12a2a62db9158e2538d646f --- /dev/null +++ b/src/graphRegex/matchFsm/FsmNode.cpp @@ -0,0 +1,132 @@ +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + + + +FsmNode::FsmNode(bool isAValid,bool isAStart ){ + mIsAStart =isAStart; + mIsAValid =isAValid; + +} +const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared_ptr<FsmRunTimeContext> fsmContext){ + + + std::vector<std::shared_ptr<FsmRunTimeContext>> out; + + for(auto edge : mEdges){ + if (auto sharedEdge = edge.lock()) { + + std::shared_ptr<FsmNode> nextState = sharedEdge->getDestNode(); + + //make copy of the fsmContext + std::shared_ptr<FsmRunTimeContext> newFsmContext = std::make_shared<FsmRunTimeContext>(fsmContext); + + EdgeTestResult edgeRes = sharedEdge->test(newFsmContext); + + if(edgeRes.success){ + if(edgeRes.node.size() != 0){ + for(auto nextNode :edgeRes.node ){ + if(!newFsmContext->isAlreadyValid(nextNode)|| newFsmContext->isCommonDefined(nextNode) ){ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nextNode)); + + }else{ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr)); + } + + } + }else{ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr)); + } + } + newFsmContext.reset(); + + }else{ + throw std::runtime_error("test FsmNode weak pointer is expired" ); + } + + } + return out; +} + + + +std::size_t FsmNode::getOrigine(void){ + return mOrigineStm; +} +void FsmNode::incOrigine(std::size_t inc){ + mOrigineStm += inc; +} +void FsmNode::rmEdge(std::shared_ptr<FsmEdge> edge){ + mEdges.erase(edge); +} + +void FsmNode::addEdge(std::shared_ptr<FsmEdge> edge){ + std::weak_ptr<FsmEdge> edgeW(edge); + if (!edgeW.expired()) { + mEdges.insert(edgeW); + }else{ + throw std::runtime_error("addEdge FsmNode weak pointer is expired" ); + } +} + +// const std::set<std::shared_ptr<FsmNode>> FsmNode::getChildNodes(void){ +// std::set<std::shared_ptr<FsmNode>> children; +// for(auto edge : mEdges){ +// if (auto sharedEdge = edge.lock()) { +// children.insert(sharedEdge->getDestNode()); +// }else{ +// throw std::runtime_error("getChildNodes FsmNode weak pointer is expired" ); +// } +// } +// return children; +// } + + +const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& FsmNode::getParentNodes(void){ + return mParents; +} +const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& FsmNode::getEdges(void){ + return mEdges; +} + +void FsmNode::setGroupe(std::size_t groupeIdx){ + mGroupeStm = groupeIdx; + +} + +bool FsmNode::isValid(void){ + return mIsAValid; +} +bool FsmNode::isStart(void){ + return mIsAStart; +} +void FsmNode::unValid(void){ + mIsAValid =false; +} +void FsmNode::valid(void){ + mIsAValid =true; +} +void FsmNode::unStart(void){ + mIsAStart =false; +} +void FsmNode::start(void){ + mIsAStart =true; +} + + + +void FsmNode::addParent(std::shared_ptr<FsmNode> node){ + + std::weak_ptr<FsmNode> nodeW(node); + if (!nodeW.expired()) { + mParents.insert(nodeW); + }else{ + throw std::runtime_error("addParent FsmNode weak pointer is expired" ); + } +} +void FsmNode::rmParent(std::shared_ptr<FsmNode> node){ + mParents.erase(node); +} \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..787cf2322a5b8e7001cdc59325345000dbb61553 --- /dev/null +++ b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp @@ -0,0 +1,226 @@ +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" + +using namespace Aidge; + +std::vector<std::set<NodePtr>> FsmRunTimeContext::mRejectedNodes; + +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced ){ + mActOpNode = actOpNode; + mActState = actState; + + //not define case + if(idxRejeced == std::numeric_limits<std::size_t>::max()){ + mLocalIdxRejeced = mRejectedNodes.size(); + mRejectedNodes.push_back(std::set<NodePtr>()); + }else{ + if(idxRejeced > mRejectedNodes.size()-1 ){ + throw std::runtime_error("FsmRunTimeContext idxRejeced"); + } + mLocalIdxRejeced =idxRejeced; + } +} + + + +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime){ + mActOpNode = fsmRunTime->mActOpNode; + mActState = fsmRunTime->mActState; + mCommonNodes = fsmRunTime->mCommonNodes; + mValidNodes = fsmRunTime->mValidNodes; + mLocalIdxRejeced = fsmRunTime->mLocalIdxRejeced; +} +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ){ + mActOpNode = actOpNode; + mActState = actState; + mCommonNodes = fsmRunTime->mCommonNodes; + mValidNodes = fsmRunTime->mValidNodes; + mLocalIdxRejeced = fsmRunTime->mLocalIdxRejeced; +} + +void FsmRunTimeContext::addRejectedNode(NodePtr node){ + mRejectedNodes[mLocalIdxRejeced].insert(node); +} + +std::set<NodePtr> FsmRunTimeContext::getRejectedNodes(void){ + return mRejectedNodes[mLocalIdxRejeced]; +} + +bool FsmRunTimeContext::isOnValidState(void){ + return mActState->isValid(); +} + +bool FsmRunTimeContext::isCommonDefined(NodePtr node){ + //return mCommonNodes.find(node) != mCommonNodes.end(); + + std::set<NodePtr> nodes = getCommonNodes(); + for(const auto& nodeC : nodes){ + if(nodeC.get() == node.get()){ + return true; + } + } + return false; +} + +bool FsmRunTimeContext::isAlreadyValid(NodePtr node){ + + std::set<NodePtr> nodes = getValidNodes(); + for(const auto& nodeV : nodes){ + if(nodeV.get() == node.get()){ + return true; + } + } + return false; + + //return getValidNodes().find(node) != getValidNodes().end(); +} + +bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext){ + /* + see if 2 context can be merge + it need to have different mValidNodes exept for common + and the same idx for the common + */ + + //common node + + for (const auto& ref : getCommon()) { + for (const auto& test : fsmContext->getCommon()) { + //same index + if(ref.second == test.second){ + if(ref.first != test.first){ + return false; + } + } + } + } + + //valid nodes + std::set<NodePtr> commonElements; + std::set<NodePtr> A = getValidNodesNoCommon(); + std::set<NodePtr> B = fsmContext->getValidNodesNoCommon(); + std::set_intersection( + A.begin(),A.end(), + B.begin(), B.end(), + std::inserter(commonElements, commonElements.end()) + ); + + + if (!commonElements.empty()) { + return false; + } + + return true; +} + +bool FsmRunTimeContext::areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext){ + if(getActNode() != fsmContext->getActNode()){ + return false; + } + if (getActState() != fsmContext->getActState()){ + return false; + } + if (getValidNodes() != fsmContext->getValidNodes()){ + return false; + } + if (getCommon() != fsmContext->getCommon()){ + return false; + } + + + return true; +} + +void FsmRunTimeContext::setCommon(NodePtr node,std::size_t commonIdx){ + if(isCommonDefined(node)){ + if (mCommonNodes.at(node) != commonIdx){ + throw std::runtime_error("conflict idx in the Common node"); + } + }else{ + mCommonNodes[node] = commonIdx; + } +} + +void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag){ + //we already find a node of this type + if(mValidNodes.find(tag) != mValidNodes.end()){ + if(isAlreadyValid(node) && !isCommonDefined(node) ){ + throw std::runtime_error("setValid you valid tow time"); + } + mValidNodes[tag].insert(node); + }else{ + mValidNodes[tag] = {node}; + } + +} + +std::size_t FsmRunTimeContext::getSubStmId(void){ + return mActState->getOrigine(); +} + +NodePtr FsmRunTimeContext::getCommonNodeFromIdx(std::size_t commonIdx){ + for (const auto& pair : mCommonNodes) { + if (pair.second == commonIdx) { + return pair.first; // Return the key when the value is found + } + } + throw std::runtime_error("getCommonNodeFromIdx Value not found in the map"); +} + +std::size_t FsmRunTimeContext::getCommonNodeIdx(NodePtr node){ + if(isCommonDefined(node)){ + return mCommonNodes.at(node); + } + throw std::runtime_error("getCommonNodeIdx node not found"); +} + +std::set<NodePtr> FsmRunTimeContext::getCommonNodes(void){ + std::set<NodePtr> nodes; + // Iterate over the map and insert values into the set + for (const auto& pair : mCommonNodes) { + nodes.insert(pair.first); + } + return nodes; +} + +std::map<NodePtr,std::size_t> FsmRunTimeContext::getCommon(void){ + return mCommonNodes; +} + +std::set<NodePtr> FsmRunTimeContext::getValidNodes(void){ + + auto sharedSet = std::make_shared<std::set<NodePtr>>(); + // Create a set to store the values from the map + std::set<NodePtr> nodes; + // Iterate over the map and insert values into the set + for (const auto& pair : mValidNodes) { + nodes.insert(pair.second.begin(),pair.second.end()); + } + return nodes; +} + +std::set<NodePtr> FsmRunTimeContext::getValidNodesNoCommon(void){ + std::set<NodePtr> differenceSet; + std::set<NodePtr> valide = getValidNodes(); + std::set<NodePtr> common = getCommonNodes(); + std::set_difference(valide.begin(), valide.end(), common.begin(), common.end(),std::inserter(differenceSet, differenceSet.end())); + return differenceSet; +} + +std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> FsmRunTimeContext::getValid(void){ + return mValidNodes; +} + +NodePtr FsmRunTimeContext::getActNode(void){ + return mActOpNode; +} + +std::shared_ptr<FsmNode> FsmRunTimeContext::getActState(){ + return mActState; +} + + +void FsmRunTimeContext::rst(void){ + mRejectedNodes.clear(); +} + diff --git a/src/graphRegex/matchFsm/MatchResult.cpp b/src/graphRegex/matchFsm/MatchResult.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c35f1a7348e365baa8a27854ee6b0a833e342ee7 --- /dev/null +++ b/src/graphRegex/matchFsm/MatchResult.cpp @@ -0,0 +1,93 @@ +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" + +using namespace Aidge; + + +MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm):mIdToRunTime(nbSubStm){ + mAllValid = allValid; + mNbSubStm = nbSubStm; + + //mIdToRunTimm + for (const auto& contextPtr : allValid) { + mIdToRunTime[contextPtr->getSubStmId()].push_back(contextPtr); + } + + std::vector<std::shared_ptr<FsmRunTimeContext>> precedence; + //make all solution posible + _generateCombinationd(0,precedence); + //sort by solution number of elements + std::sort(mSolve.begin(), mSolve.end(), [](const std::set<NodePtr>& set1, const std::set<NodePtr>& set2) { + return set1.size() < set2.size(); + }); + + +} + +void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence){ + + //it's end , we are below the number of stm + if (idxSubStm == mNbSubStm) + { + //precedence containe a liste of FSM compatible, we just need to + //check if all the node have been valide by at least one contetext + + //1) make the set of all node for the comput graph that are valide in all the FsmRunTimeContext + std::set<NodePtr> validNode; + std::set<NodePtr> rejectNode; + for (const auto& contextPtr : precedence) { + std::set<NodePtr> tmpV = contextPtr->getValidNodes(); + validNode.insert(tmpV.begin(), tmpV.end()); + std::set<NodePtr> tmpR = contextPtr->getRejectedNodes(); + rejectNode.insert(tmpR.begin(),tmpR.end()); + } + // 2) all RejectedNodes need to be valide by an others stm + // if it's not the case the match is not valid + if(std::includes(validNode.begin(), validNode.end(), rejectNode.begin(), rejectNode.end())){ + //we can save the solution + mSolve.push_back(validNode); + } + precedence.pop_back(); + return; + } + + + for (const auto& contextPtrOneFsm : mIdToRunTime[idxSubStm]) + { + if(idxSubStm == 0){ + precedence.push_back(contextPtrOneFsm); + _generateCombinationd(idxSubStm+1,precedence); + + }else{ + //test if the new context is compatible whith all the context in the precedence + // + bool compatibleSolutionFsm = true; + for (const auto& contextPtrOfOtherFsm : precedence) { + if(!(contextPtrOneFsm->areCompatible(contextPtrOfOtherFsm))){ + compatibleSolutionFsm = false; + break; + } + } + + if(compatibleSolutionFsm){ + precedence.push_back(contextPtrOneFsm); + _generateCombinationd(idxSubStm+1,precedence); + } + + } + } + + if(idxSubStm != 0){ + precedence.pop_back(); + } + return; + +} + +std::set<NodePtr> MatchResult::getBiggerSolution(void){ + if(mSolve.empty()){ + return std::set<NodePtr>(); + }else{ + return mSolve[0]; + } + +} \ No newline at end of file diff --git a/src/graphmatching/NodeRegex.cpp b/src/graphmatching/NodeRegex.cpp index bbb116d1b12a31b491b26d2a64d04b416b61c6b7..9bf164f60255c17492e528b0f27dec8c53f74979 100644 --- a/src/graphmatching/NodeRegex.cpp +++ b/src/graphmatching/NodeRegex.cpp @@ -12,7 +12,7 @@ #include "aidge/graphmatching/NodeRegex.hpp" -// Verification done by the Parameter system +// Verification done by the Attribute system // Version 1 - Only test the type of the node (no need for a lexer) @@ -39,8 +39,8 @@ bool Aidge::NodeRegex::isA(std::string NodeType){ /**bool NodeRegex::_is(string &Node_op){ // Parsing the condition is done in the initialization of the NodeRegex - // assert parameters exist in the node with the parameter function isParam() + // assert attributes exist in the node with the attribute function hasAttr() - // get the parameters + // get the attributes }*/ diff --git a/src/nodeTester/ConditionalInterpreter.cpp b/src/nodeTester/ConditionalInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e01bdd76a28576451a1a09202d5fd1e87a4856e5 --- /dev/null +++ b/src/nodeTester/ConditionalInterpreter.cpp @@ -0,0 +1,344 @@ + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + +using namespace Aidge; + + +/////////////////////////////// +//ConditionalRegisterFunction +/////////////////////////////// + + ConditionalData* ConditionalRegisterFunction::run(const std::string key,std::vector<ConditionalData*> & datas){ + + auto lambdaIt = mWlambda.find(key); + if (lambdaIt != mWlambda.end()) { + return lambdaIt->second(datas); + }else { + throw std::runtime_error("can not run Lambda due to invalid key: " + key); + } + } + +////////////////////// +//ConditionalInterpreter +/////////////////////// + ConditionalInterpreter::ConditionalInterpreter(const std::string ConditionalExpressions) + :mLambdaRegiter() + { + + ConditionalParser conditionalParser = ConditionalParser(ConditionalExpressions); + mTree = conditionalParser.parse(); + ///lambda by default + mLambdaRegiter.insert("getType",+[](NodePtr NodeOp){return NodeOp->type();}); + + } + + + bool ConditionalInterpreter::test( const NodePtr nodeOp) + { + + clearRes(); + try{ + std::vector<ConditionalData*> r = visit({mTree},nodeOp); + + if (mResolution.size() != 1){ + throw std::runtime_error("Multi-output interpretation output"); + }else{ + if (!mResolution[0]->isTypeEqualTo<bool>()){ + throw std::runtime_error("TEST OUT MUST BE A BOOL "); + }else{ + return mResolution[0]->getValue<bool>(); + } + } + + }catch(const std::exception& e){ + std::ostringstream errorMessage; + errorMessage << "Error in test " << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + } + + void ConditionalInterpreter::insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f){ + mLambdaRegiter.insert<std::function<bool(Aidge::NodePtr)> >(key, f); + } + + ///// + std::vector<ConditionalData*> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){ + std::vector<ConditionalData*> dataVector; + + for ( std::shared_ptr<AstNode<ConditionalTokenTypes>> node : nodes) { + try{ + switch (node->getType()){ + /////////////////////////////////// + //OPERATOR + /////////////////////////////////// + case ConditionalTokenTypes::NOT: + { + visit(node->getChilds(),nodeOp); + fNot(); + } + break; + case ConditionalTokenTypes::AND: + { + visit(node->getChilds(),nodeOp); + fAnd(); + } + break; + case ConditionalTokenTypes::OR: + { + visit(node->getChilds(),nodeOp); + fOr(); + } + break; + case ConditionalTokenTypes::EQ: + { + visit(node->getChilds(),nodeOp); + fEq(); + //dataVector.insert(dataVector.end(), tmp.begin(), tmp.end()); + } + break; + case ConditionalTokenTypes::NEQ: + { + visit(node->getChilds(),nodeOp); + fNeq(); + } + break; + + /////////////////////////////////// + //VALUE + /////////////////////////////////// + + case ConditionalTokenTypes::KEY: + + break; + case ConditionalTokenTypes::INTEGER: + { + fStrToInteger(node); + } + break; + case ConditionalTokenTypes::FLOAT: + { + fStrToFloat(node); + + } + break; + case ConditionalTokenTypes::STRING: + { + fStrToStr(node); + } + break; + + case ConditionalTokenTypes::NODE: //TODO + { + + ConditionalData* data = new ConditionalData; + data->setValue<NodePtr>(nodeOp); + mResolution.push_back(data); + + } + break; + + case ConditionalTokenTypes::LAMBDA: + { + visit(node->getChilds(),nodeOp); + fLambda(node); + + } + break; + + case ConditionalTokenTypes::BOOL: //TODO + { + ConditionalData* data = new ConditionalData; + + if(node->getValue() == "true"){ + data->setValue<bool>(true); + }else{ + data->setValue<bool>(false); + } + + mResolution.push_back(data); + + } + break; + + case ConditionalTokenTypes::ARGSEP: + case ConditionalTokenTypes::LPAREN: + case ConditionalTokenTypes::RPAREN: + case ConditionalTokenTypes::STOP: + default: + throw std::runtime_error("NODE TYPE NOT SUPORTED IN ConditionalInterpreter"); + } + }catch(const std::exception& e){ + std::ostringstream errorMessage; + errorMessage << "Error in visiting AST for node"<< nodeOp->name() << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + } + + return dataVector; + } + + + ////////////////////// + //value convertor + ///////////////////// + + + void ConditionalInterpreter::fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + ConditionalData* data = new ConditionalData; + data->setValue<int>(std::stoi(node->getValue())); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + + ConditionalData* data = new ConditionalData; + data->setValue<float>(std::stof(node->getValue())); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + ConditionalData* data = new ConditionalData; + data->setValue<std::string>(node->getValue()); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + //if the lambda have input + ConditionalData* data; + try { + data = mLambdaRegiter.run(node->getValue(),mResolution); + } catch (const std::exception& e) { + std::ostringstream errorMessage; + errorMessage << "Error in conditional interpretation when run the "<< node->getValue() <<" Lambda\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fEq(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("EQ need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + if (a->getType() != b->getType()){ + throw std::runtime_error("EQ Unsuported between type :" + a->getType() +" "+ b->getType()); + } + + + + ConditionalData* data = new ConditionalData; + + if (a->isTypeEqualTo<int>()) { + data->setValue<bool>( a->getValue<int>() == b->getValue<int>()); + }else if (a->isTypeEqualTo<float>()){ + data->setValue<bool>( a->getValue<float>() == b->getValue<float>()); + }else if (a->isTypeEqualTo<std::string>()){ + data->setValue<bool>( a->getValue<std::string>() == b->getValue<std::string>()); + }else if (a->isTypeEqualTo<bool>()){ + data->setValue<bool>( a->getValue<bool>() == b->getValue<bool>()); + }else{ + throw std::runtime_error("EQ Unknown type encountered :" + a->getType() ); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fNeq(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("NEQ need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + if (a->getType() != b->getType()){ + throw std::runtime_error("NEQ Unsuported between type :" + a->getType() +" "+ b->getType()); + } + + ConditionalData* data = new ConditionalData; + + if (a->isTypeEqualTo<int>()) { + data->setValue<bool>( a->getValue<int>() != b->getValue<int>()); + }else if (a->isTypeEqualTo<float>()){ + data->setValue<bool>( a->getValue<float>() != b->getValue<float>()); + }else if (a->isTypeEqualTo<std::string>()){ + data->setValue<bool>( a->getValue<std::string>() != b->getValue<std::string>()); + }else + { + throw std::runtime_error("NEQ Unknown type encountered :" + a->getType() ); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fAnd(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("AND need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + + if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ + throw std::runtime_error("AND Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( a->getValue<bool>() && b->getValue<bool>()); + + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fOr(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("OR need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + + if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ + throw std::runtime_error("OR Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( a->getValue<bool>() || b->getValue<bool>()); + + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fNot() + { + if (mResolution.size() != 1){ + throw std::runtime_error("NOT need 1 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + + if (a->getType() != typeid(bool).name()){ + throw std::runtime_error("NOT Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( !a->getValue<bool>() ); + + clearRes(); + mResolution.push_back(data); + + } diff --git a/src/nodeTester/ConditionalLexer.cpp b/src/nodeTester/ConditionalLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9379bd8409f8f7ec4bae3e0122f88de79718e9dd --- /dev/null +++ b/src/nodeTester/ConditionalLexer.cpp @@ -0,0 +1,242 @@ +#include "aidge/nodeTester/ConditionalLexer.hpp" + +using namespace Aidge; + +////////////////// +//ConditionalLexer +////////////////// + + +ConditionalLexer::ConditionalLexer( const std::string ConditionalExpressions): +mConditionalExpressions(ConditionalExpressions) +{ + mPosition = 0; +} + +std::shared_ptr<ParsingToken<ConditionalTokenTypes>> ConditionalLexer::getNextToken(void){ + std::string currentChars = ""; + + while (mPosition < mConditionalExpressions.length()) + { + //erase all space + if (mConditionalExpressions[mPosition] != ' ') + { + currentChars += mConditionalExpressions[mPosition]; + } + else + { + mPosition++; + continue; + } + //performe tokenisation, find a regex and make a new token + + if (std::regex_match(currentChars,std::regex("\\&\\&")))// the AND TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::AND,""); + } + else if (std::regex_match(currentChars,std::regex("\\|\\|")))// the OR TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::OR,""); + } + else if (std::regex_match(currentChars,std::regex("\\!")))// the Not and not equ + { + mPosition++; + if ( mPosition < mConditionalExpressions.length()){ + currentChars += mConditionalExpressions[mPosition]; + if(std::regex_match(currentChars,std::regex("!="))){ + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NEQ,""); + }else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NOT,""); + } + } + //a not at the end not ok but it's the parseur work + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NOT,""); + } + else if (std::regex_match(currentChars,std::regex("==")))// the EQ TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::EQ,""); + } + else if (std::regex_match(currentChars,std::regex("\\(")))// the LPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::LPAREN,""); + } + else if (std::regex_match(currentChars,std::regex("\\)")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::RPAREN,""); + } + else if (std::regex_match(currentChars,std::regex(",")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::ARGSEP,""); + } + else if (std::regex_match(currentChars,std::regex("\\$")))// the ACTNode TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NODE,""); + } + + + ///// + //non const lent token + ///// + + //LAMBDA, KEY , bool //the fuction TAG + else if (std::regex_match(currentChars,std::regex("[A-Za-z_]")))// the KEY TOKEN (a char next ) + { + //read all the key + bool isLambda = false; + std::regex keyRegex("[A-Za-z_0-9]+"); + std::regex LambdaRegex("[A-Za-z_0-9]+\\("); + + while ( mPosition < mConditionalExpressions.length()) { + if(!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,LambdaRegex)) + { + currentChars.pop_back(); //the last char is the problemes + break; + } + else if (std::regex_match(currentChars,LambdaRegex)){ + isLambda = true; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mConditionalExpressions.length()-1) + { + if (!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,LambdaRegex)) + { + throw badTokenError(currentChars,mPosition); + } + //mPosition++; // we stop all by going pos > lengt + } + + + if (std::regex_match(currentChars,std::regex("(true|false)"))){ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::BOOL,currentChars); + + } else if (isLambda){ + currentChars.pop_back();//pop the ( of the lambda + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::LAMBDA,currentChars); + } else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::KEY,currentChars); + } + + } + //numeric value + else if (std::regex_match(currentChars,std::regex("[0-9]")))// the KEY TOKEN (a char next ) + { + //read all the key + bool isFloat = false; + std::regex integerRegex("[0-9]+$"); + std::regex floatRegex("[0-9]+\\.[0-9]*$"); + + while ( mPosition < mConditionalExpressions.length()) { + + if(!std::regex_match(currentChars,integerRegex) && !std::regex_match(currentChars,floatRegex)) + { + currentChars.pop_back(); // the last char match is not a good one + break; + } + else if (std::regex_match(currentChars,floatRegex)){ + isFloat = true; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mConditionalExpressions.length()-1) + { + if (!std::regex_match(currentChars,integerRegex) && !std::regex_match(currentChars,floatRegex)) + { + throw badTokenError(currentChars,mPosition); + } + } + + if(isFloat){ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::FLOAT,currentChars); + }else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::INTEGER,currentChars); + } + + } + //string TODO + else if (std::regex_match(currentChars,std::regex("\'"))) // TODO ' or \' + { + std::regex strRegex("\'[A-Za-z_0-9\\s]*\'$"); + while ( mPosition < mConditionalExpressions.length()) { + if(std::regex_match(currentChars,strRegex)){ + break; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + + //test the end condition + if (mPosition == mConditionalExpressions.length()-1 ){ + if (!std::regex_match(currentChars,strRegex)){ + throw badTokenError(currentChars,mPosition); + } + //mPosition++; // we stop all by going pos > lengt + } + + mPosition++; // go after the last " + //erase the " char + currentChars.pop_back(); + currentChars.erase(0,1); + + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::STRING,currentChars); + + } + + //Array TODO + + mPosition++; + } + + //no more to find no one match the currentChars + if (currentChars.empty()) { + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::STOP,""); // Null shared pointer ; + }else{ + //std::ostringstream errorMessage; + //errorMessage << "\nBad syntax " << currentChars << " :\n" << mConditionalExpressions; + throw badTokenError(currentChars,mPosition); + } + +} + +void ConditionalLexer::rstPosition(void){ + if (isEnd()){ + mPosition = 0; + }else{ + throw badTokenError("end rst",mPosition); + } + +} + +bool ConditionalLexer::isEnd(void){ + return mPosition >= mConditionalExpressions.length(); +} + +std::runtime_error ConditionalLexer::badTokenError(const std::string& currentChars,std::size_t position){ + std::ostringstream errorMessage; + errorMessage << "\nBad syntax " << currentChars << " :\n" << mConditionalExpressions << "\n"; + for (std::size_t i = 0; i < position; i++) { + errorMessage << ' '; + } + errorMessage << "^\n"; + + return std::runtime_error(errorMessage.str()); +} \ No newline at end of file diff --git a/src/nodeTester/ConditionalParser.cpp b/src/nodeTester/ConditionalParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ca2843aabefe9f98bc8ad46a36fe03883d0baef --- /dev/null +++ b/src/nodeTester/ConditionalParser.cpp @@ -0,0 +1,188 @@ + +#include "aidge/nodeTester/ConditionalParser.hpp" + +using namespace Aidge; + + +////////////////////////////// +//ConditionalParser +////////////////////////////// + +ConditionalParser::ConditionalParser(const std::string ConditionalExpressions):mLexer(ConditionalExpressions){ + mCurrentToken = mLexer.getNextToken(); +} + +void ConditionalParser::rstParser(void){ + mLexer.rstPosition(); + mCurrentToken = mLexer.getNextToken(); +} + +void ConditionalParser::ackToken(ConditionalTokenTypes tokenType){ + if(mCurrentToken->getType() == tokenType ){ + + try { + mCurrentToken = mLexer.getNextToken(); + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "Conditional Lexer error in Parser :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + }else{ + + std::ostringstream errorMessage; + errorMessage << "Bad syntax ConditionalParser " << static_cast<int>(mCurrentToken->getType()) <<"!="<< static_cast<int>(tokenType) << "\n"; + errorMessage << mLexer.rep(); + throw std::runtime_error(errorMessage.str()); + } +} + + + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstVal(void){ + /* + val : (KEY|INTEGER|FOAT|STRING|LAMBDA) + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + + if (token->getType() == ConditionalTokenTypes::KEY){ + ackToken(ConditionalTokenTypes::KEY); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::INTEGER){ + ackToken(ConditionalTokenTypes::INTEGER); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::FLOAT){ + ackToken(ConditionalTokenTypes::FLOAT); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::BOOL){ + ackToken(ConditionalTokenTypes::BOOL); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::STRING){ + ackToken(ConditionalTokenTypes::STRING); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + + }else if(token->getType() == ConditionalTokenTypes::NODE){ + ackToken(ConditionalTokenTypes::NODE); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + + }else if(token->getType() == ConditionalTokenTypes::LAMBDA){ + return constructAstLambda(); + } + + throw std::runtime_error("ConditionalParser unknow val type "+ token->rep().str() + "\n" + mLexer.rep()); + +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstLambda(void){ + /* + AstLambda : LAMBDA val (ARGSEP val)* RPAREN + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> tokenLdb = mCurrentToken->copy(); + ackToken(ConditionalTokenTypes::LAMBDA); + ASTNodeCh paramLambda; + //AT LEAST ONE VALUE AS INPUT OF A LAMBDA + paramLambda.push_back(constructAstVal()); + while (mCurrentToken->getType() != ConditionalTokenTypes::RPAREN) + { + ackToken(ConditionalTokenTypes::ARGSEP); + paramLambda.push_back(constructAstVal()); + } + ackToken(ConditionalTokenTypes::RPAREN); + return std::make_shared<AstNode<ConditionalTokenTypes>>(tokenLdb,paramLambda); +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstCmpr(void){ + /* + cmpr : val (EQ|NEQ) val | LPAREN expr RPAREN + NOT ir ? + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + //we can check the type relation ir key (EQ|NEQ) val | val (EQ|NEQ) key , but val (EQ|NEQ) val is valid ? + if (token->getType() == ConditionalTokenTypes::LPAREN) + { + ackToken(ConditionalTokenTypes::LPAREN); + std::shared_ptr<AstNode<ConditionalTokenTypes>> node = constructAstExpr(); + ackToken(ConditionalTokenTypes::RPAREN); + return node; + }else{ + + std::shared_ptr<AstNode<ConditionalTokenTypes>> node = constructAstVal(); + token = mCurrentToken->copy(); + if (token->getType() == ConditionalTokenTypes::EQ){ + ackToken(ConditionalTokenTypes::EQ); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{node,constructAstVal()}); + }else if(token->getType() == ConditionalTokenTypes::NEQ){ + ackToken(ConditionalTokenTypes::NEQ); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{node,constructAstVal()}); + }else{ + + throw std::runtime_error("constructAstCmpr "+ token->rep().str() + "\n" + mLexer.rep()); + } + + } +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstExpr(std::size_t precLimit /*= 0*/){ + /* + expr : cmpr ((AND | OR) cmpr)* + the NOT is not binary OP can be use in pratt + precedence H to L: TODO + AND + OR + */ + + //the not + std::shared_ptr<AstNode<ConditionalTokenTypes>> left; + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + + if (mCurrentToken->getType() == ConditionalTokenTypes::NOT ){ + ackToken(ConditionalTokenTypes::NOT ); + left= std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{constructAstCmpr()}); + }else{ + left= constructAstCmpr(); + } + + //pratt + while (mCurrentToken->getType() != ConditionalTokenTypes::STOP ) //security + { + token = mCurrentToken->copy(); + //if the token is not in the map is not a operator so we consider a prec of 0 + if (ConditionalPrec.find(token->getType()) ==ConditionalPrec.end() ){ + return left; + } + + //if my actual operator have a prec <= of the last operator + std::size_t prec = ConditionalPrec.at(token->getType()); + if (prec <= precLimit){ + return left; + } + + //Act all AND and OR + ackToken(token->getType()); + + std::shared_ptr<AstNode<ConditionalTokenTypes>> right = constructAstExpr(prec); + + //i'm not sur what append to newNode + //std::shared_ptr<AstNode<ConditionalTokenTypes>> newNode = std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{left,constructAstCmpr()}); + std::shared_ptr<AstNode<ConditionalTokenTypes>> newNode = std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{left,right}); + left = newNode; + } + return left; +} + + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::parse(void){ + /* + expr : cmpr ((AND | OR) cmpr)* + cmpr : val (EQ|NEQ) val | LPAREN expr RPAREN | BOOL | LAMBDA + val : (KEY|INTEGER|FOAT|STRING|LAMBDA) + lambda : LAMBDA val (ARGSEP val)* RPAREN + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> astTree = constructAstExpr(); + + rstParser(); + return astTree; +} \ No newline at end of file diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..192036651cfbe2df71139dd63ca3d71f07300964 --- /dev/null +++ b/src/operator/GenericOperator.cpp @@ -0,0 +1,17 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <vector> + +#include "aidge/operator/GenericOperator.hpp" + +const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Identity + = [](const std::vector<std::vector<size_t>>& inputsDims) { return inputsDims; }; diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1f58c68686d9359fa3b8ea4b5eb54244e988895 --- /dev/null +++ b/src/operator/MetaOperator.cpp @@ -0,0 +1,141 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/utils/ErrorHandling.hpp" + +Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, + std::vector<NodePtr> inputNodes, + std::vector<NodePtr> outputNodes) + : Operator(type), + mGraph(graph) +{ + mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); + for (std::size_t i = 0; i < mInputs.size(); ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size()); + for (std::size_t i = 0; i < mOutputs.size(); ++i) { + mOutputs[i] = std::make_shared<Tensor>(); + } + + // Fill inputsNodes and outputsNodes when there is no ambiguity + if (inputNodes.empty()) { + AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping"); + inputNodes.push_back(*mGraph->inputNodes().begin()); + } + + if (outputNodes.empty()) { + AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping"); + outputNodes.push_back(*mGraph->outputNodes().begin()); + } + + AIDGE_ASSERT(mGraph->inputNodes().size() == inputNodes.size(), "wrong number of specified input nodes"); + AIDGE_ASSERT(mGraph->outputNodes().size() == outputNodes.size(), "wrong number of specified output nodes"); + + // Identify inputs that are outside the micro-graph + for (const auto& inputNode : inputNodes) { + AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph"); + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inputNode->inputs(); + + int inputIdx = 0; // input idx relative to the current node + for (const auto& in : inputNodeinputs) { + if (in.first == nullptr || !mGraph->inView(in.first)) { + // The input is not connected inside the micro-graph + // (no connection to this input or connection outside the micro-graph) + // => it is therefore an input for the meta-operator + mInputOps.push_back(std::make_pair(inputNode->getOperator(), inputIdx)); + } + + ++inputIdx; + } + } + + // The outputs of the output nodes are also the outputs of the meta-operator + for (const auto& outputNode : outputNodes) { + AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph"); + const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs = + outputNode->outputs(); + + for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) { + mOutputOps.push_back(std::make_pair(outputNode->getOperator(), outputIdx)); + } + } + + AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); + AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbRequiredData(inputIdx); + } + else { + const auto& inputOp = mInputOps[inputIdx]; + return inputOp.first->getNbRequiredData(inputOp.second); + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbConsumedData(inputIdx); + } + else { + const auto& inputOp = mInputOps[inputIdx]; + return inputOp.first->getNbConsumedData(inputOp.second); + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { + if (mImpl) { + return mImpl->getNbProducedData(outputIdx); + } + else { + const auto& outputOp = mOutputOps[outputIdx]; + return outputOp.first->getNbProducedData(outputOp.second); + } +} + +void Aidge::MetaOperator_Op::updateConsummerProducer() { + if (mImpl) { + mImpl->updateConsummerProducer(); + } + else { + if (!mScheduler) { + // Lazy initialization + mScheduler = std::make_shared<SequentialScheduler>(mGraph); + } + + // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule. + // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()" + mScheduler->generateScheduling(); + } +} + +void Aidge::MetaOperator_Op::forward() { + if (mImpl) { + // A custom implementation exists for this meta operator + mImpl->forward(); + } + else { + // No custom implementation, use the individual operators implementations + if (!mScheduler) { + // Lazy initialization + // TODO: should we assert that a scheduler already exists at this point? + // => should be created in updateConsummerProducer() + mScheduler = std::make_shared<SequentialScheduler>(mGraph); + mScheduler->generateScheduling(); + } + + mScheduler->forward(false); + } +} diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 99b07235e2917527160f03af997747f02947dcf9..09a17a428e1de91c0318f710e6f097573cf529a6 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -38,7 +38,18 @@ Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) co Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { return mImpl->getNbProducedData(outputIdx); } +void Aidge::Operator::updateConsummerProducer(){ + mImpl->updateConsummerProducer(); +} -void Aidge::Operator::forward() { mImpl->forward(); } +void Aidge::Operator::runHooks() const { + for (auto& hook : mHooks) { + hook.second->call(); + } +} +void Aidge::Operator::forward() { + mImpl->forward(); + runHooks(); +} void Aidge::Operator::backward() { mImpl->backward(); } diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5e59582af68f66e6c54d09fac4cb1cc028493dd --- /dev/null +++ b/src/recipies/FuseBatchNorm.cpp @@ -0,0 +1,146 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <set> +#include <cassert> +#include <memory> +#include <string> +#include "aidge/operator/FC.hpp" +#include "aidge/operator/BatchNorm.hpp" +#include "aidge/operator/Conv.hpp" + +#include "aidge/utils/Recipies.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/GenericOperator.hpp" +// Graph Regex +#include "aidge/graphmatching/GRegex.hpp" +#include "aidge/graphmatching/NodeRegex.hpp" +using namespace Aidge; + +void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ + + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); + + // Assert the nodes types are correct to be fused + std::shared_ptr<Node> conv; + std::shared_ptr<Node> batchnorm; + for (const auto& element : nodes) { + assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace"); + if (element->type() == "Conv"){ + conv = element; + } + else if (element->type() == "BatchNorm") { + batchnorm = element; + } + } + // TODO : check if batchnorm is the only child of the Conv or FC + std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); + std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); + std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second); + std::shared_ptr<Tensor> b_var = batchnorm->input(4).first->getOperator()->getOutput(batchnorm->input(4).second); + + + // TODO : Find a way to remove the template + const float epsilon = std::static_pointer_cast<BatchNorm_Op<2>>(batchnorm->getOperator())->getAttr<float>("Epsilon"); + DimSize_t convOutDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("OutChannels"); + + + assert(scale->size() == convOutDims); + assert(shift->size() == convOutDims); + assert(b_mean->size() == convOutDims); + assert(b_var->size() == convOutDims); + assert(epsilon > 0.0); + // TODO : no no_bias attribute ? + float meanVariance = 0.0; + unsigned int count = 0; + + for (std::size_t output = 0; output < convOutDims; ++output) { + // TODO : get suppose datatype is float .. + if (b_var->get<float>(output) > 1.0e-12) { + meanVariance += b_var->get<float>(output); + ++count; + } + else { + printf("Zero-variance: %s [%lu]\n", conv->name().c_str(), output); + } + } + if (count > 0) + meanVariance /= count; + else { + printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n"); + } + + const DimSize_t channelsSize = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels"); + + // TODO : suppose we have Conv2D ... + const std::array<DimSize_t, 2> kernelDims = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims"); + + std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second); + std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second); + + for (std::size_t output = 0; output < convOutDims; ++output) { + // Corrected for zero-variance issue: + // "A Quantization-Friendly Separable Convolution for MobileNets" + // https://arxiv.org/pdf/1803.08607.pdf + // to help post-training quantization + const float factor = scale->get<float>(output) + / std::sqrt(epsilon + ((b_var->get<float>(output) > 1.0e-12 || count == 0) + ? b_var->get<float>(output) : meanVariance)); + // Weights adjustments + for (std::size_t channel = 0; channel < channelsSize; ++channel) { + // TODO : Suppose kerneldims = 2 + for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ + for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ + std::vector<DimSize_t> currentIdx = {output, channel, k0, k1}; + // TODO : suppose weights are float + float weightValue = weight->get<float>(currentIdx); + weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights + } + } + } + + // TODO : check if noBias==true is set, then set biasValue to 0 + float biasValue = bias->get<float>(output); + + biasValue = shift->get<float>(output) + (biasValue - b_mean->get<float>(output)) * factor; + + bias->set<float>(output, biasValue); + + } + auto g = std::make_shared<GraphView>(); + g->add(std::set<std::shared_ptr<Node>>({ + batchnorm, + batchnorm->input(1).first, + batchnorm->input(2).first, + batchnorm->input(3).first, + batchnorm->input(4).first + })); + g->replaceWith({}); + +} + +void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ + std::map<std::string,NodeRegex*> nodesRegex ; + nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm"); + nodesRegex["Conv"] = new NodeRegex("Conv"); + nodesRegex["FC"] = new NodeRegex("FC"); + + + std::vector<std::string> seqRegex; + seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC) + GRegex GReg(nodesRegex, seqRegex); + Match matches = GReg.match(graphView); + std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + for (size_t i = 0; i < matches.getNbMatch(); ++i) { + fuseBatchNorm(matchNodes[i]); + } +} diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index dc565bf0acc7747d79ec12df973a82d86fc79503..1de79890f9b597c4baff7427e01d7217f9695a44 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -20,21 +20,18 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" - +// Graph Regex +#include "aidge/graphmatching/GRegex.hpp" +#include "aidge/graphmatching/NodeRegex.hpp" using namespace Aidge; -/** - * @brief Merge MatMul and Add Node into FC. - * - * @param nodes Strict set of Node to merge. - */ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? - + // Step 0 : Assert the nodes types are correct to be fused std::shared_ptr<Node> add; std::shared_ptr<Node> matmul; @@ -53,18 +50,20 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ auto producer_add_bias = add->input(1); Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0); - // Instanciate FC + // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false)); // Step 2 : Branch existing producers & create the others // link weights & bias - if (matmul->getParents(1)==nullptr) { - matmul->getParents(0)->addChild(fc, 0, 1); + if (matmul->getParent(1)==nullptr) { + matmul->getParent(0)->addChild(fc, 0, 1); + printf("MatMul out[1] == nullptr !\n"); } else { - if (matmul->getParents(0)!=nullptr) - matmul->getParents(0)->addChild(fc, 0, 0); - matmul->getParents(1)->addChild(fc, 0, 1); + printf("MatMul out[1] != nullptr !\n"); + if (matmul->getParent(0)!=nullptr) + matmul->getParent(0)->addChild(fc, 0, 0); + matmul->input(1).first->addChild(fc, 0, 1); } (producer_add_bias.first)->addChild(fc,0,2); @@ -74,7 +73,22 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ? auto nodeToReplace = std::make_shared<GraphView>(); - nodeToReplace->add(nodes); + nodeToReplace->add(nodes, false); nodeToReplace->replaceWith({fc}); -} \ No newline at end of file +} + +void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ + + std::map<std::string,NodeRegex*> nodesRegex ; + nodesRegex["MatMul"] = new NodeRegex("MatMul"); + nodesRegex["Add"] = new NodeRegex("Add"); + std::vector<std::string> seqRegex; + seqRegex.push_back("MatMul -> Add;"); + GRegex GReg(nodesRegex, seqRegex); + Match matches = GReg.match(graphView); + std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + for (size_t i = 0; i < matches.getNbMatch(); ++i) { + fuseMulAdd(matchNodes[i]); + } +} diff --git a/src/recipies/LabelGraph.cpp b/src/recipies/LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..369336f7981198f962d8ab949309005be9ac5eb9 --- /dev/null +++ b/src/recipies/LabelGraph.cpp @@ -0,0 +1,56 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/recipies/LabelGraph.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" + +Aidge::NodePtr Aidge::nodeLabel(NodePtr node) { + // Conv => MaxPooling + if (node->type() == Conv_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->getAttr<ConvAttr::KernelDims>(), op->getAttr<ConvAttr::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // ConvDepthWise => MaxPooling + if (node->type() == ConvDepthWise_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<ConvDepthWise_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->getAttr<ConvDepthWiseAttr::KernelDims>(), op->getAttr<ConvDepthWiseAttr::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // AvgPooling => MaxPooling + if (node->type() == AvgPooling_Op<2>::Type) { + auto op = std::dynamic_pointer_cast<AvgPooling_Op<2>>(node->getOperator()); + + auto newOp = std::make_shared<MaxPooling_Op<2>>(op->getAttr<AvgPoolingAttr::KernelDims>(), op->getAttr<AvgPoolingAttr::StrideDims>()); + return std::make_shared<Node>(newOp, node->name()); + } + + // MaxPooling => MaxPooling + if (node->type() == MaxPooling_Op<2>::Type) { + return node->clone(); + } + + // By default, remove the node from the graph + return nullptr; +} + +std::shared_ptr<Aidge::GraphView> Aidge::labelGraph(std::shared_ptr<GraphView> graph) { + return graph->cloneCallback(&nodeLabel); +} diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index cc3c3324e40636a1edcbc73cdc4a9dcfeec8a026..9096c107ba505f5f18993a761273552408db721b 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -15,10 +15,38 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/utils/Recipies.hpp" +// Graph Regex +#include "aidge/graphmatching/GRegex.hpp" +#include "aidge/graphmatching/NodeRegex.hpp" + + namespace Aidge { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); + std::shared_ptr<Node> flatten; + for (const auto& element : nodes) { + assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); + if (element->type() == "Flatten"){ + flatten = element; + } + } auto g = std::make_shared<GraphView>(); - g->add(std::set<std::shared_ptr<Node>>({nodes})); + // TODO : avoid using replace_with and use a remove method instead + g->add(std::set<std::shared_ptr<Node>>({flatten})); g->replaceWith({}); } -} \ No newline at end of file + + void removeFlatten(std::shared_ptr<GraphView> graphView){ + std::map<std::string,NodeRegex*> nodesRegex ; + nodesRegex["Flatten"] = new NodeRegex("Flatten"); + nodesRegex["FC"] = new NodeRegex("FC"); + std::vector<std::string> seqRegex; + seqRegex.push_back("Flatten->FC;"); + GRegex GReg(nodesRegex, seqRegex); + Match matches = GReg.match(graphView); + std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + for (size_t i = 0; i < matches.getNbMatch(); ++i) { + removeFlatten(matchNodes[i]); + } + } +} diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index fce46397ffd286a2ddbe254752b241578415e3d8..1f34091e54c0f83dae6b60589c20fb8fdf1d5064 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -20,7 +20,7 @@ #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" -void drawProgressBar(double progress, int barWidth, const char* additionalInfo = nullptr) { +void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); int pos = static_cast<int>(barWidth * progress); for (int i = 0; i < barWidth; ++i) { @@ -29,30 +29,25 @@ void drawProgressBar(double progress, int barWidth, const char* additionalInfo = else putchar(' '); } - printf("] %d%% | %s\r", static_cast<int>(progress * 100), (additionalInfo ? additionalInfo : "")); + printf("] %d%% | %s\r", static_cast<int>(progress * 100), additionalInfo.c_str()); fflush(stdout); } -// TODO: handle multiple inputs/outputs -void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { - if (frowardDims) {mGraphView->forwardDims(); } - - mScheduling.clear(); +void Aidge::SequentialScheduler::generateScheduling(bool verbose) { + // TODO: For loop on the list of node to run + // run sequencially every runnable consumers once + // TODO: handle memory allocation in scheduler + // TODO: optimize memory usage // setup initial producers list - // add each Producer Node. - std::set<std::shared_ptr<Node>> computationOver; - std::size_t computationNumber = 0; std::set<std::shared_ptr<Node>> producers; for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { if (nodePtr->type() == "Producer") { producers.insert(nodePtr); - } else { - ++computationNumber; } } // add Data Input - // FIXME : shoudl be changed when the real system for providing + // FIXME : should be changed when the real system for providing // data is implemented for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) { for (const auto& parentPtr : nodePtr->getParents()) { @@ -81,16 +76,16 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { "\n\t\tR/C:\t", (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); } - printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + printf("%zu/%zu", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); printf("\n\t\tP:\t"); for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + printf("%zu\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } - printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); printf("\n"); } bool isRunnable = true; @@ -112,22 +107,11 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { } } - // run sequencially every runnable consumers once - // TODO: handle memory allocation in scheduler - // TODO: optimize memory usage + // Push consumers in the list of nodes to run and update the consumer producer system for (const auto& runnable : runnableConsumers) { - if (verbose) - printf("run: %s\n", - (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); - else - drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(computationNumber), 50, - (std::string("running ") + runnable->type() + "_" + - std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))) - .c_str()); - const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); + if (verbose) printf("Runnable: %s\n", (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); + runnable->getOperator()->updateConsummerProducer(); + mStaticSchedule.push_back(runnable); } // update producers and consumers list @@ -142,13 +126,13 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); } - printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + printf("%zu/%zu", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); printf("\n\t\tP:\t"); for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + printf("%zu\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } - printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); printf("\n"); } bool isStillConsumer = false; @@ -165,18 +149,6 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { } } - bool computationOverForConsumer = true; - for (IOIndex_t parentIDi = 0; parentIDi < consumer->nbInputs(); ++parentIDi) { - if (consumer->getOperator()->getNbConsumedData(parentIDi) < - consumer->getOperator()->getNbRequiredData(parentIDi)) { - computationOverForConsumer = false; - break; - } - } - if (computationOverForConsumer) { - computationOver.insert(consumer); - } - for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { if (consumer->getOperator()->getNbProducedData(outId) > 0) { if (verbose) printf(" also producer\n"); @@ -198,13 +170,46 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { if (verbose) printf("*************\n"); } while (!consumers.empty()); + +} + +// TODO: handle multiple inputs/outputs +void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { + // Forward dims (if allowed) + if (forwardDims) {mGraphView->forwardDims(); } + + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + // Clear previous scheduling results + mScheduling.clear(); + + int cpt = 0; + for (const auto& runnable : mStaticSchedule) { + if (verbose) + printf("run: %s\n", + (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); + else + drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, + (std::string("running ") + runnable->type() + "_" + + std::to_string(reinterpret_cast<uintptr_t>(runnable.get())))); + const auto tStart = std::chrono::high_resolution_clock::now(); + runnable->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); + cpt++; + } if (!verbose) drawProgressBar(1.0, 50, " "); printf("\n"); } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w"); - std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%s ms\n\n"); + std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n"); if (!mScheduling.empty()) { const auto globalStart = mScheduling[0].start; @@ -232,4 +237,4 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( } return consumers; -} \ No newline at end of file +} diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index dc693193c6606c99b1628d23ad253015f8f8dbe6..9f014364636c70031b522b09c893e1144af3f133 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -161,7 +161,7 @@ TEST_CASE("[core/graph] GraphView(addChild)") { TEST_CASE("[core/graph] GraphView(inputs)") { auto g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); - g1->add(conv); + g1->add(conv, false); REQUIRE(g1->inputs() == conv->inputs()); } @@ -330,4 +330,276 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE((r1->output(0))[0].first == r4); } -} \ No newline at end of file +} + +TEST_CASE("[GraphView] clone") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("clone_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->clone(); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("clone_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check new connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) != g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) != g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) != g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) != g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) != g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) != g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + +TEST_CASE("[GraphView] cloneSharedProducers") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedProducers_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedProducers(); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("cloneSharedProducers_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check new connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + +TEST_CASE("[GraphView] cloneSharedOperators") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("cloneSharedOperators_g1"); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + auto g2 = g1->cloneSharedOperators(); + g2->forwardDims(); + g2->save("cloneSharedOperators_g2"); + + SECTION("Check node cloning") { + REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); + REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); + REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); + REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); + REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); + REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); + REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); + REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); + REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); + } + + SECTION("Check operator cloning") { + REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator()); + REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); + REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); + REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator()); + REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); + REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); + REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator()); + REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); + REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); + } + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); + } +} + + +TEST_CASE("[core/graph] GraphView(insertParent)") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(32, 64, {1, 1}, "conv3"); + auto g = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g->add(conv1); + g->addChild(conv2, conv1, 0); + g->addChild(conv3, conv1, 0); + g->save("graphForwardDims"); + g->forwardDims(); + + auto newConv = Conv(32, 32, {1, 1}, "newConv"); + + SECTION("Check insertParent conv2 then insertParent conv3") { + g->insertParent(conv2, newConv, 0, 0, 0); + + std::set<NodePtr> expectedConv1Children = {conv3, newConv}; + std::set<NodePtr> expectedNewConvChildren = {conv2}; + + REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); + REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE((newConv->getChildren()) == expectedNewConvChildren); + REQUIRE((conv1->getChildren()) == expectedConv1Children); + + g->insertParent(conv3, newConv, 0, 0, 0); + + std::set<NodePtr> expectedConv1Children2 = {newConv}; + std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3}; + + REQUIRE(conv1->getOperator()->getOutput(0) != conv3->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); + REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(newConv->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE((newConv->getChildren()) == expectedNewConvChildren2); + REQUIRE((conv1->getChildren()) == expectedConv1Children2); + + } +} diff --git a/unit_tests/graph/Test_get.cpp b/unit_tests/graph/Test_get.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afd1f42ee9f5d6cd668dd5cab82172cdc298e149 --- /dev/null +++ b/unit_tests/graph/Test_get.cpp @@ -0,0 +1,55 @@ + + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +using namespace Aidge; +TEST_CASE("get Delta") { + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + std::set<Aidge::NodePtr> see; +conv->getNodeDelta(1,see); + + SECTION("Self return") { + see.clear(); + REQUIRE(conv->getNodeDelta(0,see) == std::set<std::shared_ptr<Node>>{conv}); + } + + + SECTION("child") { + see.clear(); + REQUIRE(conv->getNodeDelta(1,see) == std::set<std::shared_ptr<Node>>{conv1}); + } + + +} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_GRegex.cpp b/unit_tests/graphMatching/Test_GRegex.cpp index 7184fad76a921239753d4752ae1a4a61bf3aec16..2c5907d82e7c5b1d32f1fb38493c7333b68f8731 100644 --- a/unit_tests/graphMatching/Test_GRegex.cpp +++ b/unit_tests/graphMatching/Test_GRegex.cpp @@ -53,6 +53,10 @@ TEST_CASE("Create good init GRegex", "[GRegex]") { // Perform tests REQUIRE(GReg.getStmInit().size() == 1); REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } @@ -101,6 +105,10 @@ TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex // Perform tests REQUIRE(result == true_result); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { @@ -166,6 +174,10 @@ TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GR // Perform tests REQUIRE(result == true_result); REQUIRE(wrong_start_result == empty_result); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } /* diff --git a/unit_tests/graphMatching/Test_SeqStm.cpp b/unit_tests/graphMatching/Test_SeqStm.cpp index baabbbc3c10ec751c64a65ab01c2c4d502f58cb5..db8662e3329abe153d4a0fb2b3c46b950208d6bc 100644 --- a/unit_tests/graphMatching/Test_SeqStm.cpp +++ b/unit_tests/graphMatching/Test_SeqStm.cpp @@ -79,6 +79,10 @@ TEST_CASE("Create good init SeqStm", "[SeqStm]") { REQUIRE(stm.getAllCommonNode().size() == 0); REQUIRE(stm.getAllNodeTested().size() == 0); REQUIRE(stm.getAllNodeValidated().size() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test testNode function", "[SeqStm]") { @@ -156,4 +160,8 @@ TEST_CASE("Test testNode function", "[SeqStm]") { REQUIRE(stm.isStmBlocked() == true); REQUIRE(stm.getAllNodeTested() == testAllNodeTested); REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_StmFactory.cpp b/unit_tests/graphMatching/Test_StmFactory.cpp index b595372fd97a56f2ecf2575429c63db92484bbc0..3c66d0fa817cea674de5ab849091290c976e5735 100644 --- a/unit_tests/graphMatching/Test_StmFactory.cpp +++ b/unit_tests/graphMatching/Test_StmFactory.cpp @@ -36,6 +36,10 @@ TEST_CASE("Create good init StmFactory", "[StmFactory]") { } StmFactory stmF(nodesRegex); REQUIRE(stmF.getNumberOfStm() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { @@ -66,6 +70,10 @@ TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { //test the number of stm REQUIRE(stmF.getNumberOfStm() == 2); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { @@ -123,6 +131,9 @@ TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { REQUIRE(stm->getAllNodeTested() == testAllNodeTested); REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } @@ -185,5 +196,9 @@ TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { REQUIRE(stmD->isStmBlocked() == false); REQUIRE(stmD->getAllNodeTested().size() == 0); REQUIRE(stmD->getAllNodeValidated().size() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } diff --git a/unit_tests/graphRegex/Test_Fsm.cpp b/unit_tests/graphRegex/Test_Fsm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5950f21b323f07b380ae95f70637ca48a173481 --- /dev/null +++ b/unit_tests/graphRegex/Test_Fsm.cpp @@ -0,0 +1,195 @@ +#include <memory> + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + +TEST_CASE("matchFSM", "FsmEdge") { + + + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); + + SECTION("FsmEdgeUnique constructor") { + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + SECTION("FsmEdgeCommon constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeCommon EdgeToTest(nodeA,nodeB,toTest,"A"); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == true); + } + + SECTION("FsmEdgeRef constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeRef EdgeToTest(nodeA,nodeB,0,-1); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + SECTION("FsmEdgeEmpty constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeEmpty EdgeToTest(nodeA,nodeB); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + + SECTION("FsmEdgeFactory"){ + + std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + +// make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, +// FsmEdgeTypes type,std::map<std::string, const std::shared_ptr<ConditionalInterpreter>> allTest, +// const std::string& lexeme = ""); + + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(true,false); +// EMPTY = 0, +// REF, +// COMMON, +// UNIQUE + + std::shared_ptr<FsmEdge> edgeE = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::EMPTY,allTest,""); + std::shared_ptr<FsmEdge> edgeU = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::UNIQUE,allTest,"A"); + std::shared_ptr<FsmEdge> edgeC = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::COMMON,allTest,"A#"); + std::shared_ptr<FsmEdge> edgeR = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::REF,allTest,"(0,1)"); + + //test detection of bad syntax lexem + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::EMPTY,allTest,"A")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::UNIQUE,allTest,"A#")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::COMMON,allTest,"A")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::REF,allTest,"A")); + + REQUIRE(edgeE->getSourceNode() == nodeA); + REQUIRE(edgeE->getDestNode() == nodeB); + } + + SECTION("graph constructor") { + //make the nodes + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(false,true); + + //make the edges + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + + graph->addEdge(edgeAB); + graph->addEdge(edgeBC); + + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeC}); + } + + + SECTION("graph merge") { + + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + //make the nodes + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(true,false); + + //make the edges + + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + graph->addEdge(edgeAB); + graph->addEdge(edgeBC); + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{nodeC}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA,nodeB,nodeC}); + + //make the nodes + std::shared_ptr<FsmNode> node2A = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> node2B = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> node2C = std::make_shared<FsmNode>(true,false); + + + std::shared_ptr<FsmEdge> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest); + std::shared_ptr<FsmEdge> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest); + + std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(); + + + graph2->addEdge(edge2AB); + graph2->addEdge(edge2BC); + + + REQUIRE(graph2->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{node2C}); + REQUIRE(graph2->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{node2A}); + REQUIRE(graph2->getNodes() == std::set<std::shared_ptr<FsmNode>>{node2A,node2B,node2C}); + + + graph->mergeOneStartOneValid(graph2); + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{node2C}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA,nodeB,nodeC,node2B,node2C}); + } + + + + +} + +// TEST_CASE("matchFSM", "FsmGraph") { + +// SECTION("FsmEdgeUnique constructor") { +// //make the nodes +// std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); +// std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + +// //make the edges +// std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); +// std::shared_ptr<FsmEdgeUnique> edge = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + +// std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + +// graph->addEdge(edge); + + + +// } + +// } \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_FsmMatch.cpp b/unit_tests/graphRegex/Test_FsmMatch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1fe75be1a47033f75af7ccc4dc5202774444cd10 --- /dev/null +++ b/unit_tests/graphRegex/Test_FsmMatch.cpp @@ -0,0 +1,89 @@ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("FsmMatch") { + + SECTION("Construction") { + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"B",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + + allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->A",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + + + //REQUIRE(fsm->getNodes().size() == 3); + //REQUIRE(fsm->getStartNodes().size() == 1); + + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + REQUIRE(allTest["A"]->test(conv) == true); + REQUIRE(allTest["B"]->test(conv) == true); + + std::vector<std::shared_ptr<Node>> startNodes = {conv}; + + auto result = fsm->test(startNodes); + + REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1}); + } + + + SECTION("2 branche graph"){ + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Fc", 1, 1, 1, "c2"); + + g1->add(conv); + g1->addChild(conv1,conv); + g1->addChild(conv2,conv); + + REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>({conv,conv1,conv2})); + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>({conv})); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>({conv1,conv2})); + + + ///////////// + + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"B",std::make_shared<ConditionalInterpreter>("isFc($)==true")} + }; + allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest["B"]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A#->A; A#->B",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + std::vector<std::shared_ptr<Node>> startNodes = {conv,conv}; + auto result = fsm->test(startNodes); + REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1,conv2}); + + } + +} diff --git a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ce090506c9a61abd928b3ae590ee838afb05999 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp @@ -0,0 +1,42 @@ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("GraphFsmInterpreter", "GraphFsmInterpreter") { + + SECTION("Construction") { + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + + //GraphFsmInterpreter("A->B",allTest); + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->B",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + + REQUIRE(fsm->getNodes().size() == 3); + REQUIRE(fsm->getStartNodes().size() == 1); + REQUIRE(fsm->getEdge().size() == 2); + + for(auto node : fsm->getNodes()){ + if(node->isValid()){ + REQUIRE(node->getEdges().size() == 0); + }else{ + REQUIRE(node->getEdges().size() == 1); + } + + } + + + } + + + + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_GraphLexer.cpp b/unit_tests/graphRegex/Test_GraphLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b8cc8e018546ebfe3f84202d9404db27b17449b --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphLexer.cpp @@ -0,0 +1,118 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphLexer.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +#include "aidge/utilsParsing/ParsingToken.hpp" + + +#include <iostream> +#include <map> +#include <functional> + +using namespace Aidge; + +// NEXT +// QOM +// QZM +// KEY +// CKEY +// SEP +// LPAREN +// RPAREN + +TEST_CASE("GraphRegex", "Lexer") { + SECTION("RandomGenerateTest") { + + std::map<gRegexTokenTypes, std::function<std::pair<std::string, std::string>()>> LexerTestMap{ + {gRegexTokenTypes::NEXT, +[](){return std::pair<std::string, std::string>("-> ","");}}, + {gRegexTokenTypes::QOM, +[](){return std::pair<std::string, std::string>("+ ","");}}, + {gRegexTokenTypes::QZM, +[](){return std::pair<std::string, std::string>("* ","");}}, + {gRegexTokenTypes::SEP, +[](){return std::pair<std::string, std::string>("; ","");}}, + + + + {gRegexTokenTypes::KEY, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {gRegexTokenTypes::CKEY, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_"; + const std::string num = "1234567890"; + + std::size_t randomIndex = std::rand() % characters.size(); + std::size_t randomNum = std::rand() % num.size(); + std::string key; + std::string idx; + + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + idx += num[randomNum]; + randomIndex = std::rand() % characters.size(); + randomNum = std::rand() % num.size(); + } + + return std::pair<std::string, std::string>(key+"#"+idx+" ",key+"#"+idx);} + }, + + {gRegexTokenTypes::LPAREN, +[](){return std::pair<std::string, std::string>("( ","");}}, + {gRegexTokenTypes::RPAREN, +[](){return std::pair<std::string, std::string>(") ","");}} + //{gRegexTokenTypes::STOP, +[](){return std::pair<std::string, std::string>("","");}} + }; + + + ////////////////// + //TEST GENERATOR + ////////////////// + const std::size_t numRandomElements = 10000; + std::vector<std::tuple<gRegexTokenTypes, std::string>> testVector; + + std::string testString; + + for (std::size_t i = 0; i < numRandomElements; ++i) { + + int randomIndex = std::rand() % LexerTestMap.size(); + // Get an iterator to the random element in the map + auto it = std::next(LexerTestMap.begin(), randomIndex); + // Access the random key and lambda value separately using structured binding + gRegexTokenTypes randomKey = it->first; + + std::function<std::pair<std::string, std::string>()> randomValue = it->second; + std::pair<std::string, std::string> result = randomValue(); + + testString += result.first; + testVector.emplace_back(randomKey, result.second); + + + } + + GraphLexer graphLexer = GraphLexer(testString); + + for (std::tuple<gRegexTokenTypes, std::string> testToken : testVector) { + gRegexTokenTypes tokenToFind = std::get<0>(testToken); + std::string lexemToFind = std::get<1>(testToken); + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = graphLexer.getNextToken(); + + + std::ostringstream errorMessage; + errorMessage << "\n we whant :"<< lexemToFind << "\n we get : "<< token->getLexeme() <<"\n"<< "on \n" << testString << " :\n " ; + + CAPTURE(errorMessage.str()); + REQUIRE(token->getLexeme() == lexemToFind); + REQUIRE(token->getType() == tokenToFind); + } + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = graphLexer.getNextToken(); + REQUIRE(token->getType() == gRegexTokenTypes::STOP); + } + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_GraphParser.cpp b/unit_tests/graphRegex/Test_GraphParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..857caa06f4e5fa383e79ea22bfe1ca28ac0973c8 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphParser.cpp @@ -0,0 +1,82 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/utilsParsing/AstNode.hpp" +#include <iostream> + + +using namespace Aidge; + + //generative function , + std::string domain(); + std::string exp() { + int randomValue = std::rand() % 3; + switch (randomValue) { + case 0: + return "A"; + case 1 : + return "A#"; + default: + return domain(); + + } + } + + std::string seq() { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return exp(); + default: + return exp()+"->"+seq(); + } + } + + std::string domain() { + int randomValue = std::rand() % 2; + + switch (randomValue) { + // case 0: + // return seq(); + // case 1: + // return seq() + "->" +domain(); + + case 0: + return "("+ seq() +")*"; + default: + return "("+ seq() +")+"; + + // case 4: + // return "("+ domain() +")*" + "->" +domain(); + // default: + // return "("+ domain() +")+" + "->" +domain(); + } + } + + std::string allExpr() { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return seq(); + default : + return seq()+ ";" +allExpr(); + } + } + +/* +exp : KEY(QOM | QZM)? | CKEY | domain +seq :exp (NEXT seq)* +domain : LPAREN seq RPAREN (QOM | QZM) +allExpr: seq (SEP allExpr)* +*/ +TEST_CASE("GraphParser", "Test_GraphParser") { + + SECTION("Empty") { + for (int i = 0; i < 100; ++i) { + const std::string test = allExpr(); + std::cout << test <<"\n"; + GraphParser graphParser = GraphParser(test); + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = graphParser.parse(); + } + } +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_graphRegexAST.cpp b/unit_tests/graphRegex/Test_graphRegexAST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cdb0bc1934983a26ab742bfe8879455077219cc --- /dev/null +++ b/unit_tests/graphRegex/Test_graphRegexAST.cpp @@ -0,0 +1,71 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphStrInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("GraphStrInterpreter") { + + + + std::vector<std::string> tests = { + + + //sequ + "A;", + "A->B", + "A->B->C", + //seq and common + "A#", + "A#->B", + "A#->B#", + "A#->B#->C", + "A#->B#->C#", + "A->B#->C", + //sequ quantif + + "A+", + "A+->B+", + "A->B+->C", + //sequ quantif * + "A*", + "A*->B*", + "A->B*->C", + + //sequ quantif + "A*", + "A*->B+", + "A+->B*->C", + //others + + "(A#->B->C#)+", + "(A#->B)+;A#->B->C", + "B+->B->B", + "B#->R*", + "(B#->R)*", + "A->C->B#->B;B#->R", + "B#->R", + "A->C#;A->C#;A->C#;A->C#;A->C#;A->C#", + "B#->R;B#->R", + "A# -> C -> B#; B#->A#", + + // Add more test cases here + }; + + SECTION("AST Regex bijection") { + + for (const std::string& test : tests) { + std::shared_ptr<GraphStrInterpreter> strGenerator = std::make_shared<GraphStrInterpreter>(test); + std::string astString = strGenerator->interpret(); + //supress space in the test becase erase in the AST + std::string testNoS = test; + testNoS.erase(std::remove_if(testNoS.begin(), testNoS.end(), ::isspace), testNoS.end()); + //if the last char is ; (SEP) it will not in the AST and it's not a bug erase it + if (!testNoS.empty() && testNoS.back() == ';') { + // Remove the last character + testNoS.pop_back(); + } + //test + REQUIRE(astString == testNoS); + } + + } +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b502fb546e2f1396b629ebc78bc1bd4d67842e2 --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp @@ -0,0 +1,66 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalInterpreter.hpp" +#include "aidge/operator/GenericOperator.hpp" + + +using namespace Aidge; + + + +TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { + + SECTION("custom Lambda") { + + const std::string test = " !toto($) == true " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;}); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + + SECTION("syntax error") { + + const std::string test = "'A' == 'A' ,&& "; + REQUIRE_THROWS_AS( ConditionalInterpreter(test), std::runtime_error); + + } + + + SECTION("test false int ") { + + const std::string test = " 10 == 11 " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == false); + } + + SECTION("test true int ") { + const std::string test = " 42 == 42 " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + + SECTION("test false str ") { + const std::string test = " 'toto' == 'Corgi' " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == false); + } + + SECTION("test true str ") { + + const std::string test = " 'Corgi' == 'Corgi' " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalLexer.cpp b/unit_tests/nodeTester/Test_ConditionalLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a937c27227dde4fa03ed7733df9e9552c3c1ac7b --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalLexer.cpp @@ -0,0 +1,144 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalLexer.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" + +#include <iostream> +#include <map> +#include <functional> + +using namespace Aidge; + +TEST_CASE("nodeTester", "Lexer") { + SECTION("RandomGenerateTest") { + + std::map<ConditionalTokenTypes, std::function<std::pair<std::string, std::string>()>> LexerTestMap{ + {ConditionalTokenTypes::AND, +[](){return std::pair<std::string, std::string>("&& ","");}}, + {ConditionalTokenTypes::OR, +[](){return std::pair<std::string, std::string>("|| ","");}}, + {ConditionalTokenTypes::EQ, +[](){return std::pair<std::string, std::string>("== ","");}}, + {ConditionalTokenTypes::NEQ, +[](){return std::pair<std::string, std::string>("!= ","");}}, + + {ConditionalTokenTypes::KEY, +[](){return std::pair<std::string, std::string>("A ","A");}}, + + {ConditionalTokenTypes::BOOL, +[](){ + std::size_t keyLen = (std::rand() % 2); + const std::vector<std::string> characters = {"true","false"}; + + return std::pair<std::string, std::string>(characters[keyLen]+" ",characters[keyLen]);} + }, + + {ConditionalTokenTypes::INTEGER, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "1234567890"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {ConditionalTokenTypes::FLOAT, +[](){ + std::size_t keyLen = (std::rand() % 20)+2; + const std::string characters = "1234567890"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen/2; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + key += "."; + for (std::size_t i = 0; i < keyLen/2; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {ConditionalTokenTypes::STRING, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890 "; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>("'"+key+"' ",key);} + }, + + {ConditionalTokenTypes::LAMBDA, +[](){ + + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + const std::string Startchar = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::size_t randomIndex = std::rand() % characters.size(); + std::size_t randomStartIndex = std::rand() % Startchar.size(); + + std::string key; + key += Startchar[randomStartIndex]; + + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>(key+"( ",key);} + }, + + {ConditionalTokenTypes::ARGSEP, +[](){return std::pair<std::string, std::string>(", ","");}}, + {ConditionalTokenTypes::NODE, +[](){return std::pair<std::string, std::string>("$ ","");}}, + {ConditionalTokenTypes::LPAREN, +[](){return std::pair<std::string, std::string>("( ","");}}, + {ConditionalTokenTypes::RPAREN, +[](){return std::pair<std::string, std::string>(") ","");}} + //{ConditionalTokenTypes::STOP, +[](){return std::pair<std::string, std::string>("","");}} + }; + + + ////////////////// + //TEST GENERATOR + ////////////////// + const std::size_t numRandomElements = 100; + std::vector<std::tuple<ConditionalTokenTypes, std::string>> testVector; + + std::string testString; + + for (std::size_t i = 0; i < numRandomElements; ++i) { + + int randomIndex = std::rand() % LexerTestMap.size(); + // Get an iterator to the random element in the map + auto it = std::next(LexerTestMap.begin(), randomIndex); + // Access the random key and lambda value separately using structured binding + ConditionalTokenTypes randomKey = it->first; + + std::function<std::pair<std::string, std::string>()> randomValue = it->second; + std::pair<std::string, std::string> result = randomValue(); + + testString += result.first; + testVector.emplace_back(randomKey, result.second); + + + } + + ConditionalLexer conditionalLexer = ConditionalLexer(testString); + + for (std::tuple<ConditionalTokenTypes, std::string> testToken : testVector) { + ConditionalTokenTypes tokenToFind = std::get<0>(testToken); + std::string lexemToFind = std::get<1>(testToken); + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = conditionalLexer.getNextToken(); + + + std::ostringstream errorMessage; + errorMessage << "\n we whant :"<< lexemToFind << "\n we get : "<< token->getLexeme() <<"\n"<< "on \n" << testString << " :\n " ; + + CAPTURE(errorMessage.str()); + REQUIRE(token->getLexeme() == lexemToFind); + REQUIRE(token->getType() == tokenToFind); + } + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = conditionalLexer.getNextToken(); + REQUIRE(token->getType() == ConditionalTokenTypes::STOP); + } + + +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalParser.cpp b/unit_tests/nodeTester/Test_ConditionalParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56adb92b41745001e1790e087f07369918794c5d --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalParser.cpp @@ -0,0 +1,75 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalParser.hpp" +#include "aidge/utilsParsing/AstNode.hpp" + +using namespace Aidge; + + std::string gVal() { + int randomValue = std::rand() % 5; + switch (randomValue) { + case 0: + return std::to_string(std::rand() % 101); + + case 1: + return std::to_string(std::rand() % 101)+"."+std::to_string(std::rand() % 101); + + case 2: + return " 'toto' "; + case 3: + return " A "; + + case 4: + return " A(10) "; + + default: + return " true "; + + } + } + + std::string gExpr() ; + std::string gCmpr() { + int randomValue = std::rand() % 3; + switch (randomValue) { + case 0: + return gVal() + " == " +gVal(); + case 1: + return "("+ gExpr() +")"; + default: + return gVal() + " != " +gVal(); + + } + + + return gVal() + " == " +gVal(); + } + + std::string gExpr() { + std::string out = gCmpr(); + int iterations = std::rand() % 100; + for (int i = 0; i < iterations; ++i) { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return out +" && " + gCmpr(); + break; + default: + return out +" || " + gCmpr(); + break; + } + } + return out; + } + + +TEST_CASE("ConditionalParser", "ConditionalParser") { + + SECTION("Empty") { + for (int i = 0; i < 100; ++i) { + const std::string test = gExpr(); + ConditionalParser conditionalParser = ConditionalParser(test); + std::shared_ptr<AstNode<ConditionalTokenTypes>> tree = conditionalParser.parse(); + } + } +} \ No newline at end of file diff --git a/unit_tests/operator/Test_GenericOperator.cpp b/unit_tests/operator/Test_GenericOperator.cpp index 886326214a4a285fb32e5909da5114d74782ee46..8d634cc3a105c423b54b6003f41204aeb1fc5335 100644 --- a/unit_tests/operator/Test_GenericOperator.cpp +++ b/unit_tests/operator/Test_GenericOperator.cpp @@ -17,72 +17,72 @@ using namespace Aidge; -TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") { +TEST_CASE("[core/operators] GenericOp(add & get attributes)", "[Operator]") { SECTION("INT") { GenericOperator_Op Testop("TestOp", 1, 1, 1); - int value = 5; - const char* key = "intParam"; - Testop.addParameter(key, value); - REQUIRE(Testop.getParameter<int>(key) == value); + const char* key = "intAttr"; + Testop.addAttr(key, int(5)); + int registeredVal = Testop.getAttr<int>(key); + REQUIRE(registeredVal == 5); } SECTION("LONG") { GenericOperator_Op Testop("TestOp", 1, 1, 1); long value = 3; - const char* key = "longParam"; - Testop.addParameter(key, value); - REQUIRE(Testop.getParameter<long>(key) == value); + const char* key = "longAttr"; + Testop.addAttr(key, value); + REQUIRE(Testop.getAttr<long>(key) == value); } SECTION("FLOAT") { GenericOperator_Op Testop("TestOp", 1, 1, 1); float value = 2.0; - const char* key = "floatParam"; - Testop.addParameter(key, value); - REQUIRE(Testop.getParameter<float>(key) == value); + const char* key = "floatAttr"; + Testop.addAttr(key, value); + REQUIRE(Testop.getAttr<float>(key) == value); } SECTION("VECTOR<BOOL>") { GenericOperator_Op Testop("TestOp", 1, 1, 1); std::vector<bool> value = {true, false, false, true, true}; const char* key = "vect"; - Testop.addParameter(key, value); + Testop.addAttr(key, value); - REQUIRE(Testop.getParameter<std::vector<bool>>(key).size() == value.size()); + REQUIRE(Testop.getAttr<std::vector<bool>>(key).size() == value.size()); for (std::size_t i=0; i < value.size(); ++i){ - REQUIRE(Testop.getParameter<std::vector<bool>>(key)[i] == value[i]); + REQUIRE(Testop.getAttr<std::vector<bool>>(key)[i] == value[i]); } } SECTION("VECTOR<INT>") { GenericOperator_Op Testop("TestOp", 1, 1, 1); std::vector<int> value = {1, 2, 3, 4, 5, 6, 7, 8, 9}; const char* key = "vect"; - Testop.addParameter(key, value); + Testop.addAttr(key, value); - REQUIRE(Testop.getParameter<std::vector<int>>(key).size() == value.size()); + REQUIRE(Testop.getAttr<std::vector<int>>(key).size() == value.size()); for (std::size_t i=0; i < value.size(); ++i){ - REQUIRE(Testop.getParameter<std::vector<int>>(key)[i] == value[i]); + REQUIRE(Testop.getAttr<std::vector<int>>(key)[i] == value[i]); } } SECTION("MULTIPLE PARAMS") { /* - Goal : Test that the offsets are well done by adding different parameters with different size. + Goal : Test that the offsets are well done by adding different attributes with different size. */ GenericOperator_Op Testop("TestOp", 1, 1, 1); - Testop.addParameter<long>("longParam", 3); - Testop.addParameter<float>("floatParam", 2.0); - Testop.addParameter<uint8_t>("uint8Param", 5); - Testop.addParameter<long long>("llParam", 10); - REQUIRE(Testop.getParameter<long>("longParam") == 3); - REQUIRE(Testop.getParameter<float>("floatParam") == 2.0); - REQUIRE(Testop.getParameter<uint8_t>("uint8Param") == 5); - REQUIRE(Testop.getParameter<long long>("llParam") == 10); + Testop.addAttr<long>("longAttr", 3); + Testop.addAttr<float>("floatAttr", 2.0); + Testop.addAttr<uint8_t>("uint8Attr", 5); + Testop.addAttr<long long>("llAttr", 10); + REQUIRE(Testop.getAttr<long>("longAttr") == 3); + REQUIRE(Testop.getAttr<float>("floatAttr") == 2.0); + REQUIRE(Testop.getAttr<uint8_t>("uint8Attr") == 5); + REQUIRE(Testop.getAttr<long long>("llAttr") == 10); } } -TEST_CASE("[core/operator] GenericOp(type check)", "[.ass]") { +TEST_CASE("[core/operator] GenericOp(type check)", "[Operator]") { SECTION("WRONG TYPE FOR GETTER") { GenericOperator_Op Testop("TestOp", 1, 1, 1); - Testop.addParameter<long>("longParam", 3); + Testop.addAttr<long>("longAttr", 3); // This line should raise a failled assert - REQUIRE_THROWS(Testop.getParameter<int>("longParameter")); + REQUIRE_THROWS(Testop.getAttr<int>("longAttribute")); } } diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c090427914390369452ce3259f47830f01ab1754 --- /dev/null +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -0,0 +1,53 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/graph/GraphView.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[core/operators] MetaOperator", "[Operator]") { + SECTION("PaddedConv") { + auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {1, 1, 1, 1}); + + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraph(); + + REQUIRE(microGraph->getNodes().size() == 2); + REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias) + // Order not garanteed by the GraphView + //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad"); + REQUIRE(microGraph->outputNodes().size() == 1); + REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); + REQUIRE(op->nbInputs() == 3); + REQUIRE(op->nbDataInputs() == 1); + REQUIRE(op->nbOutputs() == 1); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(); + myInput->resize({2,3,5,5}); + op->getOperator()->associateInput(0,myInput); + op->getOperator()->computeOutputDims(); + + REQUIRE(op->getOperator()->outputDimsForwarded()); + REQUIRE(op->getOperator()->getOutput(0)->dims() == std::vector<size_t>({2,3,5,5})); + REQUIRE(op->getOperator()->getInput(0) == myInput); + // Order not garanteed by the GraphView + //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getInput(0) == myInput); + REQUIRE(op->getOperator()->getOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getOutput(0)); + + //op->getOperator()->updateConsummerProducer(); // require implementation + //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); + //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); + } +} diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da53642055a3146c71a211ad7816f21c9b92d6cd --- /dev/null +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -0,0 +1,77 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +// #include "aidge/backend/cpu/operator/AddImpl.hpp" +// #include "aidge/backend/cpu/operator/ConvImpl.hpp" +// #include "aidge/backend/cpu/operator/FCImpl.hpp" +// #include "aidge/backend/cpu/operator/MatMulImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Recipies.hpp" + +namespace Aidge { + +TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { + // generate the original GraphView + auto matmul0 = MatMul(5, "matmul0"); + auto add0 = Add<2>("add0"); + auto matmul1 = MatMul(5, "matmul1"); + auto add1 = Add<2>("add1"); + + auto b0 = Producer({5}, "B0"); + auto w0 = Producer({5, 5}, "W0"); + auto b1 = Producer({5}, "B1"); + auto w1 = Producer({5,5},"W1"); + auto input = Producer({2,5}, "input"); + + input->addChild(matmul0, 0, 0); + w0->addChild(matmul0, 0, 1); + + matmul0->addChild(add0, 0, 0); + b0->addChild(add0, 0, 1); + + add0->addChild(matmul1, 0, 0); + w1->addChild(matmul1, 0, 1); + + matmul1->addChild(add1, 0, 0); + b1->addChild(add1, 0, 1); + + auto g = std::make_shared<GraphView>(); + g->add({matmul0, add0, matmul1, add1, b0, b1}); + + // Check original graph + REQUIRE(g->getNodes() == + std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); + REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0))); + REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0))); + REQUIRE(((matmul1->getParent(0) == add0) && (matmul1->getParent(1) == w1))); + REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); + + // Transform GraphView inplace + fuseMulAdd(g); + g->save("bonjour"); + + // Check new GraphView + std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1})); + REQUIRE(newNodes.size() == 6); + for (const auto& node : newNodes) { + REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); + } +} +} // namespace Aidge \ No newline at end of file diff --git a/unit_tests/recipies/Test_LabelGraph.cpp b/unit_tests/recipies/Test_LabelGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..873ad68f3198c6b6adf44d8c7ae31e667c63a18d --- /dev/null +++ b/unit_tests/recipies/Test_LabelGraph.cpp @@ -0,0 +1,154 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/recipies/LabelGraph.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[LabelGraph] conv") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g1 = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g1->add(conv1); + g1->addChild(conv2, conv1, 0); + g1->addChild(conv3, conv2, 0); + g1->save("LabelGraph_conv_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_conv_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } +} + +TEST_CASE("[LabelGraph] deleted node") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + Conv(64, 10, {1, 1}, "conv3", {2, 2}) + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleted_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 220, 220})); + REQUIRE(g2->getNode("conv3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 110, 110})); + } +} + +TEST_CASE("[LabelGraph] deleted nodes") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(32, 64, {3, 3}, "conv2"), + GenericOperator("Dummy_to_be_removed", 1, 1, 1), + Conv(64, 10, {1, 1}, "conv3") + }); + + g1->save("LabelGraph_deleteds_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("conv1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_deleteds_label"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } +} + +TEST_CASE("[LabelGraph] pooling") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + AvgPooling({2, 2}, "pool1"), + MaxPooling({2, 2}, "pool2"), + MaxPooling({2, 2}, "pool3", {2, 2}) + }); + + g1->save("LabelGraph_deleted_graph"); + + auto g2 = labelGraph(g1); + + auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider"); + dataProvider2->addChild(g2->getNode("pool1"), 0); + + g2->forwardDims(); + g2->save("LabelGraph_pooling"); + + SECTION("Check resulting nodes") { + REQUIRE(g2->getNodes().size() == 3); + REQUIRE(g2->getNode("pool1")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0) == g2->getNode("pool2")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool2")->getOperator()->type() == "MaxPooling"); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0) == g2->getNode("pool3")->getOperator()->getInput(0)); + REQUIRE(g2->getNode("pool3")->getOperator()->type() == "MaxPooling"); + } + + SECTION("Check dimensions") { + REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 223, 223})); + REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222})); + REQUIRE(g2->getNode("pool3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 111, 111})); + } +} diff --git a/unit_tests/utils/Test_StaticAttributes.cpp b/unit_tests/utils/Test_StaticAttributes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..36c2e0454b415e1cb25cc3581016530a372b9e65 --- /dev/null +++ b/unit_tests/utils/Test_StaticAttributes.cpp @@ -0,0 +1,48 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include <string> +#include <vector> + +#include "aidge/utils/StaticAttributes.hpp" + +using namespace Aidge; + +enum class TestAttr { a, b, c, d }; + +namespace { +template <> +const char *const EnumStrings<TestAttr>::data[] = { + "a", + "b", + "c", + "d" +}; +} + +using Attributes_ = StaticAttributes<TestAttr, int, float, std::string, std::vector<bool>>; +template <TestAttr e> +using attr = typename Attributes_::template attr<e>; + +TEST_CASE("[core/attributes] StaticAttribute") { + SECTION("TestAttr") { + StaticAttributes<TestAttr, int, float, std::string, std::vector<bool>> attrs( + attr<TestAttr::a>(42), + attr<TestAttr::b>(18.75), + attr<TestAttr::c>("test"), + attr<TestAttr::d>({true, false, true})); + + REQUIRE(attrs.getAttr<int>("a") == 42); + REQUIRE_THROWS(attrs.getAttr<int>("inexistant")); + } +}