From 4e5ad5e06c3170e3788090d85475038747805799 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 1 Aug 2023 08:53:17 +0000 Subject: [PATCH] Initial commit --- CMakeLists.txt | 52 ++ Makefile | 36 + README.md | 100 +-- include/backend/OperatorImpl.hpp | 60 ++ include/backend/TensorImpl.hpp | 41 ++ include/data/Data.hpp | 75 ++ include/data/Tensor.hpp | 473 ++++++++++++ include/graph/Connector.hpp | 86 +++ include/graph/GraphView.hpp | 334 +++++++++ include/graph/Node.hpp | 360 ++++++++++ include/graph/OpArgs.hpp | 86 +++ include/graphmatching/GRegex.hpp | 63 ++ include/graphmatching/Match.hpp | 44 ++ include/graphmatching/NodeRegex.hpp | 41 ++ include/graphmatching/SeqStm.hpp | 127 ++++ include/graphmatching/StmFactory.hpp | 55 ++ include/graphmatching/Utile.hpp | 50 ++ include/operator/Add.hpp | 147 ++++ include/operator/AvgPooling.hpp | 169 +++++ include/operator/BatchNorm.hpp | 161 +++++ include/operator/Conv.hpp | 200 ++++++ include/operator/ConvDepthWise.hpp | 196 +++++ include/operator/FC.hpp | 155 ++++ include/operator/GenericOperator.hpp | 165 +++++ include/operator/LeakyReLU.hpp | 127 ++++ include/operator/Matmul.hpp | 143 ++++ include/operator/MetaOperator.hpp | 28 + include/operator/Operator.hpp | 99 +++ include/operator/Producer.hpp | 141 ++++ include/operator/ReLU.hpp | 110 +++ include/operator/Softmax.hpp | 110 +++ include/scheduler/Scheduler.hpp | 71 ++ include/utils/CParameter.hpp | 110 +++ include/utils/Parameter.hpp | 197 +++++ include/utils/Recipies.hpp | 27 + include/utils/Registrar.hpp | 75 ++ include/utils/Types.h | 62 ++ include/utilsParsing/AstNode.hpp | 69 ++ include/utilsParsing/ParsingToken.hpp | 66 ++ python_binding/CMakeLists.txt | 0 .../backend/pybind_OperatorImpl.cpp | 20 + python_binding/data/pybind_Data.cpp | 37 + python_binding/data/pybind_Tensor.cpp | 147 ++++ python_binding/graph/pybind_Connector.cpp | 29 + python_binding/graph/pybind_GraphView.cpp | 60 ++ python_binding/graph/pybind_Node.cpp | 49 ++ python_binding/graph/pybind_OpArgs.cpp | 38 + .../graphmatching/pybind_GRegex.cpp | 25 + python_binding/graphmatching/pybind_Match.cpp | 25 + .../graphmatching/pybind_NodeRegex.cpp | 22 + python_binding/operator/pybind_Add.cpp | 32 + python_binding/operator/pybind_AvgPooling.cpp | 89 +++ python_binding/operator/pybind_BatchNorm.cpp | 33 + python_binding/operator/pybind_Conv.cpp | 107 +++ .../operator/pybind_ConvDepthWise.cpp | 100 +++ python_binding/operator/pybind_FC.cpp | 32 + .../operator/pybind_GenericOperator.cpp | 67 ++ python_binding/operator/pybind_LeakyReLU.cpp | 26 + python_binding/operator/pybind_Matmul.cpp | 32 + python_binding/operator/pybind_Operator.cpp | 28 + python_binding/operator/pybind_Producer.cpp | 50 ++ python_binding/operator/pybind_ReLU.cpp | 25 + python_binding/operator/pybind_Softmax.cpp | 26 + python_binding/pybind_core.cpp | 92 +++ python_binding/recipies/pybind_Recipies.cpp | 27 + python_binding/scheduler/pybind_Scheduler.cpp | 26 + python_binding/utils/pybind_Parameter.cpp | 12 + src/CMakeLists.txt | 0 src/graph/Connector.cpp | 54 ++ src/graph/GraphView.cpp | 673 ++++++++++++++++++ src/graph/Node.cpp | 327 +++++++++ src/graph/OpArgs.cpp | 73 ++ src/graphmatching/GRegex.cpp | 301 ++++++++ src/graphmatching/Match.cpp | 37 + src/graphmatching/NodeRegex.cpp | 46 ++ src/graphmatching/SeqStm.cpp | 247 +++++++ src/graphmatching/StmFactory.cpp | 150 ++++ src/operator/Operator.cpp | 44 ++ src/recipies/FuseMulAdd.cpp | 80 +++ src/recipies/RemoveFlatten.cpp | 28 + src/scheduler/Scheduler.cpp | 235 ++++++ tests/CMakeLists.txt | 25 + tests/graph/Test_Connector.cpp | 257 +++++++ tests/graph/Test_GraphView.cpp | 333 +++++++++ tests/graphMatching/Test_GRegex.cpp | 306 ++++++++ tests/graphMatching/Test_NodeRegex.cpp | 44 ++ tests/graphMatching/Test_SeqStm.cpp | 159 +++++ tests/graphMatching/Test_StmFactory.cpp | 189 +++++ tests/operator/Test_GenericOperator.cpp | 77 ++ 89 files changed, 9570 insertions(+), 82 deletions(-) create mode 100644 CMakeLists.txt create mode 100644 Makefile create mode 100644 include/backend/OperatorImpl.hpp create mode 100644 include/backend/TensorImpl.hpp create mode 100644 include/data/Data.hpp create mode 100644 include/data/Tensor.hpp create mode 100644 include/graph/Connector.hpp create mode 100644 include/graph/GraphView.hpp create mode 100644 include/graph/Node.hpp create mode 100644 include/graph/OpArgs.hpp create mode 100644 include/graphmatching/GRegex.hpp create mode 100644 include/graphmatching/Match.hpp create mode 100644 include/graphmatching/NodeRegex.hpp create mode 100755 include/graphmatching/SeqStm.hpp create mode 100644 include/graphmatching/StmFactory.hpp create mode 100644 include/graphmatching/Utile.hpp create mode 100644 include/operator/Add.hpp create mode 100644 include/operator/AvgPooling.hpp create mode 100644 include/operator/BatchNorm.hpp create mode 100644 include/operator/Conv.hpp create mode 100644 include/operator/ConvDepthWise.hpp create mode 100644 include/operator/FC.hpp create mode 100644 include/operator/GenericOperator.hpp create mode 100644 include/operator/LeakyReLU.hpp create mode 100644 include/operator/Matmul.hpp create mode 100644 include/operator/MetaOperator.hpp create mode 100644 include/operator/Operator.hpp create mode 100644 include/operator/Producer.hpp create mode 100644 include/operator/ReLU.hpp create mode 100644 include/operator/Softmax.hpp create mode 100644 include/scheduler/Scheduler.hpp create mode 100644 include/utils/CParameter.hpp create mode 100644 include/utils/Parameter.hpp create mode 100644 include/utils/Recipies.hpp create mode 100644 include/utils/Registrar.hpp create mode 100644 include/utils/Types.h create mode 100644 include/utilsParsing/AstNode.hpp create mode 100644 include/utilsParsing/ParsingToken.hpp create mode 100644 python_binding/CMakeLists.txt create mode 100644 python_binding/backend/pybind_OperatorImpl.cpp create mode 100644 python_binding/data/pybind_Data.cpp create mode 100644 python_binding/data/pybind_Tensor.cpp create mode 100644 python_binding/graph/pybind_Connector.cpp create mode 100644 python_binding/graph/pybind_GraphView.cpp create mode 100644 python_binding/graph/pybind_Node.cpp create mode 100644 python_binding/graph/pybind_OpArgs.cpp create mode 100644 python_binding/graphmatching/pybind_GRegex.cpp create mode 100644 python_binding/graphmatching/pybind_Match.cpp create mode 100644 python_binding/graphmatching/pybind_NodeRegex.cpp create mode 100644 python_binding/operator/pybind_Add.cpp create mode 100644 python_binding/operator/pybind_AvgPooling.cpp create mode 100644 python_binding/operator/pybind_BatchNorm.cpp create mode 100644 python_binding/operator/pybind_Conv.cpp create mode 100644 python_binding/operator/pybind_ConvDepthWise.cpp create mode 100644 python_binding/operator/pybind_FC.cpp create mode 100644 python_binding/operator/pybind_GenericOperator.cpp create mode 100644 python_binding/operator/pybind_LeakyReLU.cpp create mode 100644 python_binding/operator/pybind_Matmul.cpp create mode 100644 python_binding/operator/pybind_Operator.cpp create mode 100644 python_binding/operator/pybind_Producer.cpp create mode 100644 python_binding/operator/pybind_ReLU.cpp create mode 100644 python_binding/operator/pybind_Softmax.cpp create mode 100644 python_binding/pybind_core.cpp create mode 100644 python_binding/recipies/pybind_Recipies.cpp create mode 100644 python_binding/scheduler/pybind_Scheduler.cpp create mode 100644 python_binding/utils/pybind_Parameter.cpp create mode 100644 src/CMakeLists.txt create mode 100644 src/graph/Connector.cpp create mode 100644 src/graph/GraphView.cpp create mode 100644 src/graph/Node.cpp create mode 100644 src/graph/OpArgs.cpp create mode 100644 src/graphmatching/GRegex.cpp create mode 100644 src/graphmatching/Match.cpp create mode 100644 src/graphmatching/NodeRegex.cpp create mode 100755 src/graphmatching/SeqStm.cpp create mode 100644 src/graphmatching/StmFactory.cpp create mode 100644 src/operator/Operator.cpp create mode 100644 src/recipies/FuseMulAdd.cpp create mode 100644 src/recipies/RemoveFlatten.cpp create mode 100644 src/scheduler/Scheduler.cpp create mode 100644 tests/CMakeLists.txt create mode 100644 tests/graph/Test_Connector.cpp create mode 100644 tests/graph/Test_GraphView.cpp create mode 100644 tests/graphMatching/Test_GRegex.cpp create mode 100644 tests/graphMatching/Test_NodeRegex.cpp create mode 100644 tests/graphMatching/Test_SeqStm.cpp create mode 100644 tests/graphMatching/Test_StmFactory.cpp create mode 100644 tests/operator/Test_GenericOperator.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..7cea83d97 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,52 @@ + +if (BUILD_CORE_ALONE) + project(Aidge_Core) + cmake_minimum_required(VERSION 3.11) + add_compile_options(-Wall -Wextra -fPIC) +endif() + +if (PYBIND) + add_definitions(-DPYBIND) + Include(FetchContent) + + FetchContent_Declare( + PyBind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.10.4 # or a later release + ) + + FetchContent_MakeAvailable(PyBind11) + file(GLOB_RECURSE pybind_src_files "python_binding/*.cpp") + pybind11_add_module(aidge_core MODULE ${pybind_src_files} "NO_EXTRAS") + target_include_directories(aidge_core PUBLIC ${pybind11_INCLUDE_DIRS} "python_binding") + target_link_libraries(aidge_core PUBLIC core) + # generate_python_binding(aidge_core core) +endif() + +add_library(core STATIC) + +# Add include directories +target_include_directories(core PUBLIC "include") + +# Containers module +file(GLOB_RECURSE src_files "src/*.cpp") +target_sources(core PRIVATE ${src_files}) + +set_property(TARGET core PROPERTY POSITION_INDEPENDENT_CODE ON) + +if (PYBIND) + target_include_directories(core PUBLIC $<BUILD_INTERFACE:${PYTHON_INCLUDE_DIRS}>) + target_link_libraries(core PRIVATE ${PYTHON_LIBRARIES}) +endif() + +if (NOT BUILD_CORE_ALONE) + # Activate compile time reducer for aidge_core + set_target_properties(core PROPERTIES COTIRE_ADD_UNITY_BUILD FALSE) + # set_target_properties(n2d2_cpu_lib PROPERTIES COTIRE_CXX_PREFIX_HEADER_INIT "include/utils/Precompiled.hpp") + cotire(core) +endif() + + +if (TESTS) + add_subdirectory(tests) +endif() \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..d10afab29 --- /dev/null +++ b/Makefile @@ -0,0 +1,36 @@ +# This makefile does nothing but delegating the actual building to cmake +BUILDDIR := build +MAKEFLAGS := --no-print-directory + +all: core_with_pybind + +core_only: + mkdir -p ${BUILDDIR}; \ + cd ${BUILDDIR}; \ + cmake -DBUILD_CORE_ALONE=ON -DCMAKE_BUILD_TYPE=Release -DPYBIND=OFF -DTESTS=OFF ..; \ + ${MAKE} ${MAKEFLAGS}; + +core_tests: + mkdir -p ${BUILDDIR}; \ + cd ${BUILDDIR}; \ + cmake -DBUILD_CORE_ALONE=ON -DCMAKE_BUILD_TYPE=Debug -DPYBIND=OFF -DTESTS=ON ..; \ + ${MAKE} ${MAKEFLAGS}; \ + cd tests; \ + ctest --output-on-failure || true; + +core_with_pybind: + mkdir -p ${BUILDDIR}; \ + cd ${BUILDDIR}; \ + cmake -DBUILD_CORE_ALONE=ON -DCMAKE_BUILD_TYPE=Release -DPYBIND=ON -DTESTS=OFF ..; \ + ${MAKE} ${MAKEFLAGS}; + +core_with_pybind_tests: + mkdir -p ${BUILDDIR}; \ + cd ${BUILDDIR}; \ + cmake -DBUILD_CORE_ALONE=ON -DCMAKE_BUILD_TYPE=Debug -DPYBIND=ON -DTESTS=ON ..; \ + ${MAKE} ${MAKEFLAGS}; \ + cd tests; \ + ctest --output-on-failure || true; + +clean: + if [ -d "${BUILDDIR}" ]; then rm -rf ${BUILDDIR}; fi \ No newline at end of file diff --git a/README.md b/README.md index 8223049ac..70ca91e42 100644 --- a/README.md +++ b/README.md @@ -1,92 +1,28 @@ -# aidge_core +# Aidge Core library +You can find here the C++ code of the Core library of Aidge. +## Compilation -## Getting started - -To make it easy for you to get started with GitLab, here's a list of recommended next steps. - -Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)! - -## Add your files - -- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files -- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command: - +To only compile the Core library, run ``` -cd existing_repo -git remote add origin https://git-dscin.intra.cea.fr/aidge/aidge_core.git -git branch -M main -git push -uf origin main +make core_only ``` -## Integrate with your tools - -- [ ] [Set up project integrations](https://git-dscin.intra.cea.fr/aidge/aidge_core/-/settings/integrations) - -## Collaborate with your team - -- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/) -- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html) -- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically) -- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/) -- [ ] [Automatically merge when pipeline succeeds](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html) - -## Test and Deploy - -Use the built-in continuous integration in GitLab. - -- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html) -- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing(SAST)](https://docs.gitlab.com/ee/user/application_security/sast/) -- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html) -- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/) -- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html) - -*** - -# Editing this README - -When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thank you to [makeareadme.com](https://www.makeareadme.com/) for this template. - -## Suggestions for a good README -Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information. - -## Name -Choose a self-explaining name for your project. - -## Description -Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors. - -## Badges -On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge. - -## Visuals -Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method. - -## Installation -Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection. - -## Usage -Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README. - -## Support -Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc. - -## Roadmap -If you have ideas for releases in the future, it is a good idea to list them in the README. - -## Contributing -State if you are open to contributions and what your requirements are for accepting them. - -For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self. +To compile the Core library + the associated unitary tests, run +``` +make core_tests +``` -You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser. +To compile the Core library with the python binding, run +``` +make core_with_pybind +``` +Important: this command can also be run with `make`. -## Authors and acknowledgment -Show your appreciation to those who have contributed to the project. -## License -For open source projects, say how it is licensed. +To compile the Core library with the python binding + the associated unitary tests, run +``` +make core_with_pybind_tests +``` -## Project status -If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers. diff --git a/include/backend/OperatorImpl.hpp b/include/backend/OperatorImpl.hpp new file mode 100644 index 000000000..a2c97c607 --- /dev/null +++ b/include/backend/OperatorImpl.hpp @@ -0,0 +1,60 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_OPERATORIMPL_H__ +#define __AIDGE_OPERATORIMPL_H__ + +#include <cstddef> +#include <vector> +#include "utils/Types.h" + +namespace Aidge { +class OperatorImpl { +public: + virtual void forward(){}; + virtual void backward() {} + + /** + * @brief Minimum amount of data from a specific input required by the + * implementation to be run. + * + * @param inputIdx Index of the input analysed. + * @return std::size_t + */ + virtual NbElts_t getNbRequiredData(IOIndex_t inputIdx) const = 0; + + // Amount of input data that cannot be overwritten during the execution. + virtual NbElts_t getNbRequiredProtected(IOIndex_t inputIdx) const = 0; + + // Memory required at an output for a given input size. + virtual NbElts_t getRequiredMemory(IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0; + + /** + * @brief Total amount of consumed data from a specific input. + * + * @param inputIdx Index of the input analysed. + * @return DimSize_t + */ + virtual NbElts_t getNbConsumedData(IOIndex_t inputIdx) const = 0; + + /** + * @brief TOtal amount of produced data ready to be used on a specific output. + * + * @param outputIdx Index of the output analysed. + * @return DimSize_t + */ + virtual NbElts_t getNbProducedData(IOIndex_t outputIdx) const = 0; + + virtual ~OperatorImpl() = default; +}; +} // namespace Aidge + +#endif /* __AIDGE_OPERATORIMPL_H__ */ diff --git a/include/backend/TensorImpl.hpp b/include/backend/TensorImpl.hpp new file mode 100644 index 000000000..f4c38d59b --- /dev/null +++ b/include/backend/TensorImpl.hpp @@ -0,0 +1,41 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_TENSORIMPL_H__ +#define __AIDGE_TENSORIMPL_H__ + +#include <cstddef> +#include <cstdio> +#include "utils/Types.h" + +namespace Aidge { +class TensorImpl { +public: + TensorImpl() = delete; + TensorImpl(const char *backend) : mBackend(backend){}; + virtual void copy(const void *src, NbElts_t length) = 0; + virtual void *rawPtr() = 0; + virtual void setRawPtr(void* /*ptr*/) + { + printf("Cannot set raw pointer for backend %s\n", mBackend); + }; + virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) + constexpr const char *backend() const { return mBackend; } + virtual ~TensorImpl() = default; + virtual bool operator==(const TensorImpl &othImpl) const = 0; + +private: + const char *mBackend; +}; + +} // namespace Aidge + +#endif /* __AIDGE_TENSORIMPL_H__ */ diff --git a/include/data/Data.hpp b/include/data/Data.hpp new file mode 100644 index 000000000..ddf3c3f1b --- /dev/null +++ b/include/data/Data.hpp @@ -0,0 +1,75 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_DATA_H__ +#define __AIDGE_DATA_H__ + +#include "utils/Parameter.hpp" + +namespace Aidge { +enum class DataType { + Float64, + Float32, + Float16, + BFloat16, + Binary, + Ternary, + Int2, + Int3, + Int4, + Int5, + Int6, + Int7, + Int8, + Int16, + Int32, + Int64, + UInt2, + UInt3, + UInt4, + UInt5, + UInt6, + UInt7, + UInt8, + UInt16, + UInt32, + UInt64 +}; + +class Data { +public: + constexpr Data(const char* type): mType(type) {}; + constexpr const char* type() const { + return mType; + } + virtual ~Data() = default; + +private: + const char* mType; +}; +} + +namespace { +template <typename T> struct NativeType { static const Aidge::DataType type; }; +template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64; +template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32; +template <> const Aidge::DataType NativeType<long>::type = Aidge::DataType::Int64; +template <> const Aidge::DataType NativeType<int>::type = Aidge::DataType::Int32; + +template <> +const char* const EnumStrings<Aidge::DataType>::data[] + = {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Ternary", + "Int2", "Int3", "Int4", "Int5", "Int6", "Int7", "Int8", "Int16", + "Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6", + "UInt7", "UInt8", "UInt16", "UInt32", "UInt64"}; +} + +#endif /* __AIDGE_DATA_H__ */ \ No newline at end of file diff --git a/include/data/Tensor.hpp b/include/data/Tensor.hpp new file mode 100644 index 000000000..30aa14462 --- /dev/null +++ b/include/data/Tensor.hpp @@ -0,0 +1,473 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_DATA_TENSOR_H__ +#define __AIDGE_CORE_DATA_TENSOR_H__ + +#include <cstring> +#include <set> +#include <memory> +#include <numeric> +#include <string> +#include <vector> + +#include "backend/TensorImpl.hpp" +#include "data/Data.hpp" +#include "utils/Registrar.hpp" +#include "utils/Types.h" + +namespace Aidge { + +// Helper to create default arrays +template <typename T, std::size_t ... Is> +constexpr std::array<T, sizeof...(Is)> +create_array_impl(T value, std::index_sequence<Is...>) +{ + // cast Is to void to remove the warning: unused value + return {{(static_cast<void>(Is), value)...}}; +} + +template <typename T, std::size_t N> +constexpr std::array<T, N> create_array(const T& value) +{ + return create_array_impl(value, std::make_index_sequence<N>()); +} + + +// Helper to convert vector to array +template <typename T, typename Iter, std::size_t... Is> +constexpr auto to_array(Iter &iter, std::index_sequence<Is...>) -> std::array<T, sizeof...(Is)> { + return {{((void)Is, T(*iter++))...}}; +} + +/** + * @brief Convert an object with an iterator to an std::array. + */ +template <std::size_t N, typename U = void, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type, + typename T = std::conditional_t<std::is_same<U, void>{}, V, U>> +constexpr auto to_array(Iter iter) -> std::array<T, N> { + return to_array<T>(iter, std::make_index_sequence<N>{}); +} + +namespace detail { + +template <class T, std::size_t N, std::size_t... I> +constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], std::index_sequence<I...>) { + return {{a[I]...}}; +} + +} // namespace detail + +/** + * @brief Convert a C-stype array into a C++ std::array. + * + * @tparam T Data type. + * @tparam N Number of elements. + * @param a C-style array to convert. + * @return constexpr std::array<std::remove_cv_t<T>, N> + */ +template <class T, std::size_t N> +constexpr std::array<std::remove_cv_t<T>, N> to_array(T (&a)[N]) { + return detail::to_array_impl(a, std::make_index_sequence<N>{}); +} + +template <typename T, std::size_t N, std::size_t... I> +constexpr std::array<T, N + 1> append(std::array<T, N> a, T t, std::index_sequence<I...>) { + return std::array<T, N + 1>{a[I]..., t}; +} + +template <typename T, std::size_t N, std::size_t... I> +constexpr std::array<T, N + 1> append(T t, std::array<T, N> a, std::index_sequence<I...>) { + return std::array<T, N + 1>{t, a[I]...}; +} + +/** + * @brief Create a new array concatenating the initial one with the value to + * add. + * @details append([1,2,7], 3) -> [1,2,7,3] + * + * @tparam T Data type. + * @tparam N Number of elements in the initilial array. + * @param a Initial array. + * @param t Element to add. + * @return constexpr std::array<T, N + 1> + */ +template <typename T, std::size_t N> +constexpr std::array<T, N + 1> append(std::array<T, N> a, T t) { + return append(a, t, std::make_index_sequence<N>()); +} + +template <typename T, std::size_t N> +constexpr std::array<T, N + 1> append(T t, std::array<T, N> a) { + return append(t, a, std::make_index_sequence<N>()); +} + +// Generic helper for initializing a Tensor +template <typename T, std::size_t SIZE_0> +struct Array1D { + T data[SIZE_0]; +}; + +template <typename T, std::size_t SIZE_0, std::size_t SIZE_1> +struct Array2D { + T data[SIZE_0][SIZE_1]; +}; + +template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2> +struct Array3D { + T data[SIZE_0][SIZE_1][SIZE_2]; +}; + +template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3> +struct Array4D { + T data[SIZE_0][SIZE_1][SIZE_2][SIZE_3]; +}; + +class Tensor : public Data, + public Registrable<Tensor, std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor &)> { + private: + DataType mDataType; + std::vector<DimSize_t> mDims; + std::unique_ptr<TensorImpl> mImpl; + std::shared_ptr<Tensor> mGrad; + + // Cached data + std::size_t mSize; // number of elements in the tensor + std::size_t mSizeM1; // for a tensor of N dimensions, number of elements in the N-1 + // first dimensions + + public: + static constexpr const char *Type = "Tensor"; + + Tensor(DataType dataType = DataType::Float32) : Data(Type), mDataType(dataType), mDims({}), mSize(0), mSizeM1(0) { + // ctor + } + Tensor(const Tensor& otherTensor) + : Data(Type), + mDataType(otherTensor.mDataType), + mDims(otherTensor.mDims), + mSize(otherTensor.mSize), + mSizeM1(otherTensor.mSizeM1) + { + if (otherTensor.hasImpl()) { + mImpl = Registrar<Tensor>::create({otherTensor.mImpl->backend(), dataType()})(*this); + mImpl->copy(otherTensor.mImpl->rawPtr(), mSize); + } + } + + template <typename T, std::size_t SIZE_0> + constexpr Tensor(Array1D<T, SIZE_0> &&arr) + : Data(Type), + mDataType(NativeType<T>::type), + mDims({SIZE_0}), + mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), + mSize(SIZE_0), + mSizeM1(SIZE_0) { + mImpl->copy(&arr.data[0], SIZE_0); + } + + template <typename T, std::size_t SIZE_0> + constexpr Tensor &operator=(Array1D<T, SIZE_0> &&arr) { + resize({SIZE_0}); + if (!mImpl) { + mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); + } + mImpl->copy(&arr.data[0], SIZE_0); + return *this; + } + template <typename T, std::size_t SIZE_0, std::size_t SIZE_1> + constexpr Tensor(Array2D<T, SIZE_0, SIZE_1> &&arr) + : Data(Type), + mDataType(NativeType<T>::type), + mDims({SIZE_0, SIZE_1}), + mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), + mSize(SIZE_0 * SIZE_1), + mSizeM1(SIZE_1) { + mImpl->copy(&arr.data[0][0], SIZE_0 * SIZE_1); + } + + template <typename T, std::size_t SIZE_0, std::size_t SIZE_1> + constexpr Tensor &operator=(Array2D<T, SIZE_0, SIZE_1> &&arr) { + resize({SIZE_0, SIZE_1}); + if (!mImpl) { + mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); + } + mImpl->copy(&arr.data[0][0], SIZE_0 * SIZE_1); + return *this; + } + template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2> + constexpr Tensor(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) + : Data(Type), + mDataType(NativeType<T>::type), + mDims({SIZE_0, SIZE_1, SIZE_2}), + mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), + mSize(SIZE_0 * SIZE_1 * SIZE_2), + mSizeM1(SIZE_1 * SIZE_2) { + mImpl->copy(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); + } + + template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2> + constexpr Tensor &operator=(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) { + resize({SIZE_0, SIZE_1, SIZE_2}); + if (!mImpl) { + mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); + } + mImpl->copy(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); + return *this; + } + template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3> + constexpr Tensor(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) + : Data(Type), + mDataType(NativeType<T>::type), + mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}), + mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), + mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3), + mSizeM1(SIZE_1 * SIZE_2 * SIZE_3) { + mImpl->copy(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); + } + + template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3> + constexpr Tensor &operator=(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) { + resize({SIZE_0, SIZE_1, SIZE_2, SIZE_3}); + if (!mImpl) { + mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); + } + mImpl->copy(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); + return *this; + } + + Tensor &operator=(const Tensor &t) { + resize(t.dims()); + setDatatype(t.dataType()); + if (t.hasImpl()) { + setBackend(t.mImpl->backend()); + mImpl->copy(t.mImpl->rawPtr(), size()); + } + else { + mImpl = nullptr; + } + return *this; + } + + bool operator==(const Tensor &otherTensor) const { + if ((!mImpl && !otherTensor.mImpl) || (dataType() != otherTensor.dataType()) || + (dims() != otherTensor.dims()) || (mImpl->backend() != otherTensor.mImpl->backend())) { + return false; + } + return *mImpl == *(otherTensor.mImpl); + } + + inline void setBackend(const std::string &name) { + if (mImpl) { + if (strcmp(mImpl->backend(), name.c_str()) != 0) { + // Backend change: create new impl, copy from old to new and replace + // impl + std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this); + newImpl->copy(mImpl->rawPtr(), size()); + mImpl = std::move(newImpl); + } + } else + mImpl = Registrar<Tensor>::create({name, mDataType})(*this); + } + static std::set<std::string> getAvailableBackends(){ + std::set<std::string> backendsList; + for(std::tuple<std::string, DataType> tupleKey : Registrar<Tensor>::getKeys()) + backendsList.insert(std::get<0>(tupleKey)); + return backendsList; + } + + constexpr DataType dataType() const { return mDataType; } + + /** + * @brief Set the DataType of the Tensor and converts data + * if the Tensor has already been initialized. + * @param dt DataType. + */ + void setDatatype(const DataType dt) { + if (mImpl && (dataType() != dt)) { + // get ptr before changing Tensor backend or the type difference will trigger a warning + const void *data = mImpl->rawPtr(); + mDataType = dt; + std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this); + newImpl->copy(data, size()); // /!\ it does not cast data but reinterpret them + mImpl = std::move(newImpl); + } + mDataType = dt; + } + + constexpr const std::unique_ptr<TensorImpl> &getImpl() { return mImpl; } + + bool hasImpl() const + { + return (mImpl) ? true : false; + } + + inline std::size_t nbDims() const { return mDims.size(); } + + template <DimIdx_t DIM> + constexpr std::array<DimSize_t, DIM> dims() const { + assert(DIM == mDims.size() && "wrong number of dimensions"); + return to_array<DIM>(mDims.cbegin()); + } + + constexpr const std::vector<DimSize_t> &dims() const { return mDims; } + + constexpr std::size_t size() const { return mSize; } + + constexpr std::size_t sizeM1() const { return mSizeM1; } + +// deducing std::array size_type and declaring DIM accordingly + template <std::array<DimSize_t, 1>::size_type DIM> + void resize(const std::array<DimSize_t, DIM> &dims) { + static_assert(DIM<=MaxDim,"Too many tensor dimensions required by resize, not supported"); + mDims.assign(dims.begin(), dims.end()); + computeSize(); + } + void resize(const std::vector<DimSize_t> &dims) { + mDims = dims; + computeSize(); + } + bool empty() const { return mDims.empty(); } + template <typename expectedType, std::array<std::size_t, 1>::size_type DIM> + constexpr expectedType &get(std::array<std::size_t, DIM> idx) { + assert(DIM == mDims.size()); + assert(mImpl); + std::size_t unfoldedIdx = 0; + for (std::size_t i = 0; i < DIM - std::size_t(1); ++i) { + unfoldedIdx = (unfoldedIdx + idx[i]) * mDims[i + 1]; + } + unfoldedIdx += idx[DIM - 1]; + return static_cast<expectedType *>(mImpl->rawPtr())[unfoldedIdx]; + } + + std::string toString() { + std::string res; + std::size_t dim = 0; + std::size_t *dimVals = new std::size_t[nbDims()]; + for (std::size_t i = 0; i < nbDims(); ++i) { + dimVals[i] = 0; + } + std::size_t counter = 0; + res += "{\n"; + if (nbDims()>=2){ + while (counter < mSize) { + std::string spaceString = std::string((dim+1)<<1,' '); + if (dim < nbDims()-2) { + if (dimVals[dim] == 0) { + res += spaceString + "{\n"; + ++dim; + } else if (dimVals[dim] < static_cast<std::size_t>(dims()[dim])) { + res += spaceString + "},\n" + spaceString + "{\n"; + ++dim; + } else { + res += spaceString + "}\n"; + dimVals[dim--] = 0; + dimVals[dim]++; + } + } else { + for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { + res += spaceString + "{"; + for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { + switch (mDataType) + { + case DataType::Int32: + res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + ","; + break; + case DataType::Float64: + res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + ","; + break; + default: + res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + ","; + break; + } + } + switch (mDataType) + { + case DataType::Int32: + res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + "}"; + break; + case DataType::Float64: + res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + "}"; + break; + default: + res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + "}"; + break; + } + if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { + res += ","; + } + res += "\n"; + } + dimVals[dim--] = 0; + dimVals[dim]++; + } + } + for(int i = static_cast<int>(dim); i>=0; --i) { + res += std::string((dim+1)<<1,' ') + "}\n"; + } + }else{ + for (DimSize_t j = 0; j < dims()[0]; ++j) { + switch (mDataType) + { + case DataType::Int32: + res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + break; + case DataType::Float64: + res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + break; + default: + res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + break; + } + } + } + + + res += "}"; + return res; + } + + inline void print() { printf("%s\n", toString().c_str()); } + + std::shared_ptr<Tensor> grad() { + if (!mGrad) { + mGrad = std::make_shared<Tensor>(mDataType); + mGrad->resize(mDims); + + if (mImpl) mGrad->setBackend(mImpl->backend()); + } + + return mGrad; + } + +private: + ///\bug not protected against overflow, see ThaliaCommonPack for a solution + std::size_t computeSize() { + if (mDims.empty()) { + mSizeM1 = DimSize_t(0); + mSize = DimSize_t(0); + } + else if (mDims.size() == 1) + { + mSizeM1 = mDims[0]; + mSize = mDims[0]; + } + else { + mSizeM1 = std::accumulate(++mDims.begin(),mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>()); + mSize = static_cast<std::size_t>(mSizeM1 * mDims[0]); + } + + return mSize; + } +}; +} // namespace Aidge + +#endif /* __AIDGE_CORE_DATA_TENSOR_H__ */ diff --git a/include/graph/Connector.hpp b/include/graph/Connector.hpp new file mode 100644 index 000000000..5ab5651da --- /dev/null +++ b/include/graph/Connector.hpp @@ -0,0 +1,86 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#ifndef __AIDGE_CORE_GRAPH_CONNECTOR_H__ +#define __AIDGE_CORE_GRAPH_CONNECTOR_H__ + +#include <cassert> +#include <memory> +#include <vector> + +#include "utils/Types.h" + +namespace Aidge { + +class Node; +class GraphView; +/** + * @brief Object meant for simpler and more instrinctive user API. + * + * example: + * Connector x(); + * x = Conv(...)(x); + * Connector y = Split(3)(x[0]); // Error! Cannot slice a Connector with one output only + * Connector y = Split(3)(x); + * CustomLayer cl(...); + * Connector z = cl(y) // Error! y has multiple outputs, must specify which one to use + * Connector z1 = cl(y[0]); + * Connector z2 = cl(y[1]); + * Connector z3 = cl(y[2]); + * x = Sum(...)(z1, z2, z3); + * GraphView g = x.generateGraph(); + */ +class Connector { + private: + std::shared_ptr<Node> mNode; + ///\brief output id + ///\details gk_IODefaultIndex is reserved for? + ///\bug Is negative value pertinent? + IOIndex_t mOutputId = gk_IODefaultIndex; + + public: + Connector() : mNode(nullptr) { + // ctor + } + Connector(std::shared_ptr<Node> node); + + ~Connector() = default; + + public: + Connector operator[](IOIndex_t index) { + assert((size() > 1) && "Cannot refer a slice of the output."); + return Connector(mNode, index); + } + + public: + IONb_t size() const; + + inline std::shared_ptr<Node> node() const { return mNode; } + + inline IOIndex_t index() const { return mOutputId; } + + private: + Connector(std::shared_ptr<Node> node, IOIndex_t index) : mNode(node) { + assert((index >= 0) && (static_cast<IONb_t>(index) < size()) && + "Non-valid output index.\n"); + mOutputId = index; + } +}; + +/** + * @brief Generate a GraphView from a list of output Connectors + * + * @param ctors list of output Connector for the graph to generate. + * @return std::shared_ptr<GraphView> + */ +std::shared_ptr<GraphView> generateGraph(std::vector<Connector> ctors); +} // namespace Aidge + +#endif /* __AIDGE_CORE_GRAPH_CONNECTOR_H__ */ \ No newline at end of file diff --git a/include/graph/GraphView.hpp b/include/graph/GraphView.hpp new file mode 100644 index 000000000..b258f78db --- /dev/null +++ b/include/graph/GraphView.hpp @@ -0,0 +1,334 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_GRAPH_GRAPHVIEW_H__ +#define __AIDGE_CORE_GRAPH_GRAPHVIEW_H__ + +#include <map> +#include <memory> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "graph/Node.hpp" +#include "utils/Types.h" + +namespace Aidge { +enum class DataType; +class GraphView : public std::enable_shared_from_this<GraphView> { +private: + /// @brief Name of the graphview + std::string mName; + + /// @brief Set of nodes included in the GraphView + std::set<NodePtr> mNodes; + + /// @brief Set of nodes included in the graphview with names + std::map<std::string, NodePtr> mNodeRegistry; + + /// @brief Nodes without input link + std::set<NodePtr> mInputNodes; + + /// @brief Nodes without output link + std::set<NodePtr> mOutputNodes; + +public: + GraphView(std::string name="") + : mName(name) + { + // ctor + } + + GraphView(std::set<NodePtr> nodes, std::string name="") + : mName(name) + { + add(nodes); + } + + bool operator==(const GraphView &gv) const + { + return mNodes == gv.mNodes; + } + + NodePtr operator[](std::string name) + { + assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView."); + return mNodeRegistry.at(name); + } + +/////////////////////////////////////////////////////// +// FUNCTIONAL DESCRIPTION +/////////////////////////////////////////////////////// + + Connector operator()(const std::vector<Connector> ctors); + +/////////////////////////////////////////////////////// +// INNER +/////////////////////////////////////////////////////// +public: + /** + * @brief Name of the node. + * @return std::string + */ + std::string name() const; + + /** + * @brief Set the node name. + * @warning Undefined behaviour when several Nodes have the same name. + * @param name New name for the node. + */ + void setName(const std::string &name); + + /** + * @brief Save the GraphView as a Mermaid graph in a .md file at the + * specified location. + * @param path + */ + void save(std::string path, bool verbose = false) const; + + inline bool inView(NodePtr nodePtr) const { + return mNodes.find(nodePtr) != mNodes.end(); + } + +/////////////////////////////////////////////////////// +// TENSOR MANAGEMENT +/////////////////////////////////////////////////////// +public: + inline std::set<NodePtr> inputNodes() const noexcept { return mInputNodes; } + inline std::set<NodePtr> outputNodes() const noexcept { return mOutputNodes; } + + inline bool isInputNode(NodePtr nodePtr) const { + return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false; + } + inline bool isOutputNode(NodePtr nodePtr) const { + return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false; + } + + /** + * @brief List data input Tensors of the graph input nodes. + * @return std::vector<std::pair<NodePtr, IOIndex_t>> + */ + std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; + + /** + * @brief List data input Tensors of the graph input nodes. + * @param name Name of the Node. + * @return std::vector<std::pair<NodePtr, IOIndex_t>> + */ + inline auto dataInputs(std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } + + /** + * @brief List input Tensors of the graph input nodes. + * @return std::vector<std::pair<NodePtr, IOIndex_t>> + */ + std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; + + std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const; + + /** + * @brief List output Tensors of the node. + * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> + */ + std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; + + /** + * @brief Specific i-th output Tensor of the GraphView. + * @param nodeName Name of the Node of which to show the output. + * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> + */ + std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs( + std::string nodeName) const; + + void forwardDims(); + + void setBackend(const std::string &backend); + void setDatatype(const DataType &datatype); + +/////////////////////////////////////////////////////// +// TOPOLOGY +/////////////////////////////////////////////////////// +public: + /** + * @brief Get the Parents of inputNodes. + * @return std::vector<NodePtr> + */ + std::set<NodePtr> getParents() const; + std::vector<NodePtr> getParents(const std::string nodeName) const; + std::vector<std::vector<NodePtr>> getOrderedParents() const; + + /** + * @brief Get the Children of outputNodes. + * @return std::set<NodePtr> + */ + std::set<NodePtr> getChildren() const; + std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const; + std::set<NodePtr> getChildren( + const NodePtr otherNode) const; // TODO change it for a vector<vector> ? + + /** + * @brief Getter for Operators of the GraphView. + * @return std::set<NodePtr> + */ + inline std::set<NodePtr> getNodes() const { return mNodes; } + + /** + * @brief Get the operator with the corresponding name if it is in the + * GraphView. + * @param nodeName name of the node. + * @return NodePtr return a new empty node if the one asked for + * was not found. + */ + NodePtr getNode(const char *nodeName) const; + + /** + * @brief Remove a Node from the current GraphView scope without affecting its connections + * @param nodePtr Node to remove + * @param includeLearnableParam Whether learnable parameters should also be removed. Default true. + */ + void remove(NodePtr nodePtr, bool includeLearnableParam = true); + + // Surrounding nodes management + + void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID); + + /** + * @brief Includes a Node to the current GraphView + * @param other_node Node to add. + * @param includeLearnableParam Should non-data inputs, like weights and biases + * be included in the GraphView automatically. Default: true. + */ + void add(NodePtr otherNode, bool includeLearnableParam = true); + void add(std::set<NodePtr> otherNodes, + bool includeLearnableParam = true); + + /** + * @brief Include every Node inside another GraphView to the current + * GraphView. + * @param other_graph GraphView containing the Nodes to include. + */ + void add(std::shared_ptr<GraphView> otherGraph); + + /** + * @brief Include a Node in the current GraphView and link it to another + * already contained Node. + * + * @param toOtherNode Pointer to the Node to add. + * @param fromOutNode Pointer to the already included Node the new Node will + * be linked to (it will become a parent of the new Node). If the GraphView + * only has one output Node, then default to this Node. + * @param fromTensor Ouput Tensor ID of the already included Node. Default to + * 0. + * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning + * first available data input for the Node. + */ + void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr, + const IOIndex_t fromTensor = IOIndex_t(0), + IOIndex_t toTensor = gk_IODefaultIndex); + + /** + * @brief Include a Node in the current GraphView and link it to another + * already contained Node. + * + * @param toOtherNode Pointer to the Node to add. + * @param fromOutNodeName Name of the already included Node the new Node will + * be linked to (it will become a parent of the new Node). As a name is + * optional, ensure such Node is in the GraphView or it will send back an + * error message. + * @param fromTensor Ouput Tensor ID of the already included Node. Default to + * 0. + * @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning + * first available data input for the Node. + */ + inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName, + const IOIndex_t fromTensor = IOIndex_t(0), + IOIndex_t toTensor = gk_IODefaultIndex) { + assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() && + "No Node with this name found in the GraphView."); + addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor); + } + + /** + * @brief Include a GraphView content in the current GraphView and link + * the two sets by linking one Node from each GraphView. + * @param toOtherView Pointer to the GraphView whose content should be added. + * @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the + * connection. If the GraphView including the other one has only one output + * Node, then it defaults to the first output Tensor of this Node. + * @param toNode Pair of pointer to Node and Tensor ID for specifying the + * connection. If the GraphView whose content is included has only one input + * Node, then it defaults to the first available data input Tensor of this + * Node. + */ + void addChild(std::shared_ptr<GraphView> toOtherView, + std::pair<NodePtr, IOIndex_t> fromOutNode = + std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)), + std::pair<NodePtr, IOIndex_t> toNode = + std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex)); + + /** + * @brief Swap two Node instances if possible. + * @param node + * @param otherNode + * @return true + * @return false + */ + bool swap(Node &node, Node &otherNode); + + void link(std::string name1_inID, std::string name2_outID); + + void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes, + IOIndex_t tensorIdx); + + /** + * @brief Replace the current GraphView with the set of given Nodes if possible + * @param newNodes Set of Nodes. + * @return true + * @return false + */ + bool replaceWith(std::set<NodePtr> newNodes); + void updateInputNodes(); + /** + * @brief Process from zero the set of output Nodes. + */ + void updateOutputNodes(); + +private: +/////////////////////////////////////////////////////// +// TENSOR MANAGEMENT +/////////////////////////////////////////////////////// + + IONb_t getNbDataInputs() const; + + IONb_t getNbFreeDataInputs() const; + + + void updateInputNodes(NodePtr node); + + /** + * @brief Update the set of output Nodes with a new Node,checking if it can be + * added and removing any Node not part of mOutputNode anymore. + * @param nodePtr + */ + void updateOutputNodes(NodePtr node); + + /////////////////////////////////////////////////////// + // TOPOLOGY + /////////////////////////////////////////////////////// + + void _forwardDims(std::set<NodePtr> listNodes); + + void removeInputNode(const std::string nodeName); + void removeOutputNode(const std::string nodeName); +}; +} // namespace Aidge + +#endif /* __AIDGE_CORE_GRAPH_GRAPHVIEW_H__ */ \ No newline at end of file diff --git a/include/graph/Node.hpp b/include/graph/Node.hpp new file mode 100644 index 000000000..94a63becb --- /dev/null +++ b/include/graph/Node.hpp @@ -0,0 +1,360 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_GRAPH_NODE_H__ +#define __AIDGE_CORE_GRAPH_NODE_H__ + +#include <cassert> +#include <memory> +#include <set> +#include <string> +#include <vector> +#include <utility> + +#include "graph/Connector.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" + +namespace Aidge { + +using NodePtr = std::shared_ptr<Node>; + +class GraphView; + +class Node : public std::enable_shared_from_this<Node> { +private: + std::string mName; // Name of the Node. Should be unique + + std::set<std::shared_ptr<GraphView>> mViews = + std::set<std::shared_ptr<GraphView>>(); // Set of pointers to GraphView + // instances including this Node + // instance + const std::shared_ptr<Operator> + mOperator; // Pointer to the associated Operator + + std::vector<NodePtr> + mParents; // List of parent nodes (Parent --> Node --> Child) + std::vector<std::vector<NodePtr>> + mChildren; // List of child nodes for each output (Parent --> Node --> + // Child) + std::vector<std::vector<IOIndex_t>> mIdInChildren; // InID of Child node. + std::vector<IOIndex_t> mIdOutParents; // OutID of Parent node. Default: gk_IODefaultIndex. + +public: + Node() = delete; + Node(std::shared_ptr<Operator> op, const char *name = nullptr); + + virtual ~Node() = default; + + friend bool operator==(const Node &lhs, const Node &rhs) { + return lhs.shared_from_this() == rhs.shared_from_this(); + } + +public: + /////////////////////////////////////////////////////// + // FUNCTIONAL DESCRIPTION + /////////////////////////////////////////////////////// + + Connector operator()(const std::vector<Connector> ctors); + +public: + /////////////////////////////////////////////////////// + // INNER + /////////////////////////////////////////////////////// + + /** + * @brief Name of the node. + * @return std::string + */ + inline std::string name() const noexcept { return mName; } + + /** + * @brief Set the node name. + * @warning Undefined behaviour when several Nodes have the same name. + * @param name New name for the node. + */ + void setName(const std::string &name); + + /** + * @brief Type of the node. + * @return std::string + */ + inline std::string type() const { return mOperator->type(); } + + /////////////////////////////////////////////////////// + // OPERATORS + /////////////////////////////////////////////////////// + + /** + * @brief Run forward() function of the associated Operator + */ + void forward(); + + /** + * @brief Run backward() function of the associated Operator + */ + void backward(); + + /** + * @brief Get the Operator object of the Node + * @return std::shared_ptr<Operator> + */ + inline std::shared_ptr<Operator> getOperator() const { return mOperator; } + + /////////////////////////////////////////////////////// + // TENSOR MANAGEMENT + /////////////////////////////////////////////////////// + + /** + * @brief Whether or not every input of the Node is linked to a Tensor. + * If true then the Node is ready to be executed. + * @return true + * @return false + */ + bool valid() const; + + /** + * @brief List of pair <Parent, ID of the data intput>. When an input is not + * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * @return std::vector<std::pair<NodePtr, IOIndex_t>> + */ + std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; + + /** + * @brief List of pair <Parent, ID of the intput>. When an input is not linked + * to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * @return std::vector<std::pair<NodePtr, IOIndex_t>> + */ + std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; + + /** + * @brief Parent and its output Tensor ID linked to the inID-th input Tensor. + * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * @param inID + * @return std::pair<NodePtr, IOIndex_t> + */ + inline std::pair<NodePtr, IOIndex_t> input(IOIndex_t inID) const { + assert((inID != gk_IODefaultIndex) && (static_cast<IONb_t>(inID) < nbInputs()) && "Input index out of bound."); + return std::pair<NodePtr, IOIndex_t>(mParents[inID], + mIdOutParents[inID]); + } + + /** + * @brief Set fix value for the specified input by creating a Producer wrapping the given Tensor. + * + * @param idx input index + * @param tensor constant tensor to add as parent for specified index. + */ + void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor); + + /** + * @brief Get the lowest index in the input Data Parent list equal to the + * nullptr. + * @return std::size_t + */ + inline IOIndex_t getFirstFreeDataInput() const { + IOIndex_t i = 0; + for (; (static_cast<IONb_t>(i) < nbDataInputs()) && (input(i).second >= 0); ++i) {} + // assert((i<nbDataInputs()) && "No free data input for Node"); + return (static_cast<IONb_t>(i) < nbDataInputs()) ? i : gk_IODefaultIndex; + } + + IONb_t getNbFreeDataInputs() const; + + /** + * @brief List input ids of children liked to outputs of the node + * @return std::vector<std::vector<std::pair<NodePtr, + * IOIndex_t>>> + */ + std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; + + /** + * @brief Children and their input Tensor ID linked to the outID-th output + * Tensor. + * @param outID + * @return std::vector<std::pair<NodePtr, IOIndex_t>> + */ + std::vector<std::pair<NodePtr, IOIndex_t>> + output(IOIndex_t outID) const; + + /** + * @brief Number of input specifically for data + * @details [data, data, weight, bias] => 4 + * @return IOIndex_t + */ + inline IONb_t nbInputs() const noexcept { return getOperator()->nbInputs(); } + + /** + * @brief Number of input specifically for data + * @details [data, data, weight, bias] => 2 + * @return IOIndex_t + */ + inline IONb_t nbDataInputs() const noexcept { + return getOperator()->nbDataInputs(); + } + + /** + * @brief Number of inputs linked to a Parent's output. + * @return IOIndex_t + */ + IONb_t nbValidInputs() const; + + /** + * @brief Getter for the number of Output Tensors of the Node. + * @return IOIndex_t + */ + inline IONb_t nbOutputs() const noexcept { return getOperator()->nbOutputs(); } + + IONb_t nbValidOutputs() const; + + /////////////////////////////////////////////////////// + // TOPOLOGY + /////////////////////////////////////////////////////// + + /** + * @brief Vector of pointers to each GraphView containing the object + * @return std::vector<GraphView> + */ + inline std::set<std::shared_ptr<GraphView>> views() const noexcept { + return mViews; + } + + /** + * @brief Add a GraphView pointer to the list of GraphView containing + * the current Node. This feature allows transparent GraphViews. + * @param graphPtr Pointer to GraphView to add to the list. + */ + inline void addView(const std::shared_ptr<GraphView> graphPtr) { + mViews.insert(graphPtr); + } + + inline void removeView(const std::shared_ptr<GraphView> graphPtr) { + if (mViews.find(graphPtr) != mViews.end()) { + mViews.erase(graphPtr); + } + } + + /** + * @brief Link another Node to an output of the current Node. + * @param otherNode Pointer to the other Node. + * @param outId ID of the output Tensor to connect to the other Node. + * Default to 0. + * @param otherInId ID of the input Tensor to connect to the current Node. + * Default to the first avaible data input. + */ + void addChild(NodePtr otherNode, + const IOIndex_t outId = IOIndex_t(0), + IOIndex_t otherInId = gk_IODefaultIndex); + + /** + * @brief Link a Node from a specific GraphView to the current Node. + * @param otherView Pointer to the GraphView whose content should be + * linked to the current Node. + * @param outId ID of the output Tensor to connect to the other Node. + * Default to 0. + * @param otherInId Pair of pointer to Node and Tensor ID for specifying the + * connection. If the GraphView whose content is linked has only one input + * Node, then it defaults to the first available data input Tensor of this + * Node. + */ + void addChild(std::shared_ptr<GraphView> otherView, + const IOIndex_t outId = IOIndex_t(0), + std::pair<NodePtr, IOIndex_t> otherInId = + std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex)); + + /** + * @brief Get the list of parent Nodes. As an input is linked to a unic Node, + * if non is linked then the parent is a nullptr. + * @return std::vector<NodePtr> + */ + std::vector<NodePtr> getParents() const; + + inline NodePtr &getParents(IOIndex_t inID) { + assert(inID != gk_IODefaultIndex); + return mParents.at(inID); + } + + NodePtr popParent(const IOIndex_t inID); + + bool removeParent(const IOIndex_t inID); + + /** + * @brief Get the Children object. Children do not include any nullptr as + * an output maybe linked to nobody and the Node would still work fine. + * @return std::set<NodePtr>> + */ + std::set<NodePtr> getChildren() const; + + std::vector<std::vector<NodePtr>> getOrderedChildren() const; + + std::vector<NodePtr> getChildren(IOIndex_t outID) const; + + /** + * @brief Remove registered child from children lists if possible. + * @param nodePtr Node to remove. + * @param outId Output index. Default 0. + * @return true Child found and removed for given output index. + * @return false Child not found at given index. Nothing removed. + */ + bool removeChild(const NodePtr nodePtr, const IOIndex_t outId = 0); + + /** + * @brief Remove every link of surrounding nodes to it and conversly + */ + void resetConnections(bool includeLearnableParam = false); + +private: + /////////////////////////////////////////////////////// + // OPERATORS + /////////////////////////////////////////////////////// + + // void setOperator(const std::shared_ptr<Operator> op_ptr); + + /////////////////////////////////////////////////////// + // TENSOR MANAGEMENT + /////////////////////////////////////////////////////// + + void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID); + + /////////////////////////////////////////////////////// + // TOPOLOGY + /////////////////////////////////////////////////////// + + /** + * @brief add function specialized in adding Nodes. + * @param other_node + * @param outID + * @param other_inID + */ + void addChildOp(NodePtr other_node, const IOIndex_t outID, + IOIndex_t other_inID); + + /** + * @brief add functon specialized in adding GraphView. + * + * @param other_graph + * @param outID + * @param other_inID + */ + void addChildView(std::shared_ptr<GraphView> other_graph, + const IOIndex_t outID, + std::pair<NodePtr, IOIndex_t> other_inID); + + /** + * @brief Add a Node to the list of parents. + * @param other_node Node to add to parents list. + * @param inID index for adding the parent. + */ + void addParent(const NodePtr other_node, const IOIndex_t inID); +}; +} // namespace Aidge + +#endif /* __AIDGE_CORE_GRAPH_NODE_H__ */ \ No newline at end of file diff --git a/include/graph/OpArgs.hpp b/include/graph/OpArgs.hpp new file mode 100644 index 000000000..dd0cfe1cc --- /dev/null +++ b/include/graph/OpArgs.hpp @@ -0,0 +1,86 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_GRAPH_OPARGS_H__ +#define __AIDGE_CORE_GRAPH_OPARGS_H__ + +#include <memory> +#include <cassert> + +namespace Aidge { +class Node; +class GraphView; + +/** + * @brief Intermediate representation for Structural description. + */ +class OpArgs { +private: + std::shared_ptr<Node> mNode = nullptr; + std::shared_ptr<GraphView> mView = nullptr; + +public: + OpArgs(const std::shared_ptr<GraphView>& view_) + : mView(view_) {assert(mView && "The GraphView provided should not be a nullptr.");} + + OpArgs(const std::shared_ptr<Node>& node_) + : mNode(node_) {assert(mNode && "The Node provided should not be a nullptr.");} + + inline std::shared_ptr<Node> node() const noexcept { + return mNode; + } + + inline std::shared_ptr<GraphView> view() const noexcept { + return mView; + } +}; + + +///////////////////////////// +// Sequential + +/** + * @brief Create a GraphView by linking every input with the next + * one in a sequential way. Nodes linked with the Sequential graph + * generation instructions must have a single output. + * Sequential(A, B, C) returns A-->B-->C. + * @param inputs List of Node and GraphView to link sequentially. + * @return std::shared_ptr<GraphView> Pointer to the generated view. + */ +std::shared_ptr<GraphView> Sequential(std::initializer_list<OpArgs> inputs); + +///////////////////////////// +// Parallel + +/** + * @brief Creates a GraphView with provided Nodes without linking them. + * @param inputs List of Node and GraphView to link sequentially. + * @return std::shared_ptr<GraphView> pointer to the generated view. + */ +std::shared_ptr<GraphView> Parallel(std::initializer_list<OpArgs> inputs); + +///////////////////////////// +// Residual + +/** + * @brief Create a GraphView by linking every input with the next + * one in a sequential way. Finally the first element output is used + * as another input for the last element. Nodes linked with the Recursive graph + * generation instructions must have a single output. + * Recursive(A, B, C) returns A-->B-->C , A-->C. + * @param inputs List of Node and GraphView to link sequentially. + * @return std::shared_ptr<GraphView> pointer to the generated view. + */ +std::shared_ptr<GraphView> Residual(std::initializer_list<OpArgs> inputs); + +} + +#endif /* __AIDGE_CORE_GRAPH_OPARGS_H__ */ \ No newline at end of file diff --git a/include/graphmatching/GRegex.hpp b/include/graphmatching/GRegex.hpp new file mode 100644 index 000000000..5a49bcd8f --- /dev/null +++ b/include/graphmatching/GRegex.hpp @@ -0,0 +1,63 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + +#ifndef __AIDGE_GREGEX_H__ +#define __AIDGE_GREGEX_H__ + +#include <stdexcept> // for exception, runtime_error, out_of_range +#include <regex> +#include <memory> // for shared_ptr +#include <algorithm> // for next_permutation + +#include "graphmatching/Utile.hpp" +#include "graphmatching/StmFactory.hpp" +#include "graphmatching/SeqStm.hpp" +#include "graphmatching/NodeRegex.hpp" +#include "graphmatching/Match.hpp" + + +namespace Aidge{ + +class GRegex { +// __init__(self,nodes_regex:dict,seq_regexps:list) + + StmFactory mStmFab; + std::vector<SeqStm*> mStmInit; + +public: + GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ); + + std::set<NodeTmp> matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch); + + bool walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm); + + bool walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm); + + bool walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm); + + std::set<NodeTmp> get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm); + + std::vector<SeqStm*> getStmInit() const { + return mStmInit; + } + + StmFactory getStmFab() const { + return mStmFab; + } + + //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> match(const std::shared_ptr<GraphView> graphToMatch); + Match match(const std::shared_ptr<GraphView> graphToMatch); + +}; + +} +#endif //__AIDGE_GREGEX_H__ \ No newline at end of file diff --git a/include/graphmatching/Match.hpp b/include/graphmatching/Match.hpp new file mode 100644 index 000000000..2651bf3ae --- /dev/null +++ b/include/graphmatching/Match.hpp @@ -0,0 +1,44 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_MATCH_H__ +#define __AIDGE_MATCH_H__ + +#include <vector> +#include <set> +#include <iostream> +#include <cassert> +#include "graphmatching/Utile.hpp" + + +namespace Aidge{ + +class Match { + +public: + Match(); + + size_t getNbMatch(); + + void insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes); + + std::vector<std::vector<NodeTmp>> getStartNodes(); + + std::vector<std::set<NodeTmp>> getMatchNodes(); + +protected: + std::vector<std::vector<NodeTmp>> mStartNodes; + std::vector<std::set<NodeTmp>> mMatchNodes; + +}; + +} +#endif //__AIDGE_MATCH_H__ \ No newline at end of file diff --git a/include/graphmatching/NodeRegex.hpp b/include/graphmatching/NodeRegex.hpp new file mode 100644 index 000000000..24a8ed225 --- /dev/null +++ b/include/graphmatching/NodeRegex.hpp @@ -0,0 +1,41 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_NODEREGEX_H__ +#define __AIDGE_NODEREGEX_H__ +#include <cstdlib> +#include <iostream> +#include <cstring> +#include "graph/Node.hpp" + + +namespace Aidge { + +class NodeRegex +{ + public: + std::string mCondition; + + NodeRegex(const std::string c){ + mCondition = c; + }; + + // Version 1 - Only test the type of the node (no need for a lexer) + // Input : Node_op + // Output : bool + // return mCondition == Node_op.type + bool _is(std::shared_ptr<Node> &Node_op); + bool isA(std::string NodeType); +}; + +} + +#endif /* ___AIDGE_NODEREGEX_H___ */ \ No newline at end of file diff --git a/include/graphmatching/SeqStm.hpp b/include/graphmatching/SeqStm.hpp new file mode 100755 index 000000000..0abcc3d0d --- /dev/null +++ b/include/graphmatching/SeqStm.hpp @@ -0,0 +1,127 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_SEQSTM_H__ +#define __AIDGE_SEQSTM_H__ + +#include <iostream> +#include <map> +#include <regex> +#include <set> +#include <stdexcept> // for exception, runtime_error, out_of_range +#include <string> +#include <utility> +#include <vector> + + +#include "graphmatching/NodeRegex.hpp" +#include "graphmatching/Utile.hpp" + + +namespace Aidge { + +class SeqStm { + +private: + const int mStmIdx; + const std::vector<std::vector<int>> mTransitionMatrix; + // str key of type like 'A' that ce use in the A->B .. extpr + const std::map<std::string, NodeRegex *> mNodesRegex; + // mTypeToIdxTransition.first = std::pair node_type , common_tag + // mTypeToIdxTransition.segond = idx in trans matrix + const std::map<NodeTypeKey, int> mTypeToIdxTransition; + + int mActSt; + std::set<NodeTmp> mAllNodeValidated; + std::set<NodeTmp> mAllNodeTested; + std::set<std::pair<NodeTmp, std::string>> mAllCommonNode; + bool mStmIsValid; + + std::pair<NodeRegex *, std::string> getNodeRegexAndCommonAt(int idxType); + + /** + * @brief test the stm on a type + * @return the common tag + */ + std::string transitionOnNodeType(NodeType nodeType); + +public: + SeqStm(const int mStmIdx, + const std::vector<std::vector<int>> &mTransitionMatrix, + const std::map<std::string, NodeRegex *> &mNodesRegex, + const std::map<NodeTypeKey, int> &mTypeToIdxTransition, int mActSt, + std::set<NodeTmp> mAllNodeValidated, std::set<NodeTmp> mAllNodeTested, + std::set<std::pair<NodeTmp, std::string>> mAllCommonNode, + bool mStmIsValid); + + ////////////////////////////////////// + // STM test + ///////////////////////////////////// + + /** + * @brief get if a st is a valide one + * @return bool + */ + bool isAValidSt(int st) { + std::size_t size = mTransitionMatrix.size(); + return st == static_cast<int>(size - 1) ? true : false; + } + + /** + * @brief true if the stm is blocked into st + * @return bool + */ + bool isStmBlocked() { return mActSt == -1 ? true : false; } + + /** + * @brief true if the stm into valide st + * @return bool + */ + bool isValid() { return mStmIsValid; } + + ///////////////////////////////////// + // utile + ///////////////////////////////////// + /** + * @brief extract from a node is type + * @return bool + */ + NodeType getTheNodeType(NodeTmp node); + + void drawStm(); + ///////////////////////////////////// + // geter + ///////////////////////////////////// + + std::set<std::pair<NodeTmp, std::string>> getAllCommonNode() { + return mAllCommonNode; + } + std::set<NodeTmp> getAllNodeTested() { return mAllNodeTested; } + + std::set<NodeTmp> getAllNodeValidated() { return mAllNodeValidated; } + + SeqStm *duplicateStm(); + + int getStmIdx() { return mStmIdx; } + + int getState() { return mActSt; } + ////////////////////////////////////////// + // USE + ////////////////////////////////////////// + /** + * @brief test the stm on a node + * @return pair new stm state, the common tag + */ + std::pair<int, std::string> testNode(const NodeTmp node); +}; +} // namespace Aidge + +#endif /* __AIDGE_SEQSTM_H__ */ \ No newline at end of file diff --git a/include/graphmatching/StmFactory.hpp b/include/graphmatching/StmFactory.hpp new file mode 100644 index 000000000..2e5e84511 --- /dev/null +++ b/include/graphmatching/StmFactory.hpp @@ -0,0 +1,55 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_STMFACTORY_H__ +#define __AIDGE_STMFACTORY_H__ + +#include <map> +#include <utility> +#include <set> +#include <string> +#include <vector> +#include <iostream> +#include <stdexcept> // for exception, runtime_error, out_of_range +#include <regex> + +#include "graphmatching/NodeRegex.hpp" +#include "graphmatching/SeqStm.hpp" +#include "graphmatching/Utile.hpp" + +namespace Aidge{ + + + +class StmFactory { + + const std::map<std::string,NodeRegex*>& mNodesRegex; + std::size_t mCmptStm = 0; +public: + StmFactory(const std::map<std::string,NodeRegex*>& nodesRegex); + //StmFactory(){}; + + SeqStm* makeNewStm(const std::string& sequRegex); + SeqStm* duplicateStm(SeqStm* stm); + + std::size_t getNumberOfStm(){ + return mCmptStm; + } +private: + + ParsingReturn initParsingSequRegex(const std::string& sequRegex); + + std::vector<std::vector<int>> initTransitionMatrix(ParsingReturn& parsing); + +}; +} + +#endif //__AIDGE_STMFACTORY_H__ \ No newline at end of file diff --git a/include/graphmatching/Utile.hpp b/include/graphmatching/Utile.hpp new file mode 100644 index 000000000..251eafd83 --- /dev/null +++ b/include/graphmatching/Utile.hpp @@ -0,0 +1,50 @@ + +/** + * @file + * @brief + * @version file 1.0.0 + * @author vl241552 + * @copyright + * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. + * All rights reserved. + */ + +#ifndef _utile_H_ +#define _utile_H_ + +#include <map> + +#include "graph/Node.hpp" +#include <map> + +namespace Aidge { + +using NodeTmp = std::shared_ptr<Node>; +using NodeType = std::string; +using CommonTag = std::string; +using NodeTypeKey = std::pair<NodeType, CommonTag>; + +// type def +// struct NodeTypeKey { +// NodeType nodeType; +// std::string commonTag; + +// // for map find +// bool operator<(const NodeTypeKey& other) const { +// if (nodeType != other.nodeType or commonTag != other.commonTag) { +// return false; +// } else { +// return true; +// } +// } + +// }; + +struct ParsingReturn { + std::map<NodeTypeKey, int> typeToIdxTransition; + std::vector<std::pair<NodeTypeKey, std::string>> transition; +}; + +} // namespace Aidge + +#endif //_utile_H_ \ No newline at end of file diff --git a/include/operator/Add.hpp b/include/operator/Add.hpp new file mode 100644 index 000000000..42720af85 --- /dev/null +++ b/include/operator/Add.hpp @@ -0,0 +1,147 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_ADD_H__ +#define __AIDGE_CORE_OPERATOR_ADD_H__ + +#include <numeric> +#include <vector> +#include <cmath> +#include <memory> +#include <array> + +#include "utils/Registrar.hpp" +#include "operator/Operator.hpp" +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" + +namespace Aidge { + +template <std::size_t NUM> +class Add_Op : public Operator, + public Registrable<Add_Op<NUM>, std::string, std::unique_ptr<OperatorImpl>(const Add_Op<NUM>&)> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, NUM> mInputs; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(shared_from_this()); + +public: + static constexpr const char* Type = "Add"; + + constexpr Add_Op() + : Operator(Type), + mOutput(std::make_shared<Tensor>()) + { + assert(NUM > 0 && "Add should have at least one input"); + for (std::size_t i = 0; i<NUM; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + setDatatype(DataType::Float32); + } + + // Data operator[](const char* inputName) override final { + // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : + // (strcmp(inputName, "weight") ? mInputs[1] : + // (strcmp(inputName, "bias") ? mInputs[2] : + // nullptr)); + // assert((in!=nullptr) && "No such parameter"); + // return *in; + // } + + constexpr void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + constexpr void computeOutputDims() override final { + if (!mInputs[0]->empty()) { + const auto expectedDims = mInputs[0]->dims(); + std::size_t nonEmptyInputTensor = 1; + for (; nonEmptyInputTensor<NUM && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { + assert(expectedDims == mInputs[nonEmptyInputTensor]->dims()); + } + if (nonEmptyInputTensor == NUM) { + mOutput->resize(expectedDims); + } + } + } + + bool outputDimsForwarded() const override final { + std::size_t forwarded = 0; + for (; forwarded < NUM && (!mInputs[forwarded]->empty()); ++forwarded) {} + return ((forwarded==NUM) && !(mOutput->empty())); + } + + // void checkDims() const override final { + // assert(outputDimsForwarded()); + // for (const auto& in : mInputs) { + // assert(in->dims() == mOutput->dims()); + // } + // } + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "Add Operators has only 1 outputs"); + return mOutput; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) { + mImpl = Registrar<Add_Op<NUM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + for (std::size_t i = 0; i < NUM; ++i) { + mInputs[i]->setBackend(name); + } + } + + void setDatatype(const DataType& datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + for (std::size_t i = 0; i < NUM; ++i) { + mInputs[i]->setDatatype(datatype); + } + } + + inline IONb_t nbInputs() const noexcept override final { return NUM; } + inline IONb_t nbDataInputs() const noexcept override final { return NUM; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +template <std::size_t NUM> +inline std::shared_ptr<Node> Add(const char* name = nullptr) { + return std::make_shared<Node>(std::make_shared<Add_Op<NUM>>(), name); +} +} + +#endif /* __AIDGE_CORE_OPERATOR_ADD_H__ */ diff --git a/include/operator/AvgPooling.hpp b/include/operator/AvgPooling.hpp new file mode 100644 index 000000000..c3321dc24 --- /dev/null +++ b/include/operator/AvgPooling.hpp @@ -0,0 +1,169 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_AVGPOOLING_H__ +#define __AIDGE_CORE_OPERATOR_AVGPOOLING_H__ + +#include <array> +#include <numeric> +#include <vector> +#include <cmath> + +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "operator/Producer.hpp" +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" +#include "utils/Types.h" + +namespace Aidge { +enum class AvgPoolingParam { StrideDims, KernelDims, PaddingDims }; + +template <DimIdx_t DIM> +class AvgPooling_Op : public Operator, + public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>, + public Parameterizable<AvgPoolingParam, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, (DIM<<1) >> { +private: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char *Type = "AvgPooling"; + + AvgPooling_Op() = delete; + + using Parameterizable_ = Parameterizable<AvgPoolingParam, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, (DIM<<1)> >; + template <AvgPoolingParam e> + using param = typename Parameterizable_::template param<e>; + + constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) + : Operator(Type), + Parameterizable_(param<AvgPoolingParam::StrideDims>(stride_dims), + param<AvgPoolingParam::KernelDims>(kernel_dims), + param<AvgPoolingParam::PaddingDims>(padding_dims)), + mOutput(std::make_shared<Tensor>()) { + setDatatype(DataType::Float32); + } + + constexpr void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 1 && "operators supports only 3 inputs"); + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + constexpr void computeOutputDims() override final { + if (!mInput->empty()) { + std::array<DimSize_t, DIM + 2> outputDims = {}; + + for (std::size_t dim = 0; dim < this->template get<AvgPoolingParam::KernelDims>().size() ; ++dim) { + outputDims[dim+2] = 1 + static_cast<DimSize_t>( + std::floor(static_cast<float>(mInput->dims()[dim+2] - + this->template get<AvgPoolingParam::KernelDims>()[dim] + + this->template get<AvgPoolingParam::PaddingDims>()[dim] + + this->template get<AvgPoolingParam::PaddingDims>()[dim+DIM]) / + static_cast<float>(this->template get<AvgPoolingParam::StrideDims>()[dim]))); + } + outputDims[1] = mInput->dims()[1]; + outputDims[0] = mInput->dims()[0]; + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + return *(mInput.get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "AvgPooling Operators supports only 1 inputs"); + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "AvgPooling Operators has only 1 outputs"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string &name) { + mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + + void setDatatype(const DataType &datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 1; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, + const char *name = nullptr, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + // FIXME: properly handle default w&b initialization in every cases + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); + auto avgPool = std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims), name); + return avgPool; +} + +template <DimSize_t DIM> +inline std::shared_ptr<Node> AvgPooling( + DimSize_t const (&kernel_dims)[DIM], + const char *name = nullptr, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); + return AvgPooling(to_array(kernel_dims), name, stride_dims, padding_dims); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::AvgPoolingParam>::data[] = {"StrideDims", + "KernelDims", "PaddingDims"}; +} + +#endif /* __AIDGE_CORE_OPERATOR_AVGPOOLING_H__ */ diff --git a/include/operator/BatchNorm.hpp b/include/operator/BatchNorm.hpp new file mode 100644 index 000000000..d46f971f1 --- /dev/null +++ b/include/operator/BatchNorm.hpp @@ -0,0 +1,161 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_BATCHNORM_H__ +#define __AIDGE_CORE_OPERATOR_BATCHNORM_H__ + +#include <array> +#include <memory> +#include <vector> + +#include "utils/Types.h" +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "operator/Producer.hpp" +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" + +namespace Aidge { +enum class BatchNormParam { Epsilon, Momentum }; + + +template <DimIdx_t DIM> +class BatchNorm_Op : public Operator, + public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>, + public Parameterizable<BatchNormParam, float, float> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 5> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), + std::make_shared<Tensor>(), std::make_shared<Tensor>(), + std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + + public: + static constexpr const char *Type = "BatchNorm"; + + BatchNorm_Op() = delete; + + using Parameterizable_ = Parameterizable<BatchNormParam, float, float>; + template <BatchNormParam e> + using param = typename Parameterizable_::template param<e>; + + constexpr BatchNorm_Op(float epsilon, float momentum) + : Operator(Type), + Parameterizable_(param<BatchNormParam::Epsilon>(epsilon), + param<BatchNormParam::Momentum>(momentum)), + mOutput(std::make_shared<Tensor>()) { + setDatatype(DataType::Float32); + } + + // Data operator[](const char* inputName) override final { + // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : + // (strcmp(inputName, "weight") ? mInputs[1] : + // (strcmp(inputName, "bias") ? mInputs[2] : + // nullptr)); + // assert((in!=nullptr) && "No such parameter"); + // return *in; + // } + + constexpr void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 5 && "operators supports only 5 inputs"); + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + constexpr void computeOutputDims() override final { + if (!mInputs[0]->empty()) { + for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) { + if(mInputs[i]->size() != mInputs[0]->dims()[1]) { + assert(!mInputs[0]->hasImpl() && "Incompatible size with already implemented learnable parameter"); + mInputs[i]->resize(std::array<DimSize_t, 1>({mInputs[0]->dims()[1]})); + } + } + mOutput->resize(mInputs[0]->dims()); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 5 && "operators supports only 5 inputs"); + return *(mInputs[inputIdx].get()); } + + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 5 && "BatchNorm Operators supports only 5 inputs"); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "BatchNorm Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 5 && "operators supports only 5 inputs"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string &name) { + mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[1]->setBackend(name); + mInputs[2]->setBackend(name); + mInputs[3]->setBackend(name); + mInputs[4]->setBackend(name); + } + + void setDatatype(const DataType &datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[1]->setDatatype(datatype); + mInputs[2]->setDatatype(datatype); + mInputs[3]->setDatatype(datatype); + mInputs[4]->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 5; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +template <DimSize_t DIM> +inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, + const float momentum = 0.1F, + const char *name = nullptr) { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); + auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); + addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); + addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift"); + addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean"); + addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance"); + return batchNorm; +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::BatchNormParam>::data[] = { "Epsilon", "Momentum" }; +} + +#endif // __AIDGE_CORE_OPERATOR_BATCHNORM_H__ \ No newline at end of file diff --git a/include/operator/Conv.hpp b/include/operator/Conv.hpp new file mode 100644 index 000000000..2e9f67df5 --- /dev/null +++ b/include/operator/Conv.hpp @@ -0,0 +1,200 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_CONV_H__ +#define __AIDGE_CORE_OPERATOR_CONV_H__ + +#include <array> +#include <cmath> +#include <numeric> +#include <vector> + +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "operator/Producer.hpp" +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" +#include "utils/Types.h" + +namespace Aidge { +enum class ConvParam { StrideDims, DilationDims, InChannels, OutChannels, KernelDims, PaddingDims }; + +template <DimIdx_t DIM> +class Conv_Op : public Operator, + public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, + public Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, + DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), + std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + + public: + static constexpr const char *Type = "Conv"; + + Conv_Op() = delete; + + using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, + DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>; + template <ConvParam e> + using param = typename Parameterizable_::template param<e>; + + constexpr Conv_Op(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) + : Operator(Type), + Parameterizable_(param<ConvParam::StrideDims>(stride_dims), + param<ConvParam::DilationDims>(dilation_dims), + param<ConvParam::InChannels>(in_channels), + param<ConvParam::OutChannels>(out_channels), + param<ConvParam::KernelDims>(kernel_dims), + param<ConvParam::PaddingDims>(padding_dims)), + mOutput(std::make_shared<Tensor>()) { + setDatatype(DataType::Float32); + } + + // Data operator[](const char* inputName) override final { + // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : + // (strcmp(inputName, "weight") ? mInputs[1] : + // (strcmp(inputName, "bias") ? mInputs[2] : + // nullptr)); + // assert((in!=nullptr) && "No such parameter"); + // return *in; + // } + + // std::shared_ptr<Conv_Op> clone() const override final { + + // } + + constexpr void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + constexpr void computeOutputDims() override final { + if (!mInputs[0]->empty()) { + std::array<DimSize_t, DIM + 2> outputDims = {}; + + for (std::size_t dim = 0; dim < this->template get<ConvParam::KernelDims>().size() ; ++dim) { + const DimSize_t kernelExtent = this->template get<ConvParam::DilationDims>()[dim] * + (this->template get<ConvParam::KernelDims>()[dim] - 1) + + 1; + + outputDims[dim+2] = 1 + static_cast<DimSize_t>( + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + + this->template get<ConvParam::PaddingDims>()[dim] + + this->template get<ConvParam::PaddingDims>()[dim+DIM]) / + static_cast<float>(this->template get<ConvParam::StrideDims>()[dim]))); + } + + outputDims[1] = this->template get<ConvParam::OutChannels>(); + outputDims[0] = mInputs[0]->dims()[0]; + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + return *(mInputs[inputIdx].get()); } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "Conv Operators supports only 3 inputs"); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Conv Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string &name) { + mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[1]->setBackend(name); + mInputs[2]->setBackend(name); + } + + void setDatatype(const DataType &datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + mInputs[2]->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 3; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> Conv(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const char *name = nullptr, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { + // FIXME: properly handle default w&b initialization in every cases + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); + auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, padding_dims, dilation_dims), name); + // addProducer(conv, 1, append(append(kernel_dims, in_channels), out_channels), "w"); + addProducer(conv, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); + addProducer(conv, 2, {out_channels}, "b"); + return conv; +} + +template <DimSize_t DIM> +inline std::shared_ptr<Node> Conv( + DimSize_t in_channels, + DimSize_t out_channels, + DimSize_t const (&kernel_dims)[DIM], + const char *name = nullptr, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); + return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::ConvParam>::data[] = {"StrideDims", "DilationDims", "InChannels", "OutChannels", + "KernelDims", "PaddingDims"}; +} + +#endif /* __AIDGE_CORE_OPERATOR_CONV_H__ */ diff --git a/include/operator/ConvDepthWise.hpp b/include/operator/ConvDepthWise.hpp new file mode 100644 index 000000000..fa268a32e --- /dev/null +++ b/include/operator/ConvDepthWise.hpp @@ -0,0 +1,196 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H__ +#define __AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H__ + +#include <array> +#include <cmath> +#include <numeric> +#include <vector> + +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "operator/Producer.hpp" +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" +#include "utils/Types.h" + +namespace Aidge { +enum class ConvDepthWiseParam { StrideDims, DilationDims, Channels, KernelDims, PaddingDims }; + +template <DimIdx_t DIM> +class ConvDepthWise_Op : public Operator, + public Registrable<ConvDepthWise_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ConvDepthWise_Op<DIM> &)>, + public Parameterizable<ConvDepthWiseParam, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, DIM>, + DimSize_t, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, (DIM<<1) >> { + public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), + std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + + public: + static constexpr const char *Type = "ConvDepthWise"; + + ConvDepthWise_Op() = delete; + + using Parameterizable_ = Parameterizable<ConvDepthWiseParam, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, DIM>, + DimSize_t, + std::array<DimSize_t, DIM>, + std::array<DimSize_t, (DIM<<1) >>; + template <ConvDepthWiseParam e> + using param = typename Parameterizable_::template param<e>; + + constexpr ConvDepthWise_Op(const std::array<DimSize_t, DIM> &kernel_dims, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) + : Operator(Type), + Parameterizable_(param<ConvDepthWiseParam::StrideDims>(stride_dims), + param<ConvDepthWiseParam::DilationDims>(dilation_dims), + param<ConvDepthWiseParam::Channels>(0), + param<ConvDepthWiseParam::KernelDims>(kernel_dims), + param<ConvDepthWiseParam::PaddingDims>(padding_dims)), + mOutput(std::make_shared<Tensor>()) { + setDatatype(DataType::Float32); + } + + constexpr void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + constexpr void computeOutputDims() override final { + if (!mInputs[0]->empty()) { + std::array<DimSize_t, DIM + 2> outputDims = {}; + + for (std::size_t dim = 0; dim < this->template get<ConvDepthWiseParam::KernelDims>().size() ; ++dim) { + const DimSize_t kernelExtent = this->template get<ConvDepthWiseParam::DilationDims>()[dim] * + (this->template get<ConvDepthWiseParam::KernelDims>()[dim] - 1) + + 1; + + outputDims[dim+2] = 1 + static_cast<DimSize_t>( + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + + this->template get<ConvDepthWiseParam::PaddingDims>()[dim] + + this->template get<ConvDepthWiseParam::PaddingDims>()[dim+DIM]) / + static_cast<float>(this->template get<ConvDepthWiseParam::StrideDims>()[dim]))); + } + this->template get<ConvDepthWiseParam::Channels>() = mInputs[0]->dims()[1]; + // std::array<DimSize_t, DIM+2> weightDims = append(mInputs[0]->dims()[1],append(1, this->template get<ConvDepthWiseParam::KernelDims>())); + // if (mInputs[1]->empty()) { + // mInputs[1]->resize(weightDims); + // } + // if (mInputs[2]->empty()) { + // mInputs[2]->resize({mInputs[0]->dims()[1]}); + // } + outputDims[1] = mInputs[0]->dims()[1]; + outputDims[0] = mInputs[0]->dims()[0]; + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "ConvDepthWise Operators supports only 3 inputs"); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "ConvDepthWise Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + + void setBackend(const std::string &name) { + mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[1]->setBackend(name); + mInputs[2]->setBackend(name); + } + + void setDatatype(const DataType &datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + mInputs[2]->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 3; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> ConvDepthWise(const std::array<DimSize_t, DIM> &kernel_dims, + const char *name = nullptr, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { + // FIXME: properly handle default w&b initialization in every cases + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); + auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims, dilation_dims), name); + addProducer(convDW, 1, std::array<DimSize_t,0>({}), "w"); + addProducer(convDW, 2, std::array<DimSize_t,0>({}), "b"); + return convDW; +} + +template <DimSize_t DIM> +inline std::shared_ptr<Node> ConvDepthWise( + DimSize_t const (&kernel_dims)[DIM], + const char *name = nullptr, + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); + return ConvDepthWise(to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::ConvDepthWiseParam>::data[] = {"StrideDims", "DilationDims", "Channels", + "KernelDims", "PaddingDims"}; +} + +#endif /* __AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H__ */ diff --git a/include/operator/FC.hpp b/include/operator/FC.hpp new file mode 100644 index 000000000..998957909 --- /dev/null +++ b/include/operator/FC.hpp @@ -0,0 +1,155 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_FC_H__ +#define __AIDGE_CORE_OPERATOR_FC_H__ + +#include <array> +#include <cmath> +#include <numeric> +#include <memory> +#include <vector> + +#include "utils/Types.h" +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "operator/Producer.hpp" +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" + +namespace Aidge { +enum class FCParam { OutChannels, NoBias }; + +class FC_Op : public Operator, + public Registrable<FC_Op, + std::string, + std::unique_ptr<OperatorImpl>(const FC_Op &)>, + public Parameterizable<FCParam, DimSize_t, bool> { +public: + // FIXME: change accessibility + std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "FC"; + + FC_Op() = delete; + + using Parameterizable_ = Parameterizable<FCParam, DimSize_t, bool>; + template <FCParam e> using param = typename Parameterizable_::template param<e>; + + FC_Op(DimSize_t out_channels, bool noBias) + : Operator(Type), + Parameterizable_( + param<FCParam::OutChannels>(out_channels), + param<FCParam::NoBias>(noBias)), + mOutput(std::make_shared<Tensor>()) + { + setDatatype(DataType::Float32); + } + + void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + if (inputIdx == 2) { + assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template get<FCParam::NoBias>()) == false ? static_cast<std::size_t>(this->template get<FCParam::OutChannels>()) : 0)); + assert(std::dynamic_pointer_cast<Tensor>(data)->nbDims() == 1); + } + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + if (inputIdx == 0 && mInputs[0]->nbDims() == 1) + mInputs[inputIdx]->resize(std::array<DimSize_t, 2>({1, mInputs[inputIdx]->size()})); + } + + void computeOutputDims() override final { + if (!mInputs[0]->empty()) { + // <in_features**, out_channels> + std::array<DimSize_t, 2> weightDims = {this->template get<FCParam::OutChannels>(), static_cast<DimSize_t>(mInputs[0]->sizeM1())}; + // <out_channels, batch> + std::array<DimSize_t, 2> outputDims = {mInputs[0]->dims()[0], this->template get<FCParam::OutChannels>()}; + + mInputs[1]->resize(weightDims); + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + return *(mInputs[inputIdx].get()); } + inline Tensor& output(const IOIndex_t /*inputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "FC Operators supports only 3 inputs"); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "FC Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 3 && "operators supports only 3 inputs"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) { + mImpl = Registrar<FC_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[0]->setBackend(name); + mInputs[1]->setBackend(name); + mInputs[2]->setBackend(name); + } + + void setDatatype(const DataType& datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + mInputs[2]->setDatatype(datatype); + } + + + inline IONb_t nbInputs() const noexcept override final { return 3; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const char* name = nullptr) { + // FIXME: properly handle default w&b initialization in every cases + auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(out_channels, noBias), name); + addProducer(fc, 1, {out_channels, 1}, "w"); + addProducer(fc, 2, {(noBias ? 0 : out_channels)}, "b"); // already sets bias dims + return fc; +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::FCParam>::data[] = {"OutChannels", + "NoBias"}; +} + +#endif /* __AIDGE_CORE_OPERATOR_FC_H__ */ \ No newline at end of file diff --git a/include/operator/GenericOperator.hpp b/include/operator/GenericOperator.hpp new file mode 100644 index 000000000..c3c8c61d7 --- /dev/null +++ b/include/operator/GenericOperator.hpp @@ -0,0 +1,165 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_GENERICOPERATOR_H__ +#define __AIDGE_CORE_OPERATOR_GENERICOPERATOR_H__ + +#include <memory> +#include <vector> +#include <string> +#include <cassert> + +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "utils/CParameter.hpp" +#include "utils/Registrar.hpp" +#include "utils/Types.h" + +namespace Aidge { +class GenericOperator_Op + : public Operator, + public Registrable<GenericOperator_Op, std::string, std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)> { + private: + CParameter mParams; + IONb_t mNbDataIn; + IONb_t mNbIn; + IONb_t mNbOut; + std::vector<std::shared_ptr<Tensor>> mInputs; + std::vector<std::shared_ptr<Tensor>> mOutputs; + + public: + GenericOperator_Op(const char *type, IONb_t nbDataIn, IONb_t nbIn, IONb_t nbOut) + : Operator(type), mNbDataIn(nbDataIn), mNbIn(nbIn), mNbOut(nbOut) + { + mInputs = std::vector<std::shared_ptr<Tensor>>(nbIn); + for (std::size_t i = 0; i < nbIn; ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + mOutputs = std::vector<std::shared_ptr<Tensor>>(nbOut); + for (std::size_t i = 0; i < nbOut; ++i) { + mOutputs[i] = std::make_shared<Tensor>(); + } + } + + /** + * @brief Get the Parameter object identified by its name. + * @tparam T expected parameter type. + * @param key Parameter name. + * @details assert if T is not the actual parameter type, if the parameter + * does not exist or internal parameter position is invalid. + * @todo Returning a T const& ? But dangerous => may get an address within + * param buffer that will get invalid after the CParam death. + * @note at() throws if the parameter does not exist, using find to test + * for parameter existance + * @return template<class T> The parameter. + */ + template <class T> + T getParameter(std::string const &key) const { + return mParams.Get<T>(key); + } + + ///\brief Add a parameter value, identified by its name + ///\tparam T expected parameter type + ///\param i_ParamName Parameter name + ///\param i_Value Parameter value + ///\todo Pass i_Value by ref if large or not trivial + ///\bug If parameter already exists, its value is changed but written in the + /// internal buffer in a new location (previous value is still in memory at + /// its previous location) + template <class T> + void addParameter(std::string const &key, T const &value) { + mParams.Add<T>(key, value); + } + + std::string getParameterType(std::string const &key) { return mParams.getParamType(key); } + + std::vector<std::string> getParametersName() { return mParams.getParametersName(); } + + // Override Virtual Opertor methods + void associateInput(IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { + printf("Info: using associateInput() on a GenericOperator.\n"); + } + + void computeOutputDims() override final { + assert(false && "Cannot compute output dim of a GenericOperator"); + } + + bool outputDimsForwarded() const override final { + assert(false && "GenericOperator cannot forward dims"); + return false; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator"); + printf("Info: using getRawInput() on a GenericOperator.\n"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator"); + printf("Info: using input() on a GenericOperator.\n"); + return *mInputs[inputIdx]; + } + + + std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx < mNbIn) && "input index out of range for this instance of GenericOperator"); + printf("Info: using getInput() on a GenericOperator.\n"); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx < mNbOut) && "output index out of range for this instance of GenericOperator"); + printf("Info: using getOutput() on a GenericOperator.\n"); + return mOutputs[outputIdx]; + } + + + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx < mNbOut) && "output index out of range for this instance of GenericOperator"); + printf("Info: using getRawOutput() on a GenericOperator.\n"); + return std::static_pointer_cast<Data>(mOutputs[outputIdx]); + } + + Tensor& output(const IOIndex_t outputIdx) const override final { + assert((outputIdx < mNbOut) && "output index out of range for this instance of GenericOperator"); + printf("Info: using output() on a GenericOperator.\n"); + return *mOutputs[outputIdx]; + } + + ~GenericOperator_Op() = default; + + void setBackend(const std::string & /*name*/) { printf("setBackend: not available yet.\n"); } + void setDatatype(const DataType & /*datatype*/) { printf("setDatatype: not available yet.\n"); } + void forward() override final { printf("forward: not available yet.\n"); } + void backward() override final { printf("backward: not available yet.\n"); } + + inline IONb_t nbInputs() const noexcept override final { return mNbIn; }; + inline IONb_t nbDataInputs() const noexcept override final { return mNbDataIn; }; + inline IONb_t nbOutputs() const noexcept override final { return mNbOut; }; +}; + +/** + * @brief Fictive custom operator not associated with any implementation. + * Allows to import unknown operators and simulate new ones. + * @param type Type of the fictive operator. + * @param nbDataIn Number of input data. + * @param nbIn Number input data + number of learnt parameters. + * @param nbOut Number of output data. + * @param name (optional) name of the Operator. + * @return std::shared_ptr<Node> Node associated with the Generic Operator. + */ +inline std::shared_ptr<Node> GenericOperator(const char *type, IONb_t nbDataIn, IONb_t nbIn, IONb_t nbOut, + const char *name = nullptr) { + return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbDataIn, nbIn, nbOut), name); +} +} // namespace Aidge + +#endif /* __AIDGE_CORE_OPERATOR_GENERICOPERATOR_H__ */ diff --git a/include/operator/LeakyReLU.hpp b/include/operator/LeakyReLU.hpp new file mode 100644 index 000000000..cdc172084 --- /dev/null +++ b/include/operator/LeakyReLU.hpp @@ -0,0 +1,127 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_LEAKYRELU_H__ +#define __AIDGE_CORE_OPERATOR_LEAKYRELU_H__ + +#include <vector> +#include <memory> + +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" +#include "operator/Operator.hpp" +#include "backend/OperatorImpl.hpp" +#include "data/Tensor.hpp" +#include "data/Data.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" + +namespace Aidge { +enum class LeakyReLUParam { + NegativeSlope +}; + +class LeakyReLU_Op : public Operator, + public Registrable<LeakyReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const LeakyReLU_Op&)>, + public Parameterizable<LeakyReLUParam, float> { +public: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "LeakyReLU"; + + LeakyReLU_Op() = delete; + + using Parameterizable_ = Parameterizable<LeakyReLUParam, float>; + template <LeakyReLUParam e> using param = typename Parameterizable_::template param<e>; + + LeakyReLU_Op(float negativeSlope) + : Operator(Type), + Parameterizable_( + param<LeakyReLUParam::NegativeSlope>(negativeSlope)) + { + setDatatype(DataType::Float32); + } + + void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) + mOutput->resize(mInput->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t /*inputIdx*/) const override final { return *(mInput.get()); } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "LeakyReLU Operator has only 1 input"); + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "LeakyReLU Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return mOutput; + } + + + void setBackend(const std::string& name) { + mImpl = Registrar<LeakyReLU_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + void setDatatype(const DataType& datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 1; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const char* name = nullptr) { + // FIXME: properly handle default w&b initialization in every cases + return std::make_shared<Node>(std::make_shared<LeakyReLU_Op>(negativeSlope), name); +} +} + +namespace { +template <> +const char* const EnumStrings<Aidge::LeakyReLUParam>::data[] + = {"NegativeSlope"}; +} + +#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */ diff --git a/include/operator/Matmul.hpp b/include/operator/Matmul.hpp new file mode 100644 index 000000000..948cbb4dd --- /dev/null +++ b/include/operator/Matmul.hpp @@ -0,0 +1,143 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_MATMUL_H__ +#define __AIDGE_CORE_OPERATOR_MATMUL_H__ + +#include <array> +#include <cmath> +#include <numeric> +#include <memory> +#include <vector> + +#include "utils/Types.h" +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "operator/Producer.hpp" +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" + +namespace Aidge { +enum class MatmulParam { OutChannels }; + +class Matmul_Op : public Operator, + public Registrable<Matmul_Op, + std::string, + std::unique_ptr<OperatorImpl>(const Matmul_Op &)>, + public Parameterizable<MatmulParam, DimSize_t> { +public: + std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()}; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Matmul"; + + Matmul_Op() = delete; + + using Parameterizable_ = Parameterizable<MatmulParam, DimSize_t>; + template <MatmulParam e> using param = typename Parameterizable_::template param<e>; + + Matmul_Op(DimSize_t out_channels) + : Operator(Type), + Parameterizable_( + param<MatmulParam::OutChannels>(out_channels)), + mOutput(std::make_shared<Tensor>()) + { + setDatatype(DataType::Float32); + } + + void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 2 && "operators supports only 2 inputs"); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInputs[0]->empty()) { + // <in_features**, out_channels> + std::array<DimSize_t, 2> weightDims = {static_cast<DimSize_t>(mInputs[0]->size()), this->template get<MatmulParam::OutChannels>()}; + // <out_channels, batch> + std::array<DimSize_t, 1> outputDims = {this->template get<MatmulParam::OutChannels>()}; + + mInputs[1]->resize(weightDims); + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 2 && "operators supports only 2 inputs"); + return *(mInputs[inputIdx].get()); } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 2 && "MatMul Operators has 2 inputs"); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "MatMul Operators has 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < 2 && "operators supports only 2 inputs"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) { + mImpl = Registrar<Matmul_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInputs[0]->setBackend(name); + mInputs[1]->setBackend(name); + } + + void setDatatype(const DataType& datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInputs[0]->setDatatype(datatype); + mInputs[1]->setDatatype(datatype); + } + + + inline IONb_t nbInputs() const noexcept override final { return 2; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +inline std::shared_ptr<Node> Matmul(DimSize_t out_channels, const char* name = nullptr) { + // FIXME: properly handle default w&b initialization in every cases + auto matmul = std::make_shared<Node>(std::make_shared<Matmul_Op>(out_channels), name); + addProducer(matmul, 1, {1, out_channels}, "w"); + return matmul; +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::MatmulParam>::data[] = {"OutChannels"}; +} + +#endif /* __AIDGE_CORE_OPERATOR__MATMUL_H__ */ diff --git a/include/operator/MetaOperator.hpp b/include/operator/MetaOperator.hpp new file mode 100644 index 000000000..5d4bad51c --- /dev/null +++ b/include/operator/MetaOperator.hpp @@ -0,0 +1,28 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_METAOPERATOR_H__ +#define __AIDGE_CORE_OPERATOR_METAOPERATOR_H__ + +#include "operator/Operator.hpp" + +namespace Aidge { +class MetaOperator : public Operator { +public: + MetaOperator() + : Operator("MetaOp") + { + } + ~MetaOperator() = default; +}; +} + +#endif /* MetaOperator_H__ */ diff --git a/include/operator/Operator.hpp b/include/operator/Operator.hpp new file mode 100644 index 000000000..19d761841 --- /dev/null +++ b/include/operator/Operator.hpp @@ -0,0 +1,99 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_OPERATOR_H__ +#define __AIDGE_CORE_OPERATOR_OPERATOR_H__ + +#include <memory> +#include <string> +#include <vector> + +#include "backend/OperatorImpl.hpp" +#include "data/Data.hpp" +#include "data/Tensor.hpp" +#include "utils/Types.h" + +namespace Aidge { + +class Operator : public std::enable_shared_from_this<Operator> { +protected: + std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator + +private: + std::string mType; + +public: + Operator() = delete; + Operator(const char* type) : mType(type) {} + virtual ~Operator(); + + +public: + + virtual void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) = 0; + virtual void computeOutputDims() = 0; + virtual bool outputDimsForwarded() const = 0; + virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0; + virtual std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const = 0; + virtual Tensor& input(const IOIndex_t /*inputIdx*/) const = 0; + virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0; + virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0; + virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0; + +/////////////////////////////////////////////////////// +// IMPLEMENTATION +/////////////////////////////////////////////////////// + + virtual void setBackend(const std::string& name) = 0; + virtual void setDatatype(const DataType& datatype) = 0; + + /** + * @brief Minimum amount of data from a specific input for one computation pass. + * @param inputIdx Index of the input analysed. + * @return NbElts_t + */ + NbElts_t getNbRequiredData(IOIndex_t inputIdx) const; + + /** + * @brief Amount of data from a specific input actually used in one computation pass. + * + * @param inputIdx Index of the input analysed. + * @return NbElts_t + */ + NbElts_t getNbConsumedData(IOIndex_t inputIdx) const; + + /** + * @brief Amount of data ready to be used on a specific output. + * + * @param outputIdx Index of the output analysed. + * @return NbElts_t + */ + NbElts_t getNbProducedData(IOIndex_t outputIdx) const; + + virtual void forward(); + + virtual void backward(); + +/////////////////////////////////////////////////////// +// INNER +/////////////////////////////////////////////////////// + + std::string type() const { + return mType; + } + + virtual IONb_t nbInputs() const noexcept = 0; + virtual IONb_t nbDataInputs() const noexcept = 0; + virtual IONb_t nbOutputs() const noexcept = 0; +}; +} // namespace Aidge + +#endif /* __AIDGE_CORE_OPERATOR_OPERATOR_H__ */ diff --git a/include/operator/Producer.hpp b/include/operator/Producer.hpp new file mode 100644 index 000000000..fdf109681 --- /dev/null +++ b/include/operator/Producer.hpp @@ -0,0 +1,141 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_PRODUCER_H__ +#define __AIDGE_CORE_OPERATOR_PRODUCER_H__ + +#include <array> +#include <vector> + +#include "utils/Types.h" +#include "data/Tensor.hpp" +#include "graph/Node.hpp" +#include "operator/Operator.hpp" +#include "utils/Parameter.hpp" +#include "utils/Registrar.hpp" + +namespace Aidge { + +class Producer_Op + : public Operator, + public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( + const Producer_Op &)> { +private: + std::shared_ptr<Tensor> mOutput; + +public: + static constexpr const char* Type = "Producer"; + + template <std::size_t DIM> + Producer_Op(const std::array<DimSize_t, DIM>& dims) + : Operator(Type), + mOutput(std::make_shared<Tensor>()) + { + //ctor + setDatatype(DataType::Float32); + mOutput->resize(dims); + } + + Producer_Op(const std::shared_ptr<Tensor> tensor) + : Operator(Type), + mOutput(tensor) + { + setDatatype(tensor->dataType()); + } + + void associateInput(IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { + assert(false && "Producer operator takes no input"); + } + + constexpr void computeOutputDims() override final {} + + constexpr bool outputDimsForwarded() const override final {return true;} + + + inline Tensor& input(const IOIndex_t /*inputIdx*/) const override final { assert(false); } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t /*inputIdx*/) const override final { + assert(false && "Producer Operator has no input"); + return nullptr; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Producer Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t /*inputIdx*/) const override final { + assert(false && "Producer operator takes no input"); + return nullptr; + } + + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + inline const std::vector<DimSize_t> dims() const noexcept { return mOutput->dims(); } + + void setBackend(const std::string& name) { + mImpl = Registrar<Producer_Op>::create(name)(*this); + mOutput->setBackend(name); + } + void setDatatype(const DataType& datatype) { + mOutput->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 0; }; + inline IONb_t nbDataInputs() const noexcept override final { return 0; }; + inline IONb_t nbOutputs() const noexcept override final { return 1; }; + +public: + void forward() override final { + printf("Basic Producer forward() function.\n"); + } + void backward() override final { + printf("Basic Producer backward() function.\n"); + } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const char *name = nullptr) { + static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported"); + return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name); +} + +template <std::size_t DIM> +inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const char *name = nullptr) { + return Producer(to_array(dims), name); +} + +inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const char *name = nullptr) { + return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor), name); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, const std::array<DimSize_t, DIM>& dims, const char* extension) { + assert(inputIdx != gk_IODefaultIndex); + static_assert(DIM<=MaxDim,"Too many tensor dimensions required by addProducer, not supported"); + const char* prodName = otherNode->name().empty() ? nullptr : (otherNode->name() + std::string("_") + std::string(extension)).c_str(); + auto prod = Producer(dims, prodName); + prod->addChild(otherNode, 0, inputIdx); + otherNode->getOperator()->associateInput(inputIdx, prod->getOperator()->getRawOutput(0)); +} + +template <std::size_t DIM> +void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, DimSize_t const (&dims)[DIM], const char* extension) { + addProducer(otherNode, inputIdx, to_array(dims), extension); +} +} // namespace Aidge + +#endif /* __AIDGE_CORE_OPERATOR_PRODUCER_H__ */ \ No newline at end of file diff --git a/include/operator/ReLU.hpp b/include/operator/ReLU.hpp new file mode 100644 index 000000000..33583cf4b --- /dev/null +++ b/include/operator/ReLU.hpp @@ -0,0 +1,110 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_RELU_H__ +#define __AIDGE_CORE_OPERATOR_RELU_H__ + +#include <cassert> +#include <memory> +#include <vector> + +#include "utils/Registrar.hpp" +#include "operator/Operator.hpp" +#include "backend/OperatorImpl.hpp" +#include "data/Tensor.hpp" +#include "data/Data.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" + +namespace Aidge { + +class ReLU_Op : public Operator, + public Registrable<ReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const ReLU_Op&)> { +public: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "ReLU"; + + ReLU_Op() + : Operator(Type) + { + setDatatype(DataType::Float32); + } + + void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) + mOutput->resize(mInput->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t /*inputIdx*/) const override final { return *(mInput.get()); } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "ReLU Operator has only 1 input"); + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "ReLU Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) { + mImpl = Registrar<ReLU_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + void setDatatype(const DataType& datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 1; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +inline std::shared_ptr<Node> ReLU(const char* name = nullptr) { + // FIXME: properly handle default w&b initialization in every cases + return std::make_shared<Node>(std::make_shared<ReLU_Op>(), name); +} +} + +#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */ diff --git a/include/operator/Softmax.hpp b/include/operator/Softmax.hpp new file mode 100644 index 000000000..8c35ead5a --- /dev/null +++ b/include/operator/Softmax.hpp @@ -0,0 +1,110 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_OPERATOR_SOFTMAX_H__ +#define __AIDGE_CORE_OPERATOR_SOFTMAX_H__ + +#include <cassert> +#include <memory> +#include <vector> + +#include "utils/Registrar.hpp" +#include "operator/Operator.hpp" +#include "backend/OperatorImpl.hpp" +#include "data/Tensor.hpp" +#include "data/Data.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" + +namespace Aidge { + +class Softmax_Op : public Operator, + public Registrable<Softmax_Op, std::string, std::unique_ptr<OperatorImpl>(const Softmax_Op&)> { +public: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Softmax"; + + Softmax_Op() + : Operator(Type) + { + setDatatype(DataType::Float32); + } + + void associateInput(IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) + mOutput->resize(mInput->dims()); + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + + inline Tensor& input(const IOIndex_t /*inputIdx*/) const override final { return *(mInput.get()); } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "Softmax Operator has only 1 input"); + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Softmax Operator has only 1 output"); + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) { + mImpl = Registrar<Softmax_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + void setDatatype(const DataType& datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IONb_t nbInputs() const noexcept override final { return 1; } + inline IONb_t nbDataInputs() const noexcept override final { return 1; } + inline IONb_t nbOutputs() const noexcept override final { return 1; } +}; + +inline std::shared_ptr<Node> Softmax(const char* name = nullptr) { + // FIXME: properly handle default w&b initialization in every cases + return std::make_shared<Node>(std::make_shared<Softmax_Op>(), name); +} +} + +#endif /* __AIDGE_CORE_OPERATOR_SOFTMAX_H__ */ diff --git a/include/scheduler/Scheduler.hpp b/include/scheduler/Scheduler.hpp new file mode 100644 index 000000000..2abe90e11 --- /dev/null +++ b/include/scheduler/Scheduler.hpp @@ -0,0 +1,71 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_SCHEDULER_H__ +#define __AIDGE_SCHEDULER_H__ + +#include <chrono> +#include <memory> +#include <set> +#include <string> +#include <vector> + +namespace Aidge { +class Node; +class GraphView; + +class SequentialScheduler { +public: + struct SchedulingElement { + SchedulingElement( + std::shared_ptr<Node> node_, + std::chrono::time_point<std::chrono::high_resolution_clock> start_, + std::chrono::time_point<std::chrono::high_resolution_clock> end_) + : node(node_), start(start_), end(end_) {} + + std::shared_ptr<Node> node; + std::chrono::time_point<std::chrono::high_resolution_clock> start; + std::chrono::time_point<std::chrono::high_resolution_clock> end; + }; + + SequentialScheduler(std::shared_ptr<GraphView> graphView) + : mGraphView(graphView) + { + // ctor + }; + ~SequentialScheduler() = default; + + /** + * @brief Run the provided Computational Graph with a batch of data + */ + void forward(bool forwardDims = true, bool verbose = false); + + /** + * @brief Save in a Markdown file the order of layers execution. + * @param fileName Name of the generated file. + */ + void saveSchedulingDiagram(const std::string& fileName) const; + +private: + /** + * @brief Set of layers receiving an input from currently processing layers + * + * @param producers Set of layers ready to run. + * @return std::set<std::shared_ptr<Node>> + */ + std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; + + std::shared_ptr<GraphView> mGraphView; + std::vector<SchedulingElement> mScheduling; +}; +} // namespace Aidge + +#endif /* __AIDGE_SCHEDULER_H__ */ \ No newline at end of file diff --git a/include/utils/CParameter.hpp b/include/utils/CParameter.hpp new file mode 100644 index 000000000..c7d0ea23d --- /dev/null +++ b/include/utils/CParameter.hpp @@ -0,0 +1,110 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CPARAMETER_H__ +#define __AIDGE_CPARAMETER_H__ + +#include <assert.h> +#include <map> +#include <vector> +#include <numeric> + +namespace Aidge { + +///\todo store also a fix-sized code that indicates the type +///\todo managing complex types or excluding non-trivial, non-aggregate types +class CParameter +{ +public: + // not copyable, not movable + CParameter(CParameter const &) = delete; + CParameter(CParameter &&) = delete; + CParameter &operator=(CParameter const &) = delete; + CParameter &operator=(CParameter &&) = delete; + CParameter() : m_Params({}){}; + + /** + * \brief Returning a parameter identified by its name + * \tparam T expected parameter type + * \param i_ParamName Parameter name + * \details assert if T is not the actual parameter type, if the parameter does not + * exist or interna parameter position is invalid. + * \todo Returning a T const& ? But dangerous => the client may get an address within + * param buffer that will get invalid after the CParam death. + * \note at() throws if the parameter does not exist, using find to test for parameter existance + */ + template<class T> T Get(std::string const &i_ParamName) const + { + assert(m_Params.find(i_ParamName) != m_Params.end()); + assert(m_Types.find(i_ParamName) != m_Types.end()); + assert(m_Params.at(i_ParamName) <= m_OffSet); + assert(typeid(T).name() == m_Types.at(i_ParamName)); + return *reinterpret_cast<T *>(m_BeginBuffer + m_Params.at(i_ParamName)); + } + + ///\brief Add a parameter value, identified by its name + ///\tparam T expected parameter type + ///\param i_ParamName Parameter name + ///\param i_Value Parameter value + ///\todo Pass i_Value by ref if large or not trivial + ///\bug If parameter already exists, its value is changed but written in the + /// internal buffer in a new location (previous value is still in memory at its previous location) + template<class T> void Add(std::string const &i_ParamName, T const &i_Value) + { + m_Buffer.resize(m_Buffer.size() + (sizeof(T) / sizeof(uint8_t))); + m_BeginBuffer = m_Buffer.data(); // Update buffer ptr in case of memory reordering + *reinterpret_cast<T *>(m_BeginBuffer + m_OffSet) + = i_Value; // Black-magic used to add anytype into the vector + m_Params[i_ParamName] = m_OffSet; // Copy pointer offset + m_OffSet += sizeof(T); // Increment offset + m_Types[i_ParamName] = typeid(i_Value).name(); + } + + std::string getParamType(std::string const &i_ParamName){ + return m_Types[i_ParamName]; + } + + std::vector<std::string> getParametersName(){ + std::vector<std::string> parametersName; + for(auto const& it: m_Params) + parametersName.push_back(it.first); + return parametersName; + } + + + ~CParameter() = default; + +private: + // Note for Cyril: of course storing offset and not address! Good idea + std::map<std::string, std::size_t> m_Params; // { Param name : offset } + + ///\brief Map to check type error + /* Note : i tried this : `std::map<std::string, std::type_info const *> m_Types;` + but looks like the type_ingo object was destroyed. + I am not a hugde fan of storing a string and making string comparison. + Maybe we can use a custom enum type (or is there a standard solution ?) + */ + std::map<std::string, std::string> m_Types; + + ///\brief All parameters values concatenated in raw binary form. + std::vector<uint8_t> m_Buffer = {}; + + ///\brief Starting address of the buffer + uint8_t *m_BeginBuffer = m_Buffer.data(); + + ///\brief Offset, in number of uint8_t, of the next parameter to write + std::size_t m_OffSet = 0; +}; + +} + + +#endif /* __AIDGE_CPARAMETER_H__ */ diff --git a/include/utils/Parameter.hpp b/include/utils/Parameter.hpp new file mode 100644 index 000000000..6a8fcca41 --- /dev/null +++ b/include/utils/Parameter.hpp @@ -0,0 +1,197 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_UTILS_PARAMETER_H__ +#define __AIDGE_CORE_UTILS_PARAMETER_H__ + +#ifdef PYBIND +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <string> // Add this inclue to print error +#endif +#include <tuple> +#include <cassert> +#include <cstddef> + +#ifdef PYBIND +namespace py = pybind11; +#endif + +namespace { +// This is the type that will hold all the strings. Each enumerate type will +// declare its own specialization. +template <typename T> struct EnumStrings { + static const char* const data[]; +}; +} + +namespace Aidge { +template<class T, std::size_t N> +constexpr std::size_t size(T (&)[N]) { return N; } + +#ifdef PYBIND +/* This abstract class allows to avoid binding Parametrizable. +* Otherwise we would need to bind every template possible of Parametrizable. +* Every operators can access the methods of this class by inheriting from +* PyAbstractParametrizable in the binding code. +*/ +class PyAbstractParametrizable{ + public: + /* Bindable get function, does not recquire any templating. + * This is thanks to py::object which allow the function to + * be agnostic from its return type. + */ + virtual py::object getPy(const char* /*name*/) = 0; +}; +#endif + +template <class PARAM_ENUM, class ...T> +class Parameterizable +#ifdef PYBIND + : public PyAbstractParametrizable +#endif + { +public: + using Parameters = std::tuple<T...>; + + // Helper class to pass to the constructor + template <PARAM_ENUM paramEnum> + class param { + public: + constexpr param(const typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& v) : value(v) {} + const typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type value; + }; + +/* + // Direct tuple initialization + Parameterizable(T... params) : mParams({params...}) { + + } +*/ + + // Constructor for parameters initialization. + // Compile-time garantee that every parameter is initialized. + template <PARAM_ENUM ...paramEnum> // non-type parameter pack + constexpr Parameterizable(const param<paramEnum>&&... params) { + // Check number of params consistency + static_assert(sizeof...(params) == std::tuple_size<std::tuple<T...>>::value, "wrong number of parameters in constructor"); + // static_assert(size(EnumStrings<PARAM_ENUM>::data) == std::tuple_size<std::tuple<T...>>::value, "wrong number of parameters in enum string"); + + // Check no duplicates + constexpr std::array<PARAM_ENUM, std::tuple_size<std::tuple<T...>>::value> pe = { paramEnum... }; + static_assert(!hasDuplicates(pe), "duplicate parameter"); // requires C++14 + + // Init params with constructor arguments + const std::array<PARAM_ENUM, std::tuple_size<std::tuple<T...>>::value> p = { ((void)(get<paramEnum>() = params.value), paramEnum) ... }; + (void)p; // avoid unused warning + } + + // Compile-time access with enum + template <PARAM_ENUM paramEnum> + constexpr typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& get() { + return std::get<static_cast<std::size_t>(paramEnum)>(mParams); + } + + template <PARAM_ENUM paramEnum> + constexpr const typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& get() const { + return std::get<static_cast<std::size_t>(paramEnum)>(mParams); + } + + // Runtime access with enum + template <typename R> + constexpr R& get(PARAM_ENUM paramEnum) { + return get<R>(static_cast<std::size_t>(paramEnum)); + } + + template <typename R> + constexpr const R& get(PARAM_ENUM paramEnum) const { + return get<R>(static_cast<std::size_t>(paramEnum)); + } + + // Runtime existance check with name + constexpr bool isParam(const char* name) const { + for (std::size_t i = 0; i < size(EnumStrings<PARAM_ENUM>::data); ++i) { + if (strcmp(EnumStrings<PARAM_ENUM>::data[i], name) == 0) { + return true; + } + } + + return false; + } + + // Runtime access with name + template <typename R> + constexpr R& get(const char* name) { + for (std::size_t i = 0; i < size(EnumStrings<PARAM_ENUM>::data); ++i) { + if (strcmp(EnumStrings<PARAM_ENUM>::data[i], name) == 0) { + return get<R>(i); + } + } + + assert(false && "parameter not found"); + } + + template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value-1> + constexpr typename std::enable_if<(SIZE > 0), R&>::type get(std::size_t i) { + if (i == SIZE) { + if (std::is_same<R, typename std::tuple_element<SIZE,std::tuple<T...>>::type>::value) { + return reinterpret_cast<R&>(std::get<SIZE>(mParams)); + } + else { + assert(false && "wrong parameter type"); + } + } + else { + return get<R, SIZE-1>(i); + } + } + + template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value-1> + constexpr typename std::enable_if<(SIZE <= 0), R&>::type get(std::size_t i) { + assert(false && "parameter not found"); + } + + constexpr const std::tuple<T...>& getParams() const { + return mParams; + } + + #ifdef PYBIND + py::object getPy(const char* name){ + for (std::size_t i = 0; i < size(EnumStrings<PARAM_ENUM>::data); ++i) { + if (strcmp(EnumStrings<PARAM_ENUM>::data[i], name) == 0) { + // https://github.com/pybind/pybind11/blob/f3e0602802c7840992c97f4960515777cad6a5c7/include/pybind11/pytypes.h#L1119-L1138 + // Normal accessor would not work has we convert the tuple to a py::object which can be anything + return py::detail::accessor_policies::tuple_item::get(py::cast(mParams), static_cast<py::size_t>(i)); + } + } + throw py::value_error("Parameter : " + std::string(name) + " does not exist." ); + }; + #endif + +private: + template <typename V, std::size_t N> + static constexpr bool hasDuplicates(const std::array<V, N>& array) { + for (std::size_t i = 1; i < N; i++) { + for (std::size_t j = 0; j < i; j++) { + if (array[i] == array[j]) { + return true; + } + } + } + + return false; + } + + std::tuple<T...> mParams; +}; +} + +#endif /* AIDGE_CORE_UTILS_PARAMETER_H__ */ diff --git a/include/utils/Recipies.hpp b/include/utils/Recipies.hpp new file mode 100644 index 000000000..71ed8feb9 --- /dev/null +++ b/include/utils/Recipies.hpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_UTILS_RECIPIES_H__ +#define __AIDGE_CORE_UTILS_RECIPIES_H__ + +#include "graph/Node.hpp" +#include "graph/GraphView.hpp" + +namespace Aidge{ + +void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); +void removeFlatten(std::shared_ptr<GraphView> view); + + +} + + +#endif /* __AIDGE_CORE_UTILS_RECIPIES_H__ */ \ No newline at end of file diff --git a/include/utils/Registrar.hpp b/include/utils/Registrar.hpp new file mode 100644 index 000000000..8348eb98d --- /dev/null +++ b/include/utils/Registrar.hpp @@ -0,0 +1,75 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef __AIDGE_CORE_UTILS_REGISTRAR_H__ +#define __AIDGE_CORE_UTILS_REGISTRAR_H__ + +#ifdef PYBIND +#include <pybind11/pybind11.h> +#endif + +#include <functional> +#include <map> +#include <cassert> + +namespace Aidge { +#ifdef PYBIND +namespace py = pybind11; +#endif + +template <class DerivedClass, class Key, class Func> // curiously rucurring template pattern +class Registrable { +public: + typedef Key registrar_key; + typedef std::function<Func> registrar_type; + + static std::map<Key, std::function<Func>>& registry() + { + #ifdef PYBIND + if (std::getenv("AIDGE_CORE_WITH_PYBIND")){ + std::string name = std::string("registrar_")+typeid(Registrable<DerivedClass, Key, Func>).name(); + static auto shared_data = reinterpret_cast<std::map<Key, std::function<Func>> *>(py::get_shared_data(name)); + if (!shared_data) + shared_data = static_cast<std::map<Key, std::function<Func>> *>(py::set_shared_data(name, new std::map<Key, std::function<Func>>())); + return *shared_data; + } + #endif // PYBIND + static std::map<Key, std::function<Func>> rMap; + return rMap; + } + +}; + +template <class C> +struct Registrar { + Registrar(const typename C::registrar_key& key, typename C::registrar_type func) { + //printf("REGISTRAR: %s\n", key.c_str()); + bool newInsert; + std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func)); + //assert(newInsert && "registrar already exists"); + } + + static auto create(const typename C::registrar_key& key){ + const auto it = C::registry().find(key); + assert(it != C::registry().end() && "invalid registrar key"); + + return (*it).second; + } + static std::vector<typename C::registrar_key> getKeys(){ + std::vector<typename C::registrar_key> keys; + for(auto keyValue : C::registry()) + keys.push_back(keyValue.first); + return keys; + } +}; +} + +#endif // __AIDGE_CORE_UTILS_REGISTRAR_H__ \ No newline at end of file diff --git a/include/utils/Types.h b/include/utils/Types.h new file mode 100644 index 000000000..f626c6352 --- /dev/null +++ b/include/utils/Types.h @@ -0,0 +1,62 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + +#ifndef __AIDGE_TYPES_H__ +#define __AIDGE_TYPES_H__ + +#include <limits> +#include <type_traits> +#include <cstddef> +#include <cstdint> + +namespace Aidge +{ +////////////////////////////////////// +/// Tensor +////////////////////////////////////// + +/// @brief Number of elements used for scheduling +using NbElts_t = std::size_t; +constexpr NbElts_t MaxElts = std::numeric_limits<NbElts_t>::max(); + +///\brief Signed dimension size for Tensor (allow for negative coordinates). +using Coord_t = std::make_signed<std::size_t>::type; +constexpr Coord_t MaxCoord = std::numeric_limits<Coord_t>::max(); + +///\brief Unsigned value for the size of each dimension for a Tensor. +using DimSize_t = std::size_t; +constexpr DimSize_t MaxDimSize = std::numeric_limits<DimSize_t>::max(); + +///\brief Unsigned index for a Tensor's number of dimension. +using DimIdx_t = std::uint8_t; +constexpr DimIdx_t MaxDim = std::numeric_limits<DimIdx_t>::max(); + +////////////////////////////////////// +/// Operator/Nodes +////////////////////////////////////// + +///\brief Signed integral type to hold an IO index. +///\details <0 values reserved +///\todo Change it for an unsigned value with default to numeric_limit and max to numeric_limit-1 +using IOIndex_t = std::make_signed<std::uint16_t>::type; +/// @brief Default for absence of connection +constexpr IOIndex_t gk_IODefaultIndex = -1; +constexpr IOIndex_t gk_IOMaxIndex = std::numeric_limits<IOIndex_t>::max(); + +///\brief Number of input/output connections for a Node/Operator +using IONb_t = std::uint16_t; +constexpr IONb_t gk_IOMaxNb = std::numeric_limits<IONb_t>::max(); + + +} // namespace Aidge + +#endif // __AIDGE_TYPES_H__ \ No newline at end of file diff --git a/include/utilsParsing/AstNode.hpp b/include/utilsParsing/AstNode.hpp new file mode 100644 index 000000000..28d17a543 --- /dev/null +++ b/include/utilsParsing/AstNode.hpp @@ -0,0 +1,69 @@ + + +#ifndef _AIDGE_AST_NODE_H_ +#define _AIDGE_AST_NODE_H_ + +#include <string> +#include <type_traits> +#include <vector> +#include <memory> +#include "utilsParsing/ParsingToken.hpp" + +namespace Aidge{ + + template <typename EnumType> + class AstNode: public std::enable_shared_from_this<AstNode> + { + static_assert(std::is_enum<EnumType>::value, "AstNode EnumType must be an enum type"); + public: + AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode>> child ={}):mToken(token),mChild(child){} + /** + * @brief get the type of the token + * @return the type + */ + EnumType getType() const{ + return mToken->getType(); + } + + /** + * @brief get the lexeme of the token + * @return the lexeme + */ + std::string getValue() const{ + return mToken->getLexeme(); + } + /** + * @brief get the child of the node + * @return child + */ + const std::vector<std::shared_ptr<AstNode>>& getChilds() const { + return mChild; + } + /** + * @brief test if the node is a leaf in the tree + * @return true if a leaf + */ + bool isLeaf() const { + return mChild.size() == 0; + } + + /** + * @brief get the number of child + * @return the number of child + */ + std::size_t nbChild() const{ + return mChild.size(); + } + private: + /** + * @brief the token of the node + */ + const std::shared_ptr<ParsingToken<EnumType>> mToken; + /** + * @brief list of child + */ + const std::vector<std::shared_ptr<AstNode>> mChild; + }; +} + +#endif //_AIDGE_AST_NODE_H_ diff --git a/include/utilsParsing/ParsingToken.hpp b/include/utilsParsing/ParsingToken.hpp new file mode 100644 index 000000000..78045cf30 --- /dev/null +++ b/include/utilsParsing/ParsingToken.hpp @@ -0,0 +1,66 @@ + +#ifndef _AIDGE_PARSING_TOKEN_H_ +#define _AIDGE_PARSING_TOKEN_H_ + +#include <string> +#include <type_traits> + +namespace Aidge{ + template <typename EnumType> + class ParsingToken: public std::enable_shared_from_this<ParsingToken> + { + static_assert(std::is_enum<EnumType>::value, "ParsingToken EnumType must be an enum type"); + public: + /** + * @brief Token container + * @param type one of the token type + * @param lexeme String representing aditional information of the token + */ + ParsingToken(const EnumType type , const std::string lexeme )mLexeme(lexeme),mType(type){} + + /** + * @brief get the lexeme + * @return std::string + */ + const std::string getLexeme(void){ + return mLexeme; + } + + /** + * @brief get the token type + * + * @return ParsingToken + */ + const EnumType getType(void){ + return mType; + } + + /** + * @brief copy the token + * @return deep copy of the token + */ + std::shared_ptr<Aidge::ParsingToken> copy(); + + //TODO + std::ostringstream rep(void){ + std::ostringstream out; + out << " Token (" << mLexeme <<")" << "\n"; + return out; + } + private: + + /** + * @brief additional information of the token + */ + const std::string mLexeme; + + /** + * @brief type of the token + * @see ConditionalTokenTypes + */ + const EnumType mType; + + }; +} + +#endif //_AIDGE_PARSING_TOKEN_H_ \ No newline at end of file diff --git a/python_binding/CMakeLists.txt b/python_binding/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp new file mode 100644 index 000000000..ca413a7a2 --- /dev/null +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -0,0 +1,20 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "backend/OperatorImpl.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_OperatorImpl(py::module& m){ + py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>>(m, "OperatorImpl"); +} +} diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp new file mode 100644 index 000000000..dfa841bdb --- /dev/null +++ b/python_binding/data/pybind_Data.cpp @@ -0,0 +1,37 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "data/Data.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Data(py::module& m){ + // TODO : extend with more values ! + py::enum_<DataType>(m, "DataType") + .value("Float64", DataType::Float64) + .value("Float32", DataType::Float32) + .value("Float16", DataType::Float16) + .value("Int8", DataType::Int8) + .value("Int32", DataType::Int32) + .value("Int64", DataType::Int64) + .value("UInt8", DataType::UInt8) + .value("UInt32", DataType::UInt32) + .value("UInt64", DataType::UInt64) + ; + + py::class_<Data, std::shared_ptr<Data>>(m,"Data") + .def(py::init<const char*>()); + + +} +} diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp new file mode 100644 index 000000000..38c01cdd2 --- /dev/null +++ b/python_binding/data/pybind_Tensor.cpp @@ -0,0 +1,147 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <pybind11/operators.h> +#include <pybind11/numpy.h> + +#include "data/Tensor.hpp" +#include "data/Data.hpp" +#include "utils/Registrar.hpp" +#include "utils/Types.h" +#include "backend/TensorImpl.hpp" + +namespace py = pybind11; +namespace Aidge { + + +template<typename T> +void addCtor(py::class_<Tensor, + std::shared_ptr<Tensor>, + Data, + Registrable<Tensor, + std::tuple<std::string, DataType>, + std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ + mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { + /* Request a buffer descriptor from Python */ + py::buffer_info info = b.request(); + Tensor* newTensor = new Tensor(); + newTensor->setDatatype(NativeType<T>::type); + const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end()); + newTensor->resize(dims); + // TODO : Find a better way to choose backend + std::set<std::string> availableBackends = Tensor::getAvailableBackends(); + if (availableBackends.find("cpu") != availableBackends.end()){ + newTensor->setBackend("cpu"); + newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); + }else{ + printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); + } + + return newTensor; + })); +} + + +void init_Tensor(py::module& m){ + py::class_<Registrable<Tensor, + std::tuple<std::string, DataType>, + std::unique_ptr<TensorImpl>(const Tensor&)>, + std::shared_ptr<Registrable<Tensor, + std::tuple<std::string, DataType>, + std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable"); + + py::class_<Tensor, std::shared_ptr<Tensor>, + Data, + Registrable<Tensor, + std::tuple<std::string, DataType>, + std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor + (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); + + pyClassTensor.def(py::init<>()) + .def("set_backend", &Tensor::setBackend, py::arg("name")) + .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) + .def("dtype", &Tensor::dataType) + .def("size", &Tensor::size) + .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize) + .def("has_impl", &Tensor::hasImpl) + .def_static("get_available_backends", &Tensor::getAvailableBackends) + .def("__str__", [](Tensor& b) { + return b.toString(); + }) + .def("__len__", [](Tensor& b) -> size_t{ + return b.size(); + }) + .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { + // TODO : Should return error if backend not compatible with get + if (idx >= b.size()) throw py::index_error(); + switch(b.dataType()){ + case DataType::Float64: + return py::cast(static_cast<double*>(b.getImpl()->rawPtr())[idx]); + case DataType::Float32: + return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]); + case DataType::Int32: + return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]); + default: + return py::none(); + } + }) + .def_buffer([](Tensor& b) -> py::buffer_info { + const std::unique_ptr<TensorImpl>& tensorImpl = b.getImpl(); + + std::vector<ssize_t> dims; + std::vector<ssize_t> strides; + ssize_t stride = tensorImpl->scalarSize(); + + for (unsigned int dim = b.nbDims(); dim > 0; dim--) { + dims.push_back(b.dims()[dim-1]); + strides.push_back(stride); + stride *= b.dims()[dim-1]; + } + std::reverse(dims.begin(), dims.end()); + std::reverse(strides.begin(), strides.end()); + + std::string dataFormatDescriptor; + switch(b.dataType()){ + case DataType::Float64: + dataFormatDescriptor = py::format_descriptor<double>::format(); + break; + case DataType::Float32: + dataFormatDescriptor = py::format_descriptor<float>::format(); + break; + case DataType::Int32: + dataFormatDescriptor = py::format_descriptor<int>::format(); + break; + default: + throw py::value_error("Unsupported data format"); + } + + return py::buffer_info( + tensorImpl->rawPtr(), /* Pointer to buffer */ + tensorImpl->scalarSize(), /* Size of one scalar */ + dataFormatDescriptor, /* Python struct-style format descriptor */ + b.nbDims(), /* Number of dimensions */ + dims, /* Buffer dimensions */ + strides /* Strides (in bytes) for each index */ + ); + }); + + // TODO : If the ctor with the right data type does not exist, pybind will always convert the data to INT ! + // Need to find a way to avoid this ! + addCtor<int>(pyClassTensor); + addCtor<float>(pyClassTensor); +// #if SIZE_MAX != 0xFFFFFFFF + addCtor<double>(pyClassTensor); +// #endif + +} +} diff --git a/python_binding/graph/pybind_Connector.cpp b/python_binding/graph/pybind_Connector.cpp new file mode 100644 index 000000000..a937fb4f2 --- /dev/null +++ b/python_binding/graph/pybind_Connector.cpp @@ -0,0 +1,29 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include "graph/Connector.hpp" +#include "graph/Node.hpp" +#include "graph/GraphView.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_Connector(py::module& m){ + py::class_<Connector, std::shared_ptr<Connector>>(m, "Connector") + .def(py::init<>()) + .def(py::init<std::shared_ptr<Node>>()) + .def("__getitem__", &Connector::operator[], py::arg("key")) + ; + m.def("generate_graph", &Aidge::generateGraph, py::arg("output_connectors")); + // m.def("generate_graph", (std::shared_ptr<GraphView>(*)(std::vector<Connector>)) &Aidge::generateGraph, py::arg("output_connectors")); +} +} diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp new file mode 100644 index 000000000..b7ef2e166 --- /dev/null +++ b/python_binding/graph/pybind_GraphView.cpp @@ -0,0 +1,60 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <memory> +#include <string> +#include "graph/GraphView.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" +#include "data/Data.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_GraphView(py::module& m) { + py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView") + .def(py::init<>()) + .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false) + .def("get_output_nodes", &GraphView::outputNodes) + .def("get_input_nodes", &GraphView::inputNodes) + .def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add, + py::arg("other_node"), py::arg("include_learnable_parameters") = true) + .def("add_child", + (void (GraphView::*)(std::shared_ptr<Node>, + std::shared_ptr<Node>, + const IOIndex_t, + IOIndex_t)) & + GraphView::addChild, + py::arg("toOtherNode"), py::arg("fromOutNode") = nullptr, + py::arg("fromTensor") = 0U, py::arg("toTensor") = gk_IODefaultIndex) + .def("replace_with", &GraphView::replaceWith, py::arg("new_nodes")) + .def("get_nodes", &GraphView::getNodes) + .def("get_node", &GraphView::getNode, py::arg("node_name")) + .def("forward_dims", &GraphView::forwardDims) + .def("__call__", &GraphView::operator(), py::arg("connectors")) + .def("set_datatype", &GraphView::setDatatype, py::arg("datatype")) + .def("set_backend", &GraphView::setBackend, py::arg("backend")) + // .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { + // // TODO : Should return error if backend not compatible with get + // if (idx >= b.size()) throw py::index_error(); + // switch(b.dataType()){ + // case DataType::Float32: + // return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]); + // case DataType::Int32: + // return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]); + // default: + // return py::none(); + // } + // }) + ; +} +} // namespace Aidge diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp new file mode 100644 index 000000000..0d957eb2c --- /dev/null +++ b/python_binding/graph/pybind_Node.cpp @@ -0,0 +1,49 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <memory> + +#include "graph/GraphView.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" + +namespace py = pybind11; +namespace Aidge { +void init_Node(py::module& m) { + py::class_<Node, std::shared_ptr<Node>>(m, "Node") + .def("name", &Node::name) + .def("type", &Node::type) + .def("get_operator", &Node::getOperator) + .def("set_name", &Node::setName, py::arg("name")) + .def("add_child", + (void (Node::*)(std::shared_ptr<Node>, const IOIndex_t, IOIndex_t)) & + Node::addChild, + py::arg("other_node"), py::arg("out_id") = 0, py::arg("other_in_id") = -1) + .def("add_child", + (void (Node::*)(std::shared_ptr<GraphView>, const IOIndex_t, + std::pair<std::shared_ptr<Node>, IOIndex_t>)) & + Node::addChild, + py::arg("other_graph"), py::arg("out_id") = 0, + py::arg("other_in_id") = + std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)) + .def("inputs", &Node::inputs) + .def("input", &Node::input, py::arg("inID")) + .def("outputs", &Node::outputs) + .def("output", &Node::output, py::arg("outID")) + .def("get_nb_inputs", &Node::nbInputs) + .def("get_nb_datainputs", &Node::nbDataInputs) + .def("get_nb_outputs", &Node::nbOutputs) + .def("__call__", &Node::operator(), py::arg("connectors")); +} +} // namespace Aidge diff --git a/python_binding/graph/pybind_OpArgs.cpp b/python_binding/graph/pybind_OpArgs.cpp new file mode 100644 index 000000000..fa7a97431 --- /dev/null +++ b/python_binding/graph/pybind_OpArgs.cpp @@ -0,0 +1,38 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "graph/OpArgs.hpp" +#include "graph/Node.hpp" +#include "graph/GraphView.hpp" +#include <pybind11/stl.h> +#include <pybind11/complex.h> +#include <pybind11/functional.h> +#include <pybind11/chrono.h> + + +namespace py = pybind11; +namespace Aidge { +void init_OpArgs(py::module& m){ + py::class_<OpArgs, std::shared_ptr<OpArgs>>(m, "OpArgs") + .def("node", &OpArgs::node) + .def("view", &OpArgs::view) + ; + + py::implicitly_convertible<Node, OpArgs>(); + py::implicitly_convertible<GraphView, OpArgs>(); + + m.def("sequential", &Sequential, py::arg("inputs")); + m.def("parallel", &Parallel, py::arg("inputs")); + m.def("residual", &Residual, py::arg("inputs")); + +} +} diff --git a/python_binding/graphmatching/pybind_GRegex.cpp b/python_binding/graphmatching/pybind_GRegex.cpp new file mode 100644 index 000000000..4eefccec1 --- /dev/null +++ b/python_binding/graphmatching/pybind_GRegex.cpp @@ -0,0 +1,25 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include "graph/GraphView.hpp" +#include "graphmatching/GRegex.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_GRegex(py::module& m){ + py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex") + .def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps")) + .def("match", &GRegex::match, py::arg("graphToMatch")) + ; +} +} diff --git a/python_binding/graphmatching/pybind_Match.cpp b/python_binding/graphmatching/pybind_Match.cpp new file mode 100644 index 000000000..8646ff91d --- /dev/null +++ b/python_binding/graphmatching/pybind_Match.cpp @@ -0,0 +1,25 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include "graphmatching/Match.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_Match(py::module& m){ + py::class_<Match, std::shared_ptr<Match>>(m, "Match") + .def(py::init<>()) + .def("get_nb_match", &Match::getNbMatch) + .def("get_start_nodes", &Match::getStartNodes) + .def("get_match_nodes", &Match::getMatchNodes); +} +} diff --git a/python_binding/graphmatching/pybind_NodeRegex.cpp b/python_binding/graphmatching/pybind_NodeRegex.cpp new file mode 100644 index 000000000..1ee4c9332 --- /dev/null +++ b/python_binding/graphmatching/pybind_NodeRegex.cpp @@ -0,0 +1,22 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "graphmatching/NodeRegex.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_NodeRegex(py::module& m){ + py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex") + .def(py::init<const std::string>(), py::arg("condition")) + ; +} +} diff --git a/python_binding/operator/pybind_Add.cpp b/python_binding/operator/pybind_Add.cpp new file mode 100644 index 000000000..3585db040 --- /dev/null +++ b/python_binding/operator/pybind_Add.cpp @@ -0,0 +1,32 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "operator/Add.hpp" +#include "utils/Parameter.hpp" +#include "backend/OperatorImpl.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +template <std::size_t NUM> void declare_Add(py::module &m) { + py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "Add_Op", py::multiple_inheritance()); + + m.def("Add", &Add<NUM>, py::arg("name") = nullptr); +} + +void init_Add(py::module &m) { + declare_Add<2>(m); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_AvgPooling.cpp b/python_binding/operator/pybind_AvgPooling.cpp new file mode 100644 index 000000000..142f29b5c --- /dev/null +++ b/python_binding/operator/pybind_AvgPooling.cpp @@ -0,0 +1,89 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#ifdef PYBIND +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <vector> +#include <array> + +#include "utils/Parameter.hpp" +#include "backend/OperatorImpl.hpp" +#include "operator/AvgPooling.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" +#include "data/Tensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) { + py::class_<AvgPooling_Op<DIM>, std::shared_ptr<AvgPooling_Op<DIM>>, Operator, PyAbstractParametrizable>( + m, ("AvgPoolingOp" + std::to_string(DIM) + "D").c_str(), + py::multiple_inheritance()) + .def(py::init<const std::array<DimSize_t, DIM> &, + const std::array<DimSize_t, DIM> &, + const std::array<DimSize_t, (DIM<<1)> &>(), + py::arg("kernel_dims"), + py::arg("stride_dims"), + py::arg("padding_dims")); + + m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](std::vector<DimSize_t>& kernel_dims, + const char* name, + std::vector<DimSize_t> &stride_dims, + std::vector<DimSize_t> &padding_dims) { + // Lambda function wrapper because PyBind fails to convert const array. + // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. + if (kernel_dims.size() != DIM) { + throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + if (stride_dims.size() != DIM) { + throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + if (padding_dims.size() != (DIM<<1)) { + throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); + } + DimSize_t tmp_kernel_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_kernel_dims_array[i] = kernel_dims[i]; + } + DimSize_t tmp_stride_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_stride_dims_array[i] = stride_dims[i]; + } + DimSize_t tmp_padding_dims_array[DIM<<1]; + for (size_t i = 0; i < (DIM<<1); ++i) { + tmp_padding_dims_array[i] = padding_dims[i]; + } + const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; + const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; + const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; + return AvgPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array)); + }, py::arg("kernel_dims"), + py::arg("name") = nullptr, + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0)); + +} + + +void init_AvgPooling(py::module &m) { + declare_AvgPoolingOp<1>(m); + declare_AvgPoolingOp<2>(m); + declare_AvgPoolingOp<3>(m); + + // FIXME: + // m.def("AvgPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const + // (&)[1])>(&AvgPooling)); +} +} // namespace Aidge +#endif \ No newline at end of file diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp new file mode 100644 index 000000000..579b983bc --- /dev/null +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -0,0 +1,33 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <string> + +#include "operator/BatchNorm.hpp" +#include "operator/Operator.hpp" +#include "utils/Parameter.hpp" +#include "utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +template <DimSize_t DIM> +void declare_BatchNormOp(py::module& m) { + py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, PyAbstractParametrizable>(m, ("BatchNorm_Op" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()); + + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = nullptr); +} + +void init_BatchNorm(py::module &m) { + declare_BatchNormOp<2>(m); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp new file mode 100644 index 000000000..663afe1bb --- /dev/null +++ b/python_binding/operator/pybind_Conv.cpp @@ -0,0 +1,107 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <vector> +#include <array> + +#include "utils/Parameter.hpp" +#include "backend/OperatorImpl.hpp" +#include "operator/Conv.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { + py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Operator, PyAbstractParametrizable>( + m, ("ConvOp" + std::to_string(DIM) + "D").c_str(), + py::multiple_inheritance()) + .def(py::init<DimSize_t, + DimSize_t, + const std::array<DimSize_t, DIM> &, + const std::array<DimSize_t, DIM> &, + const std::array<DimSize_t, (DIM<<1)> &, + const std::array<DimSize_t, DIM> &>(), + py::arg("in_channels"), + py::arg("out_channels"), + py::arg("kernel_dims"), + py::arg("stride_dims"), + py::arg("padding_dims"), + py::arg("dilation_dims")); + + m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, + DimSize_t out_channels, + std::vector<DimSize_t>& kernel_dims, + const char* name, + std::vector<DimSize_t> &stride_dims, + std::vector<DimSize_t> &padding_dims, + std::vector<DimSize_t> &dilation_dims) { + // Lambda function wrapper because PyBind fails to convert const array. + // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. + if (kernel_dims.size() != DIM) { + throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + if (stride_dims.size() != DIM) { + throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + if (padding_dims.size() != (DIM<<1)) { + throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); + } + if (dilation_dims.size() != DIM) { + throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + DimSize_t tmp_kernel_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_kernel_dims_array[i] = kernel_dims[i]; + } + DimSize_t tmp_stride_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_stride_dims_array[i] = stride_dims[i]; + } + DimSize_t tmp_padding_dims_array[DIM<<1]; + for (size_t i = 0; i < (DIM<<1); ++i) { + tmp_padding_dims_array[i] = padding_dims[i]; + } + DimSize_t tmp_dilation_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_dilation_dims_array[i] = dilation_dims[i]; + } + const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; + const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; + const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; + const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; + return Conv<DIM>(in_channels, out_channels, to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + }, py::arg("in_channels"), + py::arg("out_channels"), + py::arg("kernel_dims"), + py::arg("name") = nullptr, + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), + py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); + +} + + +void init_Conv(py::module &m) { + declare_ConvOp<1>(m); + declare_ConvOp<2>(m); + declare_ConvOp<3>(m); + + // FIXME: + // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const + // (&)[1])>(&Conv)); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_ConvDepthWise.cpp b/python_binding/operator/pybind_ConvDepthWise.cpp new file mode 100644 index 000000000..bbd888d1b --- /dev/null +++ b/python_binding/operator/pybind_ConvDepthWise.cpp @@ -0,0 +1,100 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <vector> +#include <array> + +#include "utils/Parameter.hpp" +#include "backend/OperatorImpl.hpp" +#include "operator/ConvDepthWise.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" +#include "data/Tensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { + py::class_<ConvDepthWise_Op<DIM>, std::shared_ptr<ConvDepthWise_Op<DIM>>, Operator, PyAbstractParametrizable>( + m, ("ConvDepthWiseOp" + std::to_string(DIM) + "D").c_str(), + py::multiple_inheritance()) + .def(py::init<const std::array<DimSize_t, DIM> &, + const std::array<DimSize_t, DIM> &, + const std::array<DimSize_t, (DIM<<1)> &, + const std::array<DimSize_t, DIM> &>(), + py::arg("kernel_dims"), + py::arg("stride_dims"), + py::arg("padding_dims"), + py::arg("dilation_dims")); + + m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](std::vector<DimSize_t>& kernel_dims, + const char* name, + std::vector<DimSize_t> &stride_dims, + std::vector<DimSize_t> &padding_dims, + std::vector<DimSize_t> &dilation_dims) { + // Lambda function wrapper because PyBind fails to convert const array. + // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. + if (kernel_dims.size() != DIM) { + throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + if (stride_dims.size() != DIM) { + throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + if (padding_dims.size() != (DIM<<1)) { + throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); + } + if (dilation_dims.size() != DIM) { + throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); + } + DimSize_t tmp_kernel_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_kernel_dims_array[i] = kernel_dims[i]; + } + DimSize_t tmp_stride_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_stride_dims_array[i] = stride_dims[i]; + } + DimSize_t tmp_padding_dims_array[DIM<<1]; + for (size_t i = 0; i < (DIM<<1); ++i) { + tmp_padding_dims_array[i] = padding_dims[i]; + } + DimSize_t tmp_dilation_dims_array[DIM]; + for (size_t i = 0; i < DIM; ++i) { + tmp_dilation_dims_array[i] = dilation_dims[i]; + } + const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; + const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; + const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; + const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; + return ConvDepthWise<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + }, py::arg("kernel_dims"), + py::arg("name") = nullptr, + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), + py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); + +} + + +void init_ConvDepthWise(py::module &m) { + declare_ConvDepthWiseOp<1>(m); + declare_ConvDepthWiseOp<2>(m); + declare_ConvDepthWiseOp<3>(m); + + // FIXME: + // m.def("ConvDepthWise1D", static_cast<NodeAPI(*)(const char*, int, int, int const + // (&)[1])>(&ConvDepthWise)); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp new file mode 100644 index 000000000..293ae8418 --- /dev/null +++ b/python_binding/operator/pybind_FC.cpp @@ -0,0 +1,32 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "operator/FC.hpp" +#include "utils/Parameter.hpp" +#include "backend/OperatorImpl.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +void declare_FC(py::module &m) { + py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator>(m, "FC_Op", py::multiple_inheritance()); + + m.def("FC", &FC, py::arg("out_channels"), py::arg("nobias") = false, py::arg("name") = nullptr); +} + +void init_FC(py::module &m) { + declare_FC(m); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp new file mode 100644 index 000000000..7aa5d42ba --- /dev/null +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -0,0 +1,67 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <stdio.h> + +#include "backend/OperatorImpl.hpp" +#include "operator/GenericOperator.hpp" +#include "operator/Operator.hpp" +namespace py = pybind11; +namespace Aidge { + +void init_GenericOperator(py::module& m) { + py::class_<GenericOperator_Op, std::shared_ptr<GenericOperator_Op>, Operator>(m, "GenericOperatorOp", + py::multiple_inheritance()) + .def("get_parameter_type", &GenericOperator_Op::getParameterType) + .def("get_parameters_name", &GenericOperator_Op::getParametersName) + .def("add_parameter", &GenericOperator_Op::addParameter<bool>) + .def("add_parameter", &GenericOperator_Op::addParameter<int>) + .def("add_parameter", &GenericOperator_Op::addParameter<float>) + .def("add_parameter", &GenericOperator_Op::addParameter<std::string>) + .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<bool>>) + .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<int>>) + .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<float>>) + .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<std::string>>) + .def("get_parameter", [](GenericOperator_Op& self, std::string key) -> py::object { + /* + This getParameter method returns the good python type without having to have + prior knowledge of the parameter type. + */ + py::object res = py::none(); + std::string paramType = self.getParameterType(key); + if(paramType == typeid(int).name()) + res = py::cast(self.getParameter<int>(key)); + else if(paramType == typeid(float).name()) + res = py::cast(self.getParameter<float>(key)); + else if(paramType == typeid(bool).name()) + res = py::cast(self.getParameter<bool>(key)); + else if(paramType == typeid(std::string).name()) + res = py::cast(self.getParameter<std::string>(key)); + else if(paramType == typeid(std::vector<bool>).name()) + res = py::cast(self.getParameter<std::vector<bool>>(key)); + else if(paramType == typeid(std::vector<int>).name()) + res = py::cast(self.getParameter<std::vector<int>>(key)); + else if(paramType == typeid(std::vector<float>).name()) + res = py::cast(self.getParameter<std::vector<float>>(key)); + else if(paramType == typeid(std::vector<std::string>).name()) + res = py::cast(self.getParameter<std::vector<std::string>>(key)); + else { + throw py::key_error("Failed to convert parameter type " + key + ", this issue may come from typeid function which gave an unknown key : [" + paramType + "]. Please open an issue asking to add the support for this key."); + } + return res; + }); + + m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"), + py::arg("name") = nullptr); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_LeakyReLU.cpp b/python_binding/operator/pybind_LeakyReLU.cpp new file mode 100644 index 000000000..b8ffb9c5d --- /dev/null +++ b/python_binding/operator/pybind_LeakyReLU.cpp @@ -0,0 +1,26 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "operator/LeakyReLU.hpp" +#include "operator/Operator.hpp" +#include "utils/Parameter.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_LeakyReLU(py::module& m) { + py::class_<LeakyReLU_Op, std::shared_ptr<LeakyReLU_Op>, Operator, PyAbstractParametrizable>(m, "LeakyReLU_Op", py::multiple_inheritance()); + + m.def("LeakyReLU", &LeakyReLU, py::arg("negative_slope") = 0.0f, py::arg("name") = nullptr); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Matmul.cpp b/python_binding/operator/pybind_Matmul.cpp new file mode 100644 index 000000000..7a2748fcd --- /dev/null +++ b/python_binding/operator/pybind_Matmul.cpp @@ -0,0 +1,32 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "operator/Matmul.hpp" +#include "utils/Parameter.hpp" +#include "backend/OperatorImpl.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +void declare_Matmul(py::module &m) { + py::class_<Matmul_Op, std::shared_ptr<Matmul_Op>, Operator>(m, "Matmul_Op", py::multiple_inheritance()); + + m.def("Matmul", &Matmul, py::arg("out_channels"), py::arg("name") = nullptr); +} + +void init_Matmul(py::module &m) { + declare_Matmul(m); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp new file mode 100644 index 000000000..cf682cc73 --- /dev/null +++ b/python_binding/operator/pybind_Operator.cpp @@ -0,0 +1,28 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "backend/OperatorImpl.hpp" +#include "operator/Operator.hpp" +#include <pybind11/stl.h> + +namespace py = pybind11; +namespace Aidge { +void init_Operator(py::module& m){ + py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") + .def("output", &Operator::output, py::arg("outputIdx")) + .def("input", &Operator::input, py::arg("inputIdx")) + .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) + .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) + .def("set_backend", &Operator::setBackend, py::arg("name")) + ; +} +} diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp new file mode 100644 index 000000000..c47329941 --- /dev/null +++ b/python_binding/operator/pybind_Producer.cpp @@ -0,0 +1,50 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "utils/Types.h" +#include "utils/Parameter.hpp" +// #include "backend/OperatorImpl.hpp" +#include "operator/Operator.hpp" +#include "operator/Producer.hpp" +#include "data/Tensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> +void declare_Producer(py::module &m) { + // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const char*)>(&Producer), py::arg("dims"), py::arg("name") = nullptr); + +} + + +void init_Producer(py::module &m) { + py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>( + m, + "ProducerOp", + py::multiple_inheritance()) + .def("dims", &Producer_Op::dims); + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const char*)>(&Producer), py::arg("tensor"), py::arg("name") = nullptr); + + declare_Producer<1>(m); + declare_Producer<2>(m); + declare_Producer<3>(m); + declare_Producer<4>(m); + declare_Producer<5>(m); + declare_Producer<6>(m); + + // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); +} +} diff --git a/python_binding/operator/pybind_ReLU.cpp b/python_binding/operator/pybind_ReLU.cpp new file mode 100644 index 000000000..794951828 --- /dev/null +++ b/python_binding/operator/pybind_ReLU.cpp @@ -0,0 +1,25 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "operator/ReLU.hpp" +#include "operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_ReLU(py::module& m) { + py::class_<ReLU_Op, std::shared_ptr<ReLU_Op>, Operator>(m, "ReLU_Op", py::multiple_inheritance()); + + m.def("ReLU", &ReLU, py::arg("name") = nullptr); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Softmax.cpp b/python_binding/operator/pybind_Softmax.cpp new file mode 100644 index 000000000..1cccab6e3 --- /dev/null +++ b/python_binding/operator/pybind_Softmax.cpp @@ -0,0 +1,26 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <string> + +#include "operator/Softmax.hpp" +#include "operator/Operator.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Softmax(py::module& m) { + py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "Softmax_Op", py::multiple_inheritance()); + + m.def("Softmax", &Softmax, py::arg("name") = nullptr); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp new file mode 100644 index 000000000..b861f881c --- /dev/null +++ b/python_binding/pybind_core.cpp @@ -0,0 +1,92 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +namespace py = pybind11; + +namespace Aidge { +void init_Data(py::module&); +void init_Tensor(py::module&); +void init_OperatorImpl(py::module&); +void init_Parameterizable(py::module&); +void init_Operator(py::module&); + +void init_Add(py::module&); +void init_AvgPooling(py::module&); +void init_BatchNorm(py::module&); +void init_Conv(py::module&); +void init_ConvDepthWise(py::module&); +void init_FC(py::module&); +void init_GenericOperator(py::module&); +void init_LeakyReLU(py::module&); +void init_Matmul(py::module&); +void init_Producer(py::module&); +void init_ReLU(py::module&); +void init_Softmax(py::module&); + +void init_Node(py::module&); +void init_GraphView(py::module&); +void init_OpArgs(py::module&); +void init_Connector(py::module&); + +void init_Match(py::module&); +void init_NodeRegex(py::module&); +void init_GRegex(py::module&); + +void init_Recipies(py::module&); + +void init_Scheduler(py::module&); + + +void set_python_flag(){ + // Set an env variable to know if we run with ypthon or cpp + py::module os_module = py::module::import("os"); + os_module.attr("environ")["AIDGE_CORE_WITH_PYBIND"] = "1"; +} + +void init_Aidge(py::module& m){ + set_python_flag(); + init_Data(m); + init_Tensor(m); + + init_Node(m); + init_GraphView(m); + init_OpArgs(m); + init_Connector(m); + + init_OperatorImpl(m); + init_Parameterizable(m); + init_Operator(m); + init_Add(m); + init_AvgPooling(m); + init_BatchNorm(m); + init_Conv(m); + init_ConvDepthWise(m); + init_FC(m); + init_GenericOperator(m); + init_LeakyReLU(m); + init_Matmul(m); + init_ReLU(m); + init_Softmax(m); + + init_Producer(m); + init_Match(m); + init_NodeRegex(m); + init_GRegex(m); + init_Recipies(m); + init_Scheduler(m); +} + +PYBIND11_MODULE(aidge_core, m) { + init_Aidge(m); +} +} diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp new file mode 100644 index 000000000..c9a1e0384 --- /dev/null +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> + +#include "utils/Recipies.hpp" + +namespace py = pybind11; + +namespace Aidge { +void init_Recipies(py::module &m) { + m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes")); + m.def("remove_flatten", &removeFlatten, py::arg("view")); + +} +} // namespace Aidge diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp new file mode 100644 index 000000000..0f2598c75 --- /dev/null +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -0,0 +1,26 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "scheduler/Scheduler.hpp" +#include "graph/GraphView.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_Scheduler(py::module& m){ + py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler") + .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) + .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) + .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) + ; +} +} + diff --git a/python_binding/utils/pybind_Parameter.cpp b/python_binding/utils/pybind_Parameter.cpp new file mode 100644 index 000000000..95d7d93a3 --- /dev/null +++ b/python_binding/utils/pybind_Parameter.cpp @@ -0,0 +1,12 @@ +#include <pybind11/pybind11.h> +#include "utils/Parameter.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_Parameterizable(py::module& m){ + py::class_<PyAbstractParametrizable, std::shared_ptr<PyAbstractParametrizable>>(m, "PyAbstractParametrizable") + .def("get", &PyAbstractParametrizable::getPy, py::arg("name")) + ; +} +} + diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/src/graph/Connector.cpp b/src/graph/Connector.cpp new file mode 100644 index 000000000..4297453f8 --- /dev/null +++ b/src/graph/Connector.cpp @@ -0,0 +1,54 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graph/Connector.hpp" + +#include <map> + +#include "graph/GraphView.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" + +Aidge::Connector::Connector(std::shared_ptr<Aidge::Node> node) { + mNode = node; + if (mNode->nbOutputs() == 1U) { + mOutputId = 0; + } +} + +Aidge::IONb_t Aidge::Connector::size() const { return mNode->nbOutputs(); } + +std::shared_ptr<Aidge::GraphView> Aidge::generateGraph(std::vector<Connector> ctors) { + std::shared_ptr<GraphView> graph = std::make_shared<GraphView>(); + std::vector<std::shared_ptr<Node>> nodesToAdd = std::vector<std::shared_ptr<Node>>(); + for (const Connector& ctor : ctors) { + nodesToAdd.push_back(ctor.node()); + } + std::vector<std::shared_ptr<Node>> buffer = {}; + + while (!nodesToAdd.empty()) { + while (!nodesToAdd.empty()) { + graph->add(nodesToAdd.back()); // only add, connection already done + // between nodes + std::vector<std::shared_ptr<Node>> parents = nodesToAdd.back()->getParents(); + std::set<std::shared_ptr<Node>> alreadyAdded = graph->getNodes(); + for (std::shared_ptr<Node> parent : parents) { + if (alreadyAdded.find(parent) == alreadyAdded.end()) { + buffer.push_back(parent); + } + } + nodesToAdd.pop_back(); + } + nodesToAdd.insert(nodesToAdd.end(), buffer.begin(), buffer.end()); + buffer = {}; + } + return graph; +} \ No newline at end of file diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp new file mode 100644 index 000000000..3cce4f449 --- /dev/null +++ b/src/graph/GraphView.cpp @@ -0,0 +1,673 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> +#include <cassert> +#include <iterator> +#include <utility> + +#include <utils/Types.h> +#include "graph/GraphView.hpp" +#include <data/Tensor.hpp> + +/////////////////////////////////////////////////////// +// FUNCTIONAL DESCRIPTION +/////////////////////////////////////////////////////// + +Aidge::Connector Aidge::GraphView::operator()( + const std::vector<Aidge::Connector> ctors) { + // TODO: allow for multiple inputNodes? + assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour"); + std::shared_ptr<Node> inNode = *inputNodes().begin(); + assert((ctors.size() == static_cast<std::size_t>(inNode->nbDataInputs())) && "Wrong number of arguments.\n"); + for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) { + assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); + } + + for (const Connector &ctor : ctors) { + assert((ctor.node() != nullptr) && + "Input Connector must be associated with a node"); + } + IOIndex_t inID = 0; + for (const Connector &ctor : ctors) { + ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()), + {inNode, inID++}); + } + return Connector(*(outputNodes().begin())); +} + +/////////////////////////////////////////////////////// +// INNER +/////////////////////////////////////////////////////// + +std::string Aidge::GraphView::name() const { return mName; } + +void Aidge::GraphView::setName(const std::string &name) { mName = name; } + + +void Aidge::GraphView::save(std::string path, bool verbose) const { + FILE *fp = std::fopen((path + ".mmd").c_str(), "w"); + std::fprintf(fp, + "%%%%{init: {'flowchart': { 'curve': 'monotoneY'}, " + "'fontFamily': 'Verdana' } }%%%%\nflowchart TB\n\n"); + + std::map<const std::string, std::size_t> typeCounter; + std::map<std::shared_ptr<Node>, std::string> namePtrTable; + + // Start by creating every node + for (const std::shared_ptr<Node> &node_ptr : mNodes) { + const std::string currentType = node_ptr->type(); + if (typeCounter.find(currentType) == typeCounter.end()) + typeCounter[currentType] = 0; + ++typeCounter[currentType]; + + const std::string givenName = + (node_ptr->name().empty()) + ? currentType + std::to_string(typeCounter[currentType]) + : node_ptr->name(); + namePtrTable[node_ptr] = + (currentType + "_" + std::to_string(typeCounter[currentType])); + std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } + // Write every link + std::size_t emptyInputCounter = 0; + for (const std::shared_ptr<Node> &node_ptr : mNodes) { + for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) { + if ((pa_ptr == nullptr) || !inView(pa_ptr)) { + std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter, + emptyInputCounter, namePtrTable[node_ptr].c_str()); + ++emptyInputCounter; + } else { + std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(), + namePtrTable[node_ptr].c_str()); + } + } + } + if (verbose) { + for (const auto &c : typeCounter) { + std::printf("%s - %zu\n", c.first.c_str(), c.second); + } + } + + std::fprintf(fp, "\n"); + std::fclose(fp); +} + +/////////////////////////////////////////////////////// +// TENSOR MANAGEMENT +/////////////////////////////////////////////////////// + +Aidge::IONb_t Aidge::GraphView::getNbDataInputs() const { + IONb_t nbDataInput = static_cast<IONb_t>(0); + assert(outputNodes().size() == static_cast<std::size_t>(1)); + for (const std::shared_ptr<Node> &inNode : inputNodes()) { + nbDataInput += inNode->nbDataInputs(); + } + return nbDataInput; +} + +Aidge::IONb_t Aidge::GraphView::getNbFreeDataInputs() const { + IONb_t nbIn = 0; + for (const std::shared_ptr<Node> inputNode : mInputNodes) { + nbIn += inputNode->getNbFreeDataInputs(); + } + return nbIn; +} + + +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> +Aidge::GraphView::dataInputs() const { + IONb_t nbDataIn = 0U; + for (const std::shared_ptr<Node> inputNode : mInputNodes) { + nbDataIn += inputNode->nbDataInputs(); + } + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn); + nbDataIn = 0U; + for (const std::shared_ptr<Node> inputNode : mInputNodes) { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inputNode->dataInputs(); + std::move(inputNodeinputs.begin(), inputNodeinputs.end(), + res.begin() + nbDataIn); + nbDataIn += inputNode->nbDataInputs(); + // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> + // inputs()).end()); + } + return res; +} + + +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> +Aidge::GraphView::inputs() const { + std::size_t nbIn = 0U; + for (const std::shared_ptr<Node> inputNode : mInputNodes) { + nbIn += inputNode->nbInputs(); + } + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn); + nbIn = 0U; + for (const std::shared_ptr<Node> inputNode : mInputNodes) { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inputNode->inputs(); + std::move(inputNodeinputs.begin(), inputNodeinputs.end(), + res.begin() + nbIn); + nbIn += inputNode->nbInputs(); + // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> + // inputs()).end()); + } + return res; +} + + +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> +Aidge::GraphView::inputs(std::string name) const { + return mNodeRegistry.at(name)->inputs(); +} + +void Aidge::GraphView::forwardDims() { + // setInputs + // Link every tensor to the right pointer + // following parent - children informations + for (std::shared_ptr<Node> nodePtr : getNodes()) { + for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nodePtr->nbInputs(); ++i) { + // assess if the input was not already set and is a Tensor then link it to parent output + std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); + if (inputI.first) { + if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { + if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) { + // assert provided Data is of "Tensor" type + nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); + } + else { + assert(false && "Non-tensor entries not handled yet.\n"); + } + } + } else + { + assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); + } + + } + } + // Compute dimensions of every node + _forwardDims(inputNodes()); +} + +void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { + // TODO: support multi-inputs/outputs + std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>(); + for (std::shared_ptr<Node> nodePtr : listNodes) { + if (!nodePtr->getOperator()->outputDimsForwarded()) { + nodePtr->getOperator()->computeOutputDims(); + } + if (!nodePtr->getOperator()->outputDimsForwarded()) { + nextList.insert(nodePtr); + } else { + std::set<std::shared_ptr<Node>> children = nodePtr->getChildren(); + nextList.insert(children.begin(), children.end()); + } + } + if (nextList.empty()) { + for (std::shared_ptr<Node> nodePtr : getNodes()) { + if (!nodePtr->getOperator()->outputDimsForwarded()) { + nextList.insert(nodePtr); + } + } + } + if (!nextList.empty()) { + _forwardDims(nextList); + } +} + +void Aidge::GraphView::setBackend(const std::string &backend) { + for (auto node : getNodes()) { + node->getOperator()->setBackend(backend); + } +} + +void Aidge::GraphView::setDatatype(const DataType &datatype) { + for (auto node : getNodes()) { + node->getOperator()->setDatatype(datatype); + } +} + +void Aidge::GraphView::updateOutputNodes() { + mOutputNodes.clear(); + for (const std::shared_ptr<Node> go_it : mNodes) { + if (go_it->nbOutputs() != + go_it->nbValidOutputs()) { // an output linked to nothing + mOutputNodes.insert(go_it); + continue; + } + for (const std::shared_ptr<Node> ch_ptr : go_it->getChildren()) { + if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph + mOutputNodes.insert(go_it); + break; + } + } + } +} + +void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) { + if (node->nbOutputs() != + node->nbValidOutputs()) { // an output linked to nothing + mOutputNodes.insert(node); + } else { // don't enter if was already added to outputNodes + for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) { + if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph + mOutputNodes.insert(node); + break; + } + } + } + // update other outputNodes + for (const std::shared_ptr<Node> &pa_ptr : + node->getParents()) { // check if any parent is in OutputNodes too + if ((pa_ptr != nullptr) && + (mOutputNodes.find(pa_ptr) != + mOutputNodes.end())) { // it's a match! Must check if the outputNode + // found is still an outputNode + bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs()); + for (const std::shared_ptr<Node> ch_ptr : pa_ptr->getChildren()) { + if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph + remove = false; + break; + } + } + if (remove) { + mOutputNodes.erase(pa_ptr); + } + } + } +} + +std::vector< + std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> +Aidge::GraphView::outputs() const { + std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> + outputTensors; + for (const std::shared_ptr<Node> outputNode : mOutputNodes) { + std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> + tmpOutputs = (outputNode->outputs()); + outputTensors.insert(outputTensors.end(), tmpOutputs.begin(), + tmpOutputs.end()); + } + return outputTensors; +} + +std::vector< + std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> +Aidge::GraphView::outputs(std::string nodeName) const { + return mNodeRegistry.at(nodeName)->outputs(); +} + +void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/, + Aidge::IOIndex_t /*newNodeOutID*/) { + printf("Not implemented yet.\n"); +} + +void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) { + // add to the GraphView nodes + node->addView(shared_from_this()); + mNodes.insert(node); + if (!(node->name()).empty()) + mNodeRegistry.insert(std::make_pair(node->name(), node)); + // add learnable parameters to the graph + if (includeLearnableParam) { + for (IONb_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { + std::shared_ptr<Node> parentNode = node->getParents(static_cast<IOIndex_t>(i)); + if (parentNode) { + parentNode->addView(shared_from_this()); + mNodes.insert(parentNode); + if (!(parentNode->name()).empty()) + mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); + // check if the Node is an input node + updateInputNodes(parentNode); + } + } + } + // check if the Node is an input node + updateInputNodes(node); + // check if the Node is an input node + updateOutputNodes(node); +} + +void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { + for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); } +} + +void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { + for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) { + node_ptr->addView(shared_from_this()); + mNodes.insert(node_ptr); + if (!(node_ptr->name()).empty()) + mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr)); + // if node_ptr is part of graph inputNodes or outputNodes + // if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) { + // Update OutputNodes/inputNodes + updateInputNodes(); + updateOutputNodes(); + } +} + +void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, + std::shared_ptr<Node> fromOutNode, + const Aidge::IOIndex_t fromTensor, + Aidge::IOIndex_t toTensor) { + if (fromOutNode) + assert(inView(fromOutNode) && "Output Node not found in the GraphView."); + else { + assert((outputNodes().size() == 1U) && + "Must specify an outputNode or have only one."); + fromOutNode = *(outputNodes().begin()); + } + fromOutNode->addChild(toOtherNode, fromTensor, toTensor); + add(toOtherNode); +} + +void Aidge::GraphView::addChild( + std::shared_ptr<GraphView> toOtherView, + std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> fromOutNode, + std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t> toNode) { + // assert output node is valid + if (!fromOutNode.first) { + assert(outputNodes().size() == 1U && + "If no output node is provided, the graph should have only one to " + "make the choice explicit."); + fromOutNode.first = *(outputNodes().begin()); + } else + assert(inView(fromOutNode.first)); + // assert input node is valid + if (!toNode.first) { + assert(toOtherView->inputNodes().size() == 1U && + "If no intput node is provided, the other graph should have only " + "one to make the choice explicit."); + toNode.first = *(toOtherView->inputNodes().begin()); + } else { + assert(toOtherView->inView(toNode.first)); + } + // Tensor assertions are performed in the Node adChild method + fromOutNode.first->addChild(toNode.first, fromOutNode.second, toNode.second); + // once linking performed, add other graph to current graph + add(toOtherView); +} + +std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const { + // TODO: choose if we return a set or a vector + std::set<std::shared_ptr<Node>> parents; + for (const std::shared_ptr<Node> inputNode : mInputNodes) { + parents.insert(inputNode->getParents().begin(), + inputNode->getParents().end()); + } + return parents; +} + +std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std::string nodeName) const { + std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName); + if (it == mNodeRegistry.end()) { + printf("No such node a %s in %s graph.\n", nodeName.c_str(), name().c_str()); + exit(-1); + } + return (it->second)->getParents(); +} + +std::vector<std::vector<std::shared_ptr<Aidge::Node>>> +Aidge::GraphView::getOrderedParents() const { + std::vector<std::vector<std::shared_ptr<Node>>> parents; + for (const std::shared_ptr<Node> inputNode : mInputNodes) { + parents.push_back(inputNode->getParents()); + } + return parents; +} + +std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const { + std::set<std::shared_ptr<Node>> children; + for (const std::shared_ptr<Node> outputNode : mOutputNodes) { + children.insert((outputNode->getChildren()).begin(), + (outputNode->getChildren()).end()); + } + return children; +} + +std::vector<std::vector<std::shared_ptr<Aidge::Node>>> +Aidge::GraphView::getChildren(const std::string nodeName) const { + std::map<std::string, std::shared_ptr<Node>>::const_iterator it = + mNodeRegistry.find(nodeName); + if (it == mNodeRegistry.end()) { + printf("No such node a %s in %s graph.\n", nodeName.c_str(), + name().c_str()); + exit(-1); + } + return (it->second)->getOrderedChildren(); +} + +std::set<std::shared_ptr<Aidge::Node>> +Aidge::GraphView::getChildren(const std::shared_ptr<Node> otherNode) const { + std::set<std::shared_ptr<Node>>::const_iterator it = mNodes.find(otherNode); + if (it == mNodes.end()) { + printf("No such node in graph.\n"); + exit(-1); + } + return (*it)->getChildren(); +} + + +std::shared_ptr<Aidge::Node> +Aidge::GraphView::getNode(const char *nodeName) const { + std::map<std::string, std::shared_ptr<Node>>::const_iterator it = + mNodeRegistry.find(std::string(nodeName)); + if (it != mNodeRegistry.end()) { + return it->second; + } else { + printf("No Node named %s in the current GraphView.\n", nodeName); + exit(-1); + } +} + + +void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) { + if (mNodes.find(nodePtr) != mNodes.end()) { + mNodes.erase(nodePtr); + nodePtr->removeView(shared_from_this()); + } + if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } + // same for learnable params + + if (includeLearnableParam) { + for (IONb_t i = nodePtr->nbDataInputs(); i < nodePtr->nbInputs(); ++i) { + auto inputI = nodePtr->input(i); + bool removeNode = true; + for (const auto& parentOutput : inputI.first->outputs()) { + for (const auto& childOfParentOutput : parentOutput) { + // only remove the learnable parameter if not related to any other Node in the GraphView + if (childOfParentOutput.first != nodePtr) { + removeNode = false; + break; + } + } + } + if (removeNode) { + // assert Learnable Parameter in the GraphView scope + if (mNodes.find(inputI.first) != mNodes.end()) { + mNodes.erase(inputI.first); + inputI.first->removeView(shared_from_this()); + } + if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } + } + } + } + updateInputNodes(); + updateOutputNodes(); +} + + +bool Aidge::GraphView::swap(Node & /*node*/, Node & /*otherNode*/) { + printf("Swap() not implementated yet. Return false.\n"); + return false; +} + +void Aidge::GraphView::link(std::string /*name1_inID*/, + std::string /*name2_outID*/) { + printf("Not implemented yet.\n"); +} + +void Aidge::GraphView::insert(Node & /*newNode*/, Node & /*inNode*/, + std::initializer_list<Node> /*outNodes*/, + IOIndex_t /*tensorIdx*/) { + printf("Not implemented yet.\n"); +} + +bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { + // TODO : only supports one input/output node for now + assert(mNodes.size()>0 && "There must be at least one Node to replace"); + + bool replacable; + std::shared_ptr<Node> previousInputNode; + std::shared_ptr<Node> newInputNode; + std::shared_ptr<Node> previousOutputNode; + std::shared_ptr<Node> newOutputNode; + + auto gNew = std::make_shared<GraphView>(); + gNew->add(newNodes, false); + + if (newNodes.empty()) { + replacable = (outputNodes().size() == 1) && + (inputNodes().size() == 1) && + ((*outputNodes().begin())->nbOutputs() == 1) && + ((*inputNodes().begin())->nbInputs() == 1); + previousOutputNode = (*outputNodes().begin()); + previousInputNode = (*inputNodes().begin()); + newOutputNode = previousInputNode->input(0).first; + } else { + replacable = ((outputNodes().size() == gNew->outputNodes().size()) && + (outputNodes().size() == 1)); + previousOutputNode = (*outputNodes().begin()); + newOutputNode = (*gNew->outputNodes().begin()); + replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); + } + + if (replacable) { + auto copyOutputs = previousOutputNode->outputs(); + + // manage Views for newNodes + // only keep common views to each node for the new set + std::set<std::shared_ptr<GraphView>> commonGraphViews = (*mNodes.begin())->views(); + for (const auto& nodePtr : mNodes) { + const auto nodeView = nodePtr->views(); + std::set<std::shared_ptr<GraphView>> intersection; + std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), + nodeView.begin(), nodeView.end(), + std::inserter(intersection, intersection.begin())); + commonGraphViews = intersection; + } + + // clean Nodes to replace + std::set<std::shared_ptr<Node>> copyNode = mNodes; + for (auto& nodePtr : copyNode) { nodePtr->resetConnections(true); } + + // copy output connections + for (IONb_t o = 0; o < previousOutputNode->nbOutputs(); ++o) { + auto outputPairs = copyOutputs[o]; + for (const auto& onePair : outputPairs) { + newOutputNode->addChild(onePair.first, o, onePair.second); + } + } + + // insert new Nodes in the right GraphViews + for (auto& graphPtr : commonGraphViews) { + graphPtr->add(newNodes, false); + if (newNodes.empty()) { + graphPtr->updateInputNodes(); + graphPtr->updateOutputNodes(); + } + } + } + return replacable; +} + +void Aidge::GraphView::updateInputNodes() { + mInputNodes.clear(); + for (const std::shared_ptr<Node> go_ptr : mNodes) { + for (const std::shared_ptr<Node> pa_ptr : go_ptr->getParents()) { + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == + mNodes.end())) { // Parent doesn't exist || Parent not in the graph + mInputNodes.insert(go_ptr); + break; + } + } + } +} + +void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { + // add node_ptr to inputNode if it can + std::size_t filledWithKnownInputs = 0U; + bool wasAdded = mInputNodes.find(node) != mInputNodes.end(); + for (const std::shared_ptr<Node> pa_ptr : node->getParents()) { + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == + mNodes.end())) { // Parent doesn't exist || Parent not in the graph + mInputNodes.insert(node); + wasAdded = true; + break; + } + ++filledWithKnownInputs; + } + if (filledWithKnownInputs == node->nbInputs() && wasAdded) { + mInputNodes.erase(node); + } + // update other inputNodes + for (const std::shared_ptr<Node> ch_ptr : + node->getChildren()) { // check if any child is in InputNodes too + if (mInputNodes.find(ch_ptr) != + mInputNodes.end()) { // it's a match! Must check if the inputNode found + // is still an inputNode + // change here + bool remove = true; + for (const std::shared_ptr<Node> pa_ptr : ch_ptr->getParents()) { + if (pa_ptr == nullptr || + mNodes.find(pa_ptr) == + mNodes + .end()) { // Parent doesn't exist || Parent not in the graph + remove = false; + break; + } + } + if (remove) { + mInputNodes.erase(ch_ptr); + } + } + } +} + + +void Aidge::GraphView::removeInputNode(const std::string nodeName) { + std::map<std::string, std::shared_ptr<Node>>::iterator it = + mNodeRegistry.find(nodeName); + if (it != mNodeRegistry.end()) { + const std::shared_ptr<Node> val = (*it).second; + if (mInputNodes.find(val) != mInputNodes.end()) { + mInputNodes.erase(val); + } + } +} + +void Aidge::GraphView::removeOutputNode(const std::string nodeName) { + std::map<std::string, std::shared_ptr<Node>>::iterator it = + mNodeRegistry.find(nodeName); + if (it != mNodeRegistry.end()) { + const std::shared_ptr<Node> val = (*it).second; + if (mOutputNodes.find(val) != mOutputNodes.end()) { + mOutputNodes.erase(val); + } + } +} \ No newline at end of file diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp new file mode 100644 index 000000000..3dbc24322 --- /dev/null +++ b/src/graph/Node.cpp @@ -0,0 +1,327 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graph/Node.hpp" + +#include "graph/GraphView.hpp" +#include "operator/Producer.hpp" +#include <memory> +#include <vector> +#include "utils/Types.h" + +Aidge::Node::Node(std::shared_ptr<Operator> op, const char *name) + : mName((name == nullptr) ? std::string() : std::string(name)), + mOperator(op), + mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), + mChildren(std::vector<std::vector<std::shared_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), + std::vector<std::shared_ptr<Node>>())), + mIdInChildren( + std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), + mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { + // ctor +} + +/////////////////////////////////////////////////////// +// FUNCTIONAL DESCRIPTION +/////////////////////////////////////////////////////// + +Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> ctors) { + assert((ctors.size() == nbDataInputs()) && "Wrong number of arguments.\n"); + for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { + assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); + } + IOIndex_t i = 0; + for (const Connector &ctor : ctors) { + if (ctor.node() != nullptr) { // ctor must be associated with a node + ctor.node()->addChild(shared_from_this(), ctor.index(), i++); + } + } + return Connector(shared_from_this()); +} + +/////////////////////////////////////////////////////// +// INNER +/////////////////////////////////////////////////////// + +void Aidge::Node::setName(const std::string &name) { mName = name; } + +/////////////////////////////////////////////////////// +// OPERATORS +/////////////////////////////////////////////////////// + +void Aidge::Node::forward() { + assert((mOperator != nullptr) && "No Operator interface provided, can't run forward().\n"); + mOperator->forward(); +} + +void Aidge::Node::backward() { + assert((mOperator != nullptr) && "No Operator interface provided, can't run backward().\n"); + mOperator->backward(); +} + +/////////////////////////////////////////////////////// +// TENSOR MANAGEMENT +/////////////////////////////////////////////////////// + +bool Aidge::Node::valid() const { + for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbInputs(); ++i) { + if (mIdOutParents[static_cast<std::size_t>(i)] == gk_IODefaultIndex) { + return false; + } + } + return true; +} + +Aidge::IONb_t Aidge::Node::getNbFreeDataInputs() const { + IONb_t nbFreeDataIn = 0; + for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbInputs(); ++i) { + if (input(i).second < 0) { + ++nbFreeDataIn; + } + } + return nbFreeDataIn; +} + +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> +Aidge::Node::dataInputs() const { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataInputs()); + for (std::size_t i = 0; i < static_cast<std::size_t>(nbDataInputs()); ++i) { + res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); + } + return res; +} + +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); + for (std::size_t i = 0; i < nbInputs(); ++i) { + res[i] = + std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); + } + return res; +} + +void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { + assert((idx != gk_IODefaultIndex) && (static_cast<IONb_t>(idx) < nbInputs()) && "Parent index out of bound."); + if (mParents[idx] != nullptr) { + mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); + removeParent(idx); + } + std::shared_ptr<Node> newConstantNode = Producer(tensor); + newConstantNode->addChild(shared_from_this(), 0, idx); + for (auto& graphPtr : views()) { + graphPtr->add(newConstantNode); + } +} + +std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> +Aidge::Node::outputs() const { + std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = + std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); + for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { + listOutputs[i] = output(static_cast<IOIndex_t>(i)); + } + return listOutputs; +} + +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> +Aidge::Node::output(Aidge::IOIndex_t outID) const { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outID].size()); + for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) { + listOutputs[i] = + std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outID][i], mIdInChildren[outID][i]); + } + return listOutputs; +} + +Aidge::IONb_t Aidge::Node::nbValidInputs() const { + IONb_t counter = 0; + for (IONb_t i = 0; i < nbInputs(); ++i) { + if (mIdOutParents[static_cast<std::size_t>(i)] < 0) ++counter; + } + return counter; +} + +Aidge::IONb_t Aidge::Node::nbValidOutputs() const { + IONb_t counter = 0; + if (mIdInChildren.size() == 0) return 0; + for (std::size_t i = 0; i < nbOutputs(); ++i) { + if (mIdInChildren[i].size() > 0U) counter++; + } + return counter; +} + +void Aidge::Node::setInputId(IOIndex_t inId, IOIndex_t newNodeOutID) { + assert(inId != gk_IODefaultIndex && (static_cast<IONb_t>(inId) < nbInputs()) && "Must be a valid index"); + if (mIdOutParents[inId] != gk_IODefaultIndex) { + std::printf("Warning: filling a Tensor already attributed\n"); + auto originalParent = input(inId); + // remove original parent reference to child + // find the output ID for original Parent + // find first occurence of child in the output's children + originalParent.first->removeChild(shared_from_this(), originalParent.second); + } + mIdOutParents[inId] = newNodeOutID; +} + +/////////////////////////////////////////////////////// +// TOPOLOGY +/////////////////////////////////////////////////////// + +void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { + assert((otherInId != gk_IODefaultIndex) && (static_cast<IONb_t>(otherInId) < otherNode->nbInputs()) && + "Input index out of bound."); + assert((outId != gk_IODefaultIndex) && (static_cast<IONb_t>(outId) < nbOutputs()) && "Output index out of bound."); + if (otherNode->input(otherInId).second >= 0) { + std::printf("Warning, the %d-th Parent of the child node already existed.\n", otherInId); + } + // manage tensors and potential previous parent + otherNode->setInputId(otherInId, outId); + otherNode->getOperator()->associateInput(otherInId, getOperator()->getRawOutput(outId)); + // manage nodes + mChildren[outId].push_back(otherNode); + mIdInChildren[outId].push_back(otherInId); + otherNode->addParent(shared_from_this(), otherInId); +} + +void Aidge::Node::addChildView(std::shared_ptr<GraphView> other_graph, const IOIndex_t outID, + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + assert((otherInId.second != gk_IODefaultIndex) && + (static_cast<IONb_t>(otherInId.second) < otherInId.first->nbInputs()) && + "Other graph input index out of bound."); + assert((outID != gk_IODefaultIndex) && (static_cast<IONb_t>(outID) < nbOutputs()) && "Output index out of bound."); + std::set<std::shared_ptr<Node>> inNodes = other_graph->inputNodes(); + if (inNodes.size() == std::size_t(0)) { // no input Node + printf("Cannot add GraphView to the Node. No input node detected.\n"); + } else // inNodes.size() >= 1 + { + assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node + addChildOp(otherInId.first, outID, otherInId.second); + } +} + +void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { + otherInId = (otherInId >= 0) ? otherInId : otherNode->getFirstFreeDataInput(); + addChildOp(otherNode, outId, otherInId); +} + +void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + if (!otherInId.first) { + assert((otherView->inputNodes().size() == 1U) && + "Specify an input Node for the GraphView. More or less than one " + "Node is not explicit."); + otherInId.first = *(otherView->inputNodes().begin()); + } + otherInId.second = (otherInId.second >= 0) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); + addChildView(otherView, outId, otherInId); +} + +void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) { + if (getParents(inId) != nullptr) { + printf("Warning, you're replacing a Parent.\n"); + } + assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Input index out of bound."); + mParents[inId] = other_node; +} + +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getParents() const { return mParents; } + +std::shared_ptr<Aidge::Node> Aidge::Node::popParent(const IOIndex_t inId) { + assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Input index out of bound."); + std::shared_ptr<Node> val = mParents[inId]; + removeParent(inId); + return val; +} + +bool Aidge::Node::removeParent(const IOIndex_t inId) { + assert((inId != gk_IODefaultIndex) && (static_cast<IONb_t>(inId) < nbInputs()) && "Parent index out of bound."); + if (mParents[inId]) { + mParents[inId] = nullptr; + mIdOutParents[inId] = gk_IODefaultIndex; + return true; + } + return false; +} + +std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { + std::set<std::shared_ptr<Node>> children; + for (const std::vector<std::shared_ptr<Node>> &childrenOfOneOutput : mChildren) { + children.insert(childrenOfOneOutput.begin(), childrenOfOneOutput.end()); + } + return children; +} + +std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { return mChildren; } + +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(IOIndex_t outID) const { + assert((outID != gk_IODefaultIndex) && (static_cast<IONb_t>(outID) < nbOutputs()) && "Output index out of bound."); + return mChildren[outID]; +} + +bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { + assert((outId != gk_IODefaultIndex) && (static_cast<IONb_t>(outId) < nbOutputs()) && "Child index out of bound."); + bool removed = false; + for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { + if (mChildren[outId][j] == nodePtr) { + mChildren[outId].erase(mChildren[outId].begin() + j); + mIdInChildren[outId].erase(mIdInChildren[outId].begin() + j); + removed = true; + break; + } + } + return removed; +} + +void Aidge::Node::resetConnections(bool includeLearnableParam) { + // remove every parents reference to it + IONb_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbDataInputs(); + for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbRemovedInputs; ++i) { + std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); + if (parent.first) { + // number of children linked to the parent's output + while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} + } + // every reference to this object as child has been removed + // removing reference to parents. + mParents[i] = nullptr; + mIdOutParents[i] = gk_IODefaultIndex; + } + for (IOIndex_t i = 0; static_cast<IONb_t>(i) < nbOutputs(); ++i) { + for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) { + child.first->removeParent(child.second); + } + mChildren[i] = std::vector<std::shared_ptr<Node>>(); + mIdInChildren[i] = std::vector<IOIndex_t>(); + } + // removing this Node from every GraphView it belongs to + for (auto& graph : views()) { + // if keeping connections with LEarnable Parameters, then also remove them from graph + graph->remove(shared_from_this(), !includeLearnableParam); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////// +// private + +/////////////////////////////////////////////////////// +// FUNCTIONAL DESCRIPTION +/////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////// +// OPERATORS +/////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////// +// TENSOR MANAGEMENT +/////////////////////////////////////////////////////// diff --git a/src/graph/OpArgs.cpp b/src/graph/OpArgs.cpp new file mode 100644 index 000000000..93ceff0aa --- /dev/null +++ b/src/graph/OpArgs.cpp @@ -0,0 +1,73 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graph/Node.hpp" +#include "graph/GraphView.hpp" +#include "graph/OpArgs.hpp" + + +std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs> inputs) { + std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); + for (const OpArgs& elt : inputs) { + if(elt.node() != nullptr) { + // >= to allow incomplete graphViews + assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); + /* + * /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. + * Prefer a functional descrition for detailed inputs + */ + for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) { + node_ptr -> addChild(elt.node()); // already check that node_ptr->nbOutput == 1 + } + gv->add(elt.node()); + } + else { + for (std::shared_ptr<Node> node_in : elt.view()->inputNodes()) { + // >= to allow incomplete graphViews + assert(static_cast<std::size_t>(node_in->getNbFreeDataInputs()) >= gv->outputNodes().size()); + for (std::shared_ptr<Node> node_out : gv->outputNodes()) { + node_out -> addChild(node_in); // assert one output Tensor per output Node + } + } + gv->add(elt.view()); + } + } + return gv; +} + + +std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs> inputs) { + std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); + for(const OpArgs& elt : inputs) { + if (elt.node()!=nullptr) + gv->add(elt.node()); + else + gv->add(elt.view()); + } + return gv; +} + + +std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs> inputs) { + std::shared_ptr<GraphView> gv = Sequential(inputs); + assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); + std::shared_ptr<Node> lastNode = *gv->outputNodes().begin(); + assert(gv->inputNodes().size() == 2U && "Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection"); + std::shared_ptr<Node> firstNode = nullptr; + for (const std::shared_ptr<Node> node_ptr : gv->inputNodes()) { + if (node_ptr != lastNode) { + firstNode = node_ptr; + } + } + assert(lastNode->getNbFreeDataInputs()>=1); + gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex); + return gv; +} \ No newline at end of file diff --git a/src/graphmatching/GRegex.cpp b/src/graphmatching/GRegex.cpp new file mode 100644 index 000000000..80bc724c1 --- /dev/null +++ b/src/graphmatching/GRegex.cpp @@ -0,0 +1,301 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graphmatching/GRegex.hpp" +#include "graph/GraphView.hpp" + +using namespace Aidge; + +GRegex::GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ):mStmFab(nodesRegex){ + + + //setup all the STM + for (const std::string& sequRegex : seqRegexps) { + mStmInit.push_back(mStmFab.makeNewStm(sequRegex)); + } + +} + +bool GRegex::walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm){ + //test if all stm type are in a valid state + std::vector<int> number_of_valid; + number_of_valid.resize(all_stm.size()); + + for (std::size_t i = 0; i < all_stm.size(); ++i) { + number_of_valid[i] = 0; + for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { + SeqStm* stm = *it; + if (stm->isValid()){ + number_of_valid[i] +=1; + } + } + } + + for (std::size_t i = 0; i < number_of_valid.size(); ++i) { + if (number_of_valid[i] == 0) { + //std::cout << "NO MATCH at least one stm are not valid" << std::endl; + return false; + } + if (number_of_valid[i] > 1) { + //std::cout << "NO MATCH multiple brach match of stm (// quantification)" << std::endl; + return false; + } + } + return true; +} + +bool GRegex::walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm){ + std::set<NodeTmp> all_stm_node_tested; + std::set<NodeTmp> all_stm_node_validated; + + for (std::size_t i = 0; i < all_stm.size(); ++i) { + //std::cout << "all stm index " << i << " on dimension 1 of size " << all_stm.size() <<std::endl; + for (std::size_t j = 0; j < all_stm[i].size(); ++j) { + //std::cout << "all stm index " << j << " on dimension 2 of size " << all_stm[i].size() <<std::endl; + + std::set<NodeTmp> stm_node_tested = all_stm[i][j]->getAllNodeTested(); + std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); + + all_stm_node_tested.insert(stm_node_tested.begin(), stm_node_tested.end()); + all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); + } + } + + + std::set<NodeTmp> test_but_not_valid; + for (const auto& x : all_stm_node_tested) { + if (all_stm_node_validated.find(x) == all_stm_node_validated.end()) { + test_but_not_valid.insert(x); + } + } + + + if (!test_but_not_valid.empty()) { + std::cout << "NO MATCH. The node(s) "; + for (const auto& x : test_but_not_valid) { + std::cout << x.get() << ", "; + } + std::cout << " have been tested but not validated." << std::endl; + return false; + } + return true; + +} + +bool GRegex::walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm){ + std::map<NodeTmp, std::pair<std::string,int>> node_to_common_tag; + for (std::size_t i = 0; i < all_stm.size(); ++i) { + for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { + SeqStm* stm = *it; + + if (!stm->isValid()){ + continue; + } + + for (const auto& pair : stm->getAllCommonNode()) { + const NodeTmp node = pair.first; + const std::string common_tag = pair.second; + + if (node_to_common_tag.find(node) != node_to_common_tag.end()) { + std::string tag = node_to_common_tag[node].first; + int& occurence = node_to_common_tag[node].second; + if (tag!=common_tag){ + std::cout << "NO MATCH. The node " << node << " have two different tags "<< tag << " and " << common_tag << std::endl; + return false; + } else { + occurence += 1; + } + } else { + node_to_common_tag.insert(std::make_pair(node, std::make_pair(common_tag, 1))); + } + } + } + } + /*std::cout << "Node to common tag "; + for (const auto& x : node_to_common_tag) { + std::cout << "(" << x.first << ", " << "[" << x.second.first << ", " << x.second.second << "]" << ") ; "; + } + std::cout << std::endl;*/ + + + for (const auto& pair : node_to_common_tag) { + const std::pair<std::string, int> tag_occurence_pair = pair.second; + if (tag_occurence_pair.second < 1){ + //std::cout << "NO MATCH. The common tag " << tag_occurence_pair.first << " did not match " << std::endl; + return false; + } + } + + return true; +} + +std::set<NodeTmp> GRegex::get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm){ + std::set<NodeTmp> all_stm_node_validated; + + for (std::size_t i = 0; i < all_stm.size(); ++i) { + for (std::size_t j = 0; j < all_stm[i].size(); ++j) { + std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); + all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); + } + } + return all_stm_node_validated; +} + + +std::set<NodeTmp> GRegex::matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch){ + std::set<NodeTmp> empty_set_return; + //ASSERT + if(startNodes.size() != mStmInit.size()){ + throw std::runtime_error ("bad GRegex start nodes"); + } + + //init the walk + std::vector<std::vector<SeqStm*>> allStm; + std::vector<std::pair<NodeTmp,SeqStm*>> currentWalk; + + for (SeqStm* seqStmPtr : mStmInit) { + SeqStm* newStm = mStmFab.duplicateStm(seqStmPtr); + std::size_t idxStart = newStm->getStmIdx(); + currentWalk.push_back(std::make_pair(startNodes[idxStart],newStm)); + allStm.push_back(std::vector<SeqStm*>()); + } + + //walk + while (currentWalk.size()!=0) + { + std::vector<std::pair<NodeTmp,SeqStm*>> newWalk; + for (const auto& pair : currentWalk) { + const NodeTmp node = pair.first; + SeqStm* stmPtr = pair.second; + + std::pair<int,std::string> test = stmPtr->testNode(node); + int res = test.first; + std::string commonTag = test.second; + + std::set<NodeTmp> next_nodes = graphToMatch->getChildren(node); + + /*std::cout << "Next nodes : " ; + for (const auto& x : next_nodes) { + std::cout << x->name() << ", "; + } + std::cout << std::endl;*/ + + // Test Match + if (commonTag == "" && next_nodes.size() > 1) { + std::cout << "NO MATCH. The node " << node.get() << " is not common and has more than one child" << std::endl; + return empty_set_return; + } + + // If there is no more nodes --> Archive the branch + if (res == -1 || next_nodes.empty()) { + int indexToInsert = stmPtr->getStmIdx(); + allStm[indexToInsert].push_back(stmPtr); + //std::cout << "No more nodes --> STM archived : " << indexToInsert << std::endl; + continue; // TODEV : replace this with 'else' that encapsulate the rest of the function ? + } + + bool first = true; + + // Use an iterator to read through the next_nodes + std::set<NodeTmp>::iterator it; + for (it = next_nodes.begin(); it != next_nodes.end(); ++it) { + // Access the current element using the iterator + std::shared_ptr<Aidge::Node> next_node = *it; + if (first){ + newWalk.push_back(std::make_pair(next_node, stmPtr)); + first = false; + } else { + SeqStm* new_stmPtr = mStmFab.duplicateStm(stmPtr); + newWalk.push_back(std::make_pair(next_node, new_stmPtr)); + } + } + } + currentWalk = newWalk; + } + + //std::cout << "Walk finished" << std::endl; + + if (!walk_validation_all_stm_are_valid(allStm)){ + return empty_set_return; + } + //std::cout << "walk_validation_all_stm_are_valid finished" << std::endl; + + + if (!walk_validation_all_node_read_validate_by_one_stm(allStm)){ + return empty_set_return; + } + //std::cout << "walk_validation_all_node_read_validate_by_one_stm finished" << std::endl; + + + if (!walk_validation_common_nodes_same_tag_for_all_stm(allStm)){ + return empty_set_return; + } + //std::cout << "walk_validation_common_nodes_same_tag_for_all_stm finished" << std::endl; + + //std::cout << "MATCH" << std::endl; + + return get_all_validate_nodes(allStm); + +} + + + +Match GRegex::match(const std::shared_ptr<GraphView> graphToMatch){ + + //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; + //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; + Match matches; + std::size_t nbStartNodes = mStmInit.size(); + std::set<NodeTmp> allNodes = graphToMatch->getNodes(); + std::size_t nbAllNodes = allNodes.size(); + + std::vector<std::size_t> indices(nbStartNodes, 0); + + while (true) { + // Generate all permutations of the current combination + do { + std::vector<NodeTmp> startNodes; + //std::cout <<"start nodes :"; + for (std::size_t i = 0; i < nbStartNodes; ++i) { + auto it = std::begin(allNodes); + std::advance(it, indices[i]); + //std::cout << (*it).get() << " "; + startNodes.push_back(*it); + } + //std::cout <<"\n"; + + std::set<NodeTmp> match = matchFromStartNodes(startNodes, graphToMatch); + //std::cout << "match size : " << match.size() << " "; + if(match.size() != 0){ + //matches.push_back(std::make_pair(startNodes,match)); + //matches.insert(std::make_pair(startNodes,match)); + matches.insert(startNodes,match); + } + + } while (std::next_permutation(indices.begin(), indices.end())); + + // Generate the next combination with replacement + std::size_t i = nbStartNodes - 1; + while (true) { + if (indices[i] < nbAllNodes - 1) { + ++indices[i]; + break; + } + if (i == 0) { + return matches; + } + --i; + } + std::fill(indices.begin() + i + 1, indices.end(), indices[i]); + } + + return matches; +} \ No newline at end of file diff --git a/src/graphmatching/Match.cpp b/src/graphmatching/Match.cpp new file mode 100644 index 000000000..9a87fac7d --- /dev/null +++ b/src/graphmatching/Match.cpp @@ -0,0 +1,37 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graphmatching/Match.hpp" + +using namespace Aidge; + +Match::Match(){ + //ctr +} + +size_t Match::getNbMatch(){ + assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); + return mStartNodes.size(); +} + +void Match::insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes){ + assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); + mStartNodes.push_back(startnodes); + mMatchNodes.push_back(matchnodes); +} + +std::vector<std::vector<NodeTmp>> Match::getStartNodes(){ + return mStartNodes; +} + +std::vector<std::set<NodeTmp>> Match::getMatchNodes(){ + return mMatchNodes; +} \ No newline at end of file diff --git a/src/graphmatching/NodeRegex.cpp b/src/graphmatching/NodeRegex.cpp new file mode 100644 index 000000000..8ba6332dd --- /dev/null +++ b/src/graphmatching/NodeRegex.cpp @@ -0,0 +1,46 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graphmatching/NodeRegex.hpp" + + +// Verification done by the Parameter system + + +// Version 1 - Only test the type of the node (no need for a lexer) +// Input : Node_op +// Output : bool +// return mCondition == Node_op.type +bool Aidge::NodeRegex::_is(std::shared_ptr<Node> &Node_op){ + + std::string NodeType = Node_op->type(); + + return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; +} + + +bool Aidge::NodeRegex::isA(std::string NodeType){ + + return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; +} + +// Version 2 - Test the node to an advanced condition +// Input : Node_op +// Output : bool +// return mCondition applied on Node +/**bool NodeRegex::_is(string &Node_op){ + // Parsing the condition is done in the initialization of the NodeRegex + + // assert parameters exist in the node with the parameter function isParam() + + // get the parameters + +}*/ diff --git a/src/graphmatching/SeqStm.cpp b/src/graphmatching/SeqStm.cpp new file mode 100755 index 000000000..89c932bce --- /dev/null +++ b/src/graphmatching/SeqStm.cpp @@ -0,0 +1,247 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graphmatching/SeqStm.hpp" + +using namespace Aidge; + + + + + /////////////////////////////////////////////////////// + + SeqStm::SeqStm( + const int stmIdx, + const std::vector<std::vector<int>>& transitionMatrix, + const std::map<std::string,NodeRegex*>& nodesRegex, + const std::map<NodeTypeKey,int>& typeToIdxTransition, + int actSt, + std::set<NodeTmp> allNodeValidated, + std::set<NodeTmp> allNodeTested, + std::set<std::pair<NodeTmp,std::string>> allCommonNode, + bool stmIsValid):mStmIdx(stmIdx), + mTransitionMatrix(transitionMatrix), + mNodesRegex(nodesRegex), + mTypeToIdxTransition(typeToIdxTransition) + { + + //assert + if (transitionMatrix.size() == 0){ + throw std::runtime_error ("no transitionMatrix"); + } + if(transitionMatrix[0].size() == 0 || transitionMatrix[0].size() != typeToIdxTransition.size()){ + throw std::runtime_error ("bad transitionMatrix"); + } + int size = static_cast<int>(transitionMatrix.size()); + if (actSt >= size){ + throw std::runtime_error ("bad actSt"); + } + + + mActSt = actSt; + mAllNodeValidated = allNodeValidated; + mAllNodeTested = allNodeTested; + mAllCommonNode = allCommonNode; + mStmIsValid = stmIsValid; + + } + + SeqStm* SeqStm::duplicateStm(){ + + //deep copy of the set + // std::set<Node> cAllNodeValidated(mAllNodeValidated.begin(), mAllNodeValidated.end()); + // std::set<Node> cAllNodeTested(mAllNodeTested.begin(), mAllNodeTested.end()); + + // std::set<std::pair<Node,std::string>> cAllCommonNode; + // for (const auto& p : mAllCommonNode) { + // cAllCommonNode.insert(p); + // } + + auto newStm = new SeqStm( + mStmIdx, + mTransitionMatrix, + mNodesRegex, + mTypeToIdxTransition, + mActSt, + mAllNodeValidated, + mAllNodeTested, + mAllCommonNode, + mStmIsValid + ); + + return newStm; + } + + + std::pair<NodeRegex*,std::string> SeqStm::getNodeRegexAndCommonAt(int idxType) + { + //std::cout << "!" << idxType << "\n"; + for (auto const& x : mTypeToIdxTransition) + { + //x.second is the value : idx in mTransitionMatrix for the type + //x.first pair of the node regex class and a string that is the common tag '',#,#n + if (x.second == idxType ){ + + if (mNodesRegex.find(x.first.first) != mNodesRegex.end()){ + return std::make_pair(mNodesRegex.find(x.first.first)->second, x.first.second); + }else{ + throw std::runtime_error ("a type is not define in NodesRegex"); + } + } + } + throw std::runtime_error ("bad idx in mNodesRegex"); + return std::make_pair(nullptr,nullptr); + } + + + NodeType SeqStm::getTheNodeType(NodeTmp node) + { + //the node is a str of '{type}{idx}' and we juste want type + // // std::regex re("([a-zA-Z]+)[0-9]+"); + // // std::smatch match; + // // if (std::regex_search(node, match, re) == true) { + // // return match.str(1); + // // } + // // throw std::runtime_error ("Type node not found"); + // // return ""; + + //return node->name(); + return node->type(); + } + + + std::string SeqStm::transitionOnNodeType(NodeType nodeType){ + + if (!isStmBlocked()){ + int idxType = 0; + for (auto & nextSt : mTransitionMatrix[mActSt]) { + // There are a next step for this type + //std::cout << "transition matrix next state -> "<< nextSt<<"\n" ; + if (nextSt != -1){ + //std::cout << "next -> "<< nextSt<< " "<< isAValidSt(nextSt) <<"\n" ; + auto nodeRegex = getNodeRegexAndCommonAt(idxType); + //std::cout << "-> "<< nodeRegex.second<<"\n" ; + if (nodeRegex.first->isA(nodeType)){ + //std::cout << "nodetype tested !"<<"\n" ; + if(isAValidSt(nextSt)){ + //std::cout << "Valid state !"<<"\n" ; + mStmIsValid = true; + } + mActSt = nextSt; + return nodeRegex.second; + } + + } + idxType += 1; + } + + mActSt =-1; + } + + return ""; + } + + + std::pair<int,std::string> SeqStm::testNode(const NodeTmp node){ + + std::string commonTag = ""; + //std::cout << "0\n" ; + if (!isStmBlocked()){ + bool isNextStEnd = std::all_of(mTransitionMatrix[mActSt].begin(), mTransitionMatrix[mActSt].end(), [&](int x){ return x == -1; }); + //std::cout << "1:"<< isNextStEnd <<"\n" ; + //if the next state if full of -1 can we relay add the node test to all node tested + // oker y test it but it sure that not be valid + if(!isNextStEnd){ + mAllNodeTested.insert(node); + } + //std::cout << "2\n" ; + //recurtion avoidance + if(mAllNodeValidated.find(node) == mAllNodeValidated.end()){ + + NodeType nodeType = getTheNodeType(node); + //std::cout << "3 " << nodeType << "\n" ; + commonTag = transitionOnNodeType(nodeType); + //after the transition test, if the node is != -1 the node is valid for the stm + //std::cout << " mActSt = " << mActSt << "\n" ; + if( mActSt != -1 ){ + mAllNodeValidated.insert(node); + } + }else{ + mActSt = -1; + } + } + + if(commonTag != ""){ + mAllCommonNode.insert(std::make_pair(node,commonTag)); + } + return std::make_pair(mActSt,commonTag); + } + + +void SeqStm::drawStm(){ + + //mTransitionMatrix + // Find the maximum width of each column + std::vector<std::size_t> max_widths(mTransitionMatrix[0].size(), 0); + for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) + { + for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) + { + std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); + if (width > max_widths[j]) + { + max_widths[j] = width; + } + } + } + + // Print the vector with aligned columns + for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) + { + for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) + { + int i_int = static_cast<int>(i); + if (mActSt == -1 ){ + if(mStmIsValid){ + std::cout << "\033[48;5;40m"; + }else{ + std::cout << "\033[48;5;9m"; + } + } + else if (mActSt == i_int){ + std::cout << "\033[48;5;30m"; + }else{ + std::cout << "\033[48;5;27m"; + } + + // Pad the value with spaces to align it with the maximum width + std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); + std::string padding(max_widths[j] - width, ' '); + std::cout << padding << mTransitionMatrix[i][j] << " "; + std::cout << "\033[0m"; + } + std::cout << "\n"; + } + + std::cout << "mAllNodeTested : "; + for (const auto& x : mAllNodeTested) { + std::cout << x << ", "; + } + std::cout << "\n"; + + + std::cout << "mAllNodeValidated : "; + for (const auto& x : mAllNodeValidated) { + std::cout << x << ", "; + } + std::cout << "\n"; +} + diff --git a/src/graphmatching/StmFactory.cpp b/src/graphmatching/StmFactory.cpp new file mode 100644 index 000000000..4ca9c6d25 --- /dev/null +++ b/src/graphmatching/StmFactory.cpp @@ -0,0 +1,150 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "graphmatching/StmFactory.hpp" + +using namespace Aidge; + +StmFactory::StmFactory(const std::map<std::string, NodeRegex *> &nodesRegex) + : mNodesRegex(nodesRegex) {} + +SeqStm *StmFactory::duplicateStm(SeqStm *stm) { return stm->duplicateStm(); } + +SeqStm *StmFactory::makeNewStm(const std::string &sequRegex) { + + ParsingReturn parsing = initParsingSequRegex(sequRegex); + std::vector<std::vector<int>> transitionMatrix = + initTransitionMatrix(parsing); + + std::set<NodeTmp> allNodeValidated; + std::set<NodeTmp> allNodeTested; + std::set<std::pair<NodeTmp, std::string>> allCommonNode; + + SeqStm *newStm = new SeqStm(static_cast<int>(mCmptStm), transitionMatrix, mNodesRegex, + parsing.typeToIdxTransition, 0, allNodeValidated, + allNodeTested, allCommonNode, false); + mCmptStm += 1; + + return newStm; +} + +ParsingReturn StmFactory::initParsingSequRegex(const std::string &sequRegex) { + + std::string toMatch; + std::regex re("\\s*([A-Za-z]+)(#\\d*)?([+*])?\\s*(->|;)"); + std::smatch matches; + + int idxType = 0; + // return + ParsingReturn parsing; + // std::map<std::pair<NodeType,std::string>,int> typeToIdxTransition; + // std::vector<std::pair<std::pair<NodeType,std::string>,std::string>> + // transition; + // assert + std::map<NodeType, std::string> assertCommonNodeTypes; + + for (std::size_t i = 0; i < sequRegex.length(); i++) { + toMatch += sequRegex[i]; + if (std::regex_match(toMatch, matches, re)) { + + std::string type = matches.str(1); + std::string commonTag = matches.str(2); + std::string quantification = matches.str(3); + + if ((commonTag != "") && (quantification != "")) { + throw std::runtime_error("bad commonTag and quantification"); + } + + // make the typeToIdxTransition + NodeTypeKey typeTag = std::make_pair(type, commonTag); + /*std::cout << " typeTag: " << type << " " << commonTag + << parsing.typeToIdxTransition.size() << std::endl;*/ + if (parsing.typeToIdxTransition.find(typeTag) == + parsing.typeToIdxTransition.end()) { + parsing.typeToIdxTransition[typeTag] = idxType; + idxType += 1; + } + //////////////////////////////////////////////////////////// + // ASSERT + // SAME Common node in the sequ + if (commonTag != "") { + if (assertCommonNodeTypes.find(type) != assertCommonNodeTypes.end()) { + if (assertCommonNodeTypes[type] == commonTag) { + throw std::runtime_error("same common node in the sequ regex"); + } + } else { + assertCommonNodeTypes[type] = commonTag; + } + } + + // save all transition + parsing.transition.push_back(std::make_pair(typeTag, quantification)); + + /*std::cout << "Match found: " << matches.str() << std::endl; + std::cout << "Type: " << matches.str(1) << std::endl; + std::cout << "Common tag: " << matches.str(2) << std::endl; + std::cout << "Quantification: " << matches.str(3) << std::endl;*/ + + toMatch = ""; + } + } + if (parsing.transition.size() == 0) { + throw std::runtime_error("Bad Parsing SequRegex "); + } + + return parsing; +} + +std::vector<std::vector<int>> +StmFactory::initTransitionMatrix(ParsingReturn &parsing) { + + // std::pair<NodeTypeKey,std::string> + std::vector<std::vector<int>> transitionMatrix; + std::size_t numberOfType = parsing.typeToIdxTransition.size(); + + if (numberOfType == 0) { + throw std::runtime_error("Bad number Of Type "); + } + // init start st + transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); + + std::size_t idxTransition = 0; + int idxState = 0; + for (const auto &pair : parsing.transition) { + const NodeTypeKey &nodeTypeKey = pair.first; + const std::string &quant = pair.second; + + /*std::cout << "Key: {" << nodeTypeKey.first << ", " << nodeTypeKey.second + << "}, Value: " << quant << std::endl; + std::cout << "idxState " << idxState << " TM: " << transitionMatrix.size() + << std::endl;*/ + std::size_t idxType = parsing.typeToIdxTransition[nodeTypeKey]; + /*std::cout << "idxType " << idxType << " TM: " << transitionMatrix[0].size() + << "type" << numberOfType << std::endl;*/ + + if (quant == "*") { + transitionMatrix[idxTransition][idxType] = idxState; + } else if (quant == "+") { + idxState += 1; + transitionMatrix[idxTransition][idxType] = idxState; + transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); + idxTransition += 1; + transitionMatrix[idxTransition][idxType] = idxState; + } else { + + idxState += 1; + transitionMatrix[idxTransition][idxType] = idxState; + transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); + idxTransition += 1; + } + } + return transitionMatrix; +} \ No newline at end of file diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp new file mode 100644 index 000000000..1db2feeb5 --- /dev/null +++ b/src/operator/Operator.cpp @@ -0,0 +1,44 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> + +#include "backend/OperatorImpl.hpp" +#include "operator/Operator.hpp" +#include "utils/Types.h" + +// constexpr Aidge::Operator::Operator(const char* type) +// : mType(type) +// { +// // ctor +// } + +Aidge::Operator::~Operator() = default; + +/////////////////////////////////////////////////////// +// IMPLEMENTATION +/////////////////////////////////////////////////////// + +Aidge::NbElts_t Aidge::Operator::getNbRequiredData(Aidge::IOIndex_t inputIdx) const { + return mImpl->getNbRequiredData(inputIdx); +} + +Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { + return mImpl->getNbConsumedData(inputIdx); +} + +Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { + return mImpl->getNbProducedData(outputIdx); +} + +void Aidge::Operator::forward() { mImpl->forward(); } + +void Aidge::Operator::backward() { mImpl->backward(); } diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp new file mode 100644 index 000000000..e37311a04 --- /dev/null +++ b/src/recipies/FuseMulAdd.cpp @@ -0,0 +1,80 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <set> +#include <cassert> +#include <memory> +#include <string> + +#include "operator/FC.hpp" +#include "utils/Recipies.hpp" +#include "graph/GraphView.hpp" +#include "graph/Node.hpp" +#include "operator/Producer.hpp" +#include "operator/GenericOperator.hpp" + +using namespace Aidge; + +/** + * @brief Merge MatMul and Add Node into FC. + * + * @param nodes Strict set of Node to merge. + */ +void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ + // Fuse Mulmat & Add into FC + // Inputs : old nodes (pointers on mul & add) + + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); + // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? + + // Step 0 : Assert the nodes types are correct to be fused + std::shared_ptr<Node> add; + std::shared_ptr<Node> matmul; + for (const auto& element : nodes) { + assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace"); + if (element->type() == "MatMul"){ + matmul = element; + } + else if (element->type() == "Add") { + add = element; + } + } + + // Step 1 : Create FC + // Fetch the output dimension throught the bias size + auto producer_add_bias = add->input(1); + Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0); + + // Instanciate FC + //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); + std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false)); + + // Step 2 : Branch existing producers & create the others + // link weights & bias + if (matmul->getParents(1)==nullptr) { + matmul->getParents(0)->addChild(fc, 0, 1); + } else { + if (matmul->getParents(0)!=nullptr) + matmul->getParents(0)->addChild(fc, 0, 0); + matmul->getParents(1)->addChild(fc, 0, 1); + } + (producer_add_bias.first)->addChild(fc,0,2); + + + // Step 3 : Update all graphviews that contains at least one node to replace + // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output + // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview + // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ? + auto nodeToReplace = std::make_shared<GraphView>(); + nodeToReplace->add(nodes); + nodeToReplace->replaceWith({fc}); + +} \ No newline at end of file diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp new file mode 100644 index 000000000..0a28c4767 --- /dev/null +++ b/src/recipies/RemoveFlatten.cpp @@ -0,0 +1,28 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "graph/Node.hpp" +#include "graph/GraphView.hpp" +#include "utils/Recipies.hpp" + +namespace Aidge { + void removeFlatten(std::shared_ptr<GraphView> view) { + for (auto& nodePtr : view->getNodes()) { + if (nodePtr->type() == "Flatten") { + auto g = std::make_shared<GraphView>(); + g->add(std::set<std::shared_ptr<Node>>({nodePtr})); + g->replaceWith({}); + } + } + } +} \ No newline at end of file diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp new file mode 100644 index 000000000..5f4b0295b --- /dev/null +++ b/src/scheduler/Scheduler.cpp @@ -0,0 +1,235 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "scheduler/Scheduler.hpp" + +#include <chrono> +#include <memory> +#include <set> +#include <string> + +#include "graph/GraphView.hpp" +#include "graph/Node.hpp" +#include "utils/Types.h" + +void drawProgressBar(double progress, int barWidth, const char* additionalInfo = nullptr) { + putchar('['); + int pos = static_cast<int>(barWidth * progress); + for (int i = 0; i < barWidth; ++i) { + if (i <= pos) + putchar('#'); + else + putchar(' '); + } + printf("] %d%% | %s\r", static_cast<int>(progress * 100), (additionalInfo ? additionalInfo : "")); + fflush(stdout); +} + +// TODO: handle multiple inputs/outputs +void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { + if (frowardDims) {mGraphView->forwardDims(); } + + mScheduling.clear(); + + // setup initial producers list + // add each Producer Node. + std::set<std::shared_ptr<Node>> computationOver; + std::size_t computationNumber = 0; + std::set<std::shared_ptr<Node>> producers; + for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { + if (nodePtr->type() == "Producer") { + producers.insert(nodePtr); + } else { + ++computationNumber; + } + } + // add Data Input + // FIXME : shoudl be changed when the real system for providing + // data is implemented + for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) { + for (const auto& parentPtr : nodePtr->getParents()) { + if ((mGraphView->getNodes()).find(parentPtr) == (mGraphView->getNodes()).end()) { + // Node not found in the graph, it's an outside producer + producers.insert(parentPtr); + } + } + } + + // setup consumer list + // std::set<std::shared_ptr<Node>> consumers = getConsumers(producers); + + /* It may not be necessary to initialize producer */ + std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); + do { + // find runnable consumers + std::set<std::shared_ptr<Node>> runnableConsumers; + if (verbose) printf("List of layers receiving data:\n"); + for (const auto& consumer : consumers) { + if (verbose) { + printf("\t- consumer: " + "\x1b[1;37m" + "%s" + "\x1b[0m" + "\n\t\tR/C:\t", + (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); + for (IOIndex_t inId = 0; static_cast<IONb_t>(inId) < consumer->nbInputs() - 1; ++inId) { + printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); + } + printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); + printf("\n\t\tP:\t"); + for (IOIndex_t outId = 0; static_cast<IONb_t>(outId) < consumer->nbOutputs() - 1; ++outId) { + printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + } + printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + printf("\n"); + } + bool isRunnable = true; + + IOIndex_t parentID = 0; // FIXME: handle this correctly + // Check every input has got enought data to run + for (const auto& consumerParent : consumer->dataInputs()) { + if (consumerParent.first && + consumer->getOperator()->getNbRequiredData(parentID++) > + consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { + // not enough data to run + isRunnable = false; + break; + } + } + + if (isRunnable) { + runnableConsumers.insert(consumer); + } + } + + // run sequencially every runnable consumers once + // TODO: handle memory allocation in scheduler + // TODO: optimize memory usage + for (const auto& runnable : runnableConsumers) { + if (verbose) + printf("run: %s\n", + (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); + else + drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(computationNumber), 50, + (std::string("running ") + runnable->type() + "_" + + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))) + .c_str()); + const auto tStart = std::chrono::high_resolution_clock::now(); + runnable->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); + } + + // update producers and consumers list + if (verbose) printf("Updating producer and consumer lists...\n"); + const auto oldConsumers = consumers; + + for (const auto& consumer : oldConsumers) { + if (verbose) { + printf("\t- consumer: %s\n\t\tR/C:\t", + (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); + for (IOIndex_t inId = 0; static_cast<IONb_t>(inId) < consumer->nbInputs() - 1; ++inId) { + printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); + } + printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); + printf("\n\t\tP:\t"); + for (IOIndex_t outId = 0; static_cast<IONb_t>(outId) < consumer->nbOutputs() - 1; ++outId) { + printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + } + printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + printf("\n"); + } + bool isStillConsumer = false; + + IOIndex_t parentID = 0; // FIXME: handle this correctly + // should we check input or dataInput ? + for (const auto& consumerParent : consumer->inputs()) { + if (consumerParent.first && + consumer->getOperator()->getNbConsumedData(parentID++) < + consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { + // there is still data to consume + isStillConsumer = true; + break; + } + } + + bool computationOverForConsumer = true; + for (IOIndex_t parentIDi = 0; static_cast<IONb_t>(parentIDi) < consumer->nbInputs(); ++parentIDi) { + if (consumer->getOperator()->getNbConsumedData(parentIDi) < + consumer->getOperator()->getNbRequiredData(parentIDi)) { + computationOverForConsumer = false; + break; + } + } + if (computationOverForConsumer) { + computationOver.insert(consumer); + } + + for (IOIndex_t outId = 0; static_cast<IONb_t>(outId) < consumer->nbOutputs(); ++outId) { + if (consumer->getOperator()->getNbProducedData(outId) > 0) { + if (verbose) printf(" also producer\n"); + // make sure consumer is also a producer + producers.insert(consumer); + + const auto& childs = consumer->getChildren(); + consumers.insert(childs.begin(), childs.end()); + break; + } + } + + if (!isStillConsumer) { + if (verbose) printf(" no more consumer\n"); + // consumer is no longer a consumer, only a producer + consumers.erase(consumer); + } + } + + if (verbose) printf("*************\n"); + } while (!consumers.empty()); + if (!verbose) drawProgressBar(1.0, 50, " "); + printf("\n"); +} + +void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { + FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w"); + std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%s ms\n\n"); + + if (!mScheduling.empty()) { + const auto globalStart = mScheduling[0].start; + + for (const auto& element : mScheduling) { + std::fprintf(fp, "%s :%ld, %ld\n", + (element.node->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(element.node.get()))) + .c_str(), + std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), + std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); + } + } + + std::fprintf(fp, "\n"); + std::fclose(fp); +} + +std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( + const std::set<std::shared_ptr<Node>>& producers) const { + std::set<std::shared_ptr<Node>> consumers; + + for (const auto& producer : producers) { + const auto& childs = producer->getChildren(); + consumers.insert(childs.begin(), childs.end()); + } + + return consumers; +} \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 000000000..fbd9ec8bb --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,25 @@ + +enable_testing() + +Include(FetchContent) + +FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.0.1 # or a later release +) + +FetchContent_MakeAvailable(Catch2) + +file(GLOB_RECURSE src_files "*.cpp") + +add_executable(tests_core ${src_files}) + +target_link_libraries(tests_core PUBLIC core) + +target_link_libraries(tests_core PRIVATE Catch2::Catch2WithMain) + +list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) +include(CTest) +include(Catch) +catch_discover_tests(tests_core) diff --git a/tests/graph/Test_Connector.cpp b/tests/graph/Test_Connector.cpp new file mode 100644 index 000000000..bde2c9026 --- /dev/null +++ b/tests/graph/Test_Connector.cpp @@ -0,0 +1,257 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "graph/Connector.hpp" +#include "graph/Node.hpp" +#include "operator/GenericOperator.hpp" +#include "graph/GraphView.hpp" +#include "graph/OpArgs.hpp" + +using namespace Aidge; + +TEST_CASE("Connector Creation", "[Connector]") { + SECTION("Empty") { + Connector x = Connector(); + REQUIRE(x.index() == gk_IODefaultIndex); + REQUIRE(x.node() == nullptr); + } + SECTION("0 output") { + std::shared_ptr<Node> node = GenericOperator("Producer",1,1,0); + Connector x = Connector(node); + REQUIRE(x.index() == gk_IODefaultIndex); + REQUIRE(x.node() == node); + } + SECTION("1 output") { + std::shared_ptr<Node> node = GenericOperator("ReLU",1,1,1); + Connector x = Connector(node); + REQUIRE(x.index() == 0); + REQUIRE(x.node() == node); + } + SECTION("Several outputs") { + std::shared_ptr<Node> node = GenericOperator("Split",1,1,2); + Connector x = Connector(node); + REQUIRE(x.index() == gk_IODefaultIndex); + REQUIRE(x.node() == node); + } +} + +TEST_CASE("Connector connections Node", "[Connector]") { + SECTION("0 input / 0 output") { + std::shared_ptr<Node> fic = GenericOperator("Display",0,0,0); + Connector x; + x = (*fic)({}); + REQUIRE(x.node() == fic); + } + SECTION("1 input / 0 output") { + std::shared_ptr<Node> fic = GenericOperator("Loss",1,1,0); + Connector x; + x = (*fic)({x}); + REQUIRE(x.node() == fic); + } + SECTION("0 input / 1 output") { // Producers + std::shared_ptr<Node> fic = GenericOperator("Producer",0,0,1); + Connector x = (*fic)({}); + REQUIRE(x.node() == fic); + } + SECTION("1 input / 1 output") { + std::shared_ptr<Node> fic = GenericOperator("Conv",1,1,1); + Connector x(GenericOperator("Producer",0,0,1)); + x = (*fic)({x}); + REQUIRE(x.node() ==fic); + } + SECTION("2+ inputs / 1 output") { // ElemWise + std::shared_ptr<Node> fic = GenericOperator("fictive",3,3,1); + Connector x1(GenericOperator("fictive",0,0,1)); + Connector x2(GenericOperator("fictive",0,0,1)); + Connector x3(GenericOperator("fictive",0,0,1)); + Connector x = (*fic)({x1, x2, x3}); + REQUIRE(x.node() ==fic); + } + SECTION("1 input / 2+ outputs") { // Slice + std::shared_ptr<Node> fic = GenericOperator("fictive",1,1,3); + + Connector x(GenericOperator("fictive2",0,0,1)); + Connector y; + REQUIRE_NOTHROW(y = (*fic)({x})); + REQUIRE(y[0].node() == fic); + REQUIRE(y[1].node() == fic); + REQUIRE(y[2].node() == fic); + } +} + +TEST_CASE("GraphGeneration from Connector", "[GraphView]") { + + auto node01 = GenericOperator("Conv",0,0,1,"g_conv1"); + auto node02 = GenericOperator("ReLU",1,1,1,"g_relu"); + auto node03 = GenericOperator("g_maxpool1", 1,1,1); + auto node04 = GenericOperator("g_conv2_par1",1,1,1); + auto node05 = GenericOperator("g_relu2_par1", 1,1,1); + auto node06 = GenericOperator("g_conv2_par2", 1,1,1); + auto node07 = GenericOperator("g_relu2_par2", 1,1,1); + auto node08 = GenericOperator("g_concat", 2,2,1); + auto node09 = GenericOperator("g_conv3", 1, 1,1); + auto node10 = GenericOperator("g_matmul1", 2,2,1); + Connector a = (*node01)({}); + Connector x = (*node02)({a}); + x = (*node03)({x}); + Connector y = (*node04)({x}); + y = (*node05)({y}); + Connector z = (*node06)({x}); + z = (*node07)({z}); + x = (*node08)({y, z}); + x= (*node09)({x}); + x = (*node10)({a, x}); + std::shared_ptr<GraphView> gv = generateGraph({x}); + gv->save("GraphGeneration"); +} + +TEST_CASE("Connector connection GraphView", "[Connector]") { + SECTION("1 input") { + Connector x = Connector(); + auto prod = GenericOperator("Producer",0,0,1); + auto g = Residual({ + GenericOperator("g_conv1", 1,1,1), + GenericOperator("g_relu", 1,1,1), + GenericOperator("g_maxpool1", 1,1,1), + Parallel({ + Sequential({GenericOperator("g_conv2_par1",1,1,1), GenericOperator("g_relu2_par1", 1,1,1)}), + Sequential({GenericOperator("g_conv2_par2", 1,1,1), GenericOperator("g_relu2_par2", 1,1,1)}) + }), + GenericOperator("g_concat", 2,2,1), + GenericOperator("g_conv3", 1, 1,1), + GenericOperator("g_matmul1", 2,2,1) + }); + x = (*prod)({}); + x = (*g)({x}); + std::shared_ptr<GraphView> g2 = generateGraph({x}); + std::shared_ptr<GraphView> g3 = g; + g3->add(prod); + REQUIRE(*g3== *g2); + } + SECTION("2+ inputs") { + Connector x = (*GenericOperator("Producer",0,0,1))({}); + Connector y = (*GenericOperator("Producer",0,0,1))({}); + Connector z = (*GenericOperator("Producer",0,0,1))({}); + auto g = Sequential({GenericOperator("ElemWise", 3,3,1), + Parallel({ + Sequential({GenericOperator("g_conv2_par1",1,1,1), GenericOperator("g_relu2_par1", 1,1,1)}), + Sequential({GenericOperator("g_conv2_par2", 1,1,1), GenericOperator("g_relu2_par2", 1,1,1)}), + Sequential({GenericOperator("g_conv2_par3", 1,1,1), GenericOperator("g_relu2_par3", 1,1,1)}) + }), + GenericOperator("g_concat", 3,3,1), + GenericOperator("g_conv3", 1, 1,1) + }); + + x = (*g)({x, y, z}); + std::shared_ptr<GraphView> gv = generateGraph({x}); + gv->save("MultiInputSequentialConnector"); + REQUIRE(gv->inputNodes().size() == 0U); + } +} + +TEST_CASE("Connector Mini-graph", "[Connector]") { + Connector x = Connector(); + Connector y = Connector(); + x = (*GenericOperator("Producer",0,0,1))({}); + y = (*GenericOperator("Producer",0,0,1))({}); + for (int i = 0; i<5; ++i) { + x = (*GenericOperator("Conv",1,1,1))({x}); + } + y = (*GenericOperator("ElemWise",2,2,1))({y, x}); + std::shared_ptr<GraphView> g = generateGraph({y}); + g->save("TestGraph"); +} + +TEST_CASE("Structural descrition - Sequential", "[GraphView]") { + // SECTION("Empty Sequence") { + // std::shared_ptr<GraphView> g1 = Sequential(); // Not supported + // REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>()); + // } + SECTION("1-element Sequence") { + std::shared_ptr<Node> fic = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Sequential({fic}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic})); + } + SECTION("several-elements simple Sequence") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Sequential({fic1, fic2, fic3}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic3})); + } + +} + +TEST_CASE("Structural description - Parallel", "[GraphView]") { + // SECTION("Empty Parallel") { + // std::shared_ptr<GraphView> g1 = Parallel(); // Not supported + // REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>()); + // REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>()); + // } + SECTION("1-element Parallel") { + std::shared_ptr<Node> fic = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({fic}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic})); + } + SECTION("several-elements simple Parallel") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({fic1, fic2, fic3}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + } + SECTION("1 Graph in Parallel") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({Sequential({fic1, fic2, fic3})}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic3})); + } + SECTION("several Sequential in Parallel") { + std::shared_ptr<Node> fic1 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic2 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic3 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic4 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic5 = GenericOperator("node1", 1,1,1); + std::shared_ptr<Node> fic6 = GenericOperator("node1", 1,1,1); + std::shared_ptr<GraphView> g2 = Parallel({Sequential({fic1, fic2, fic3}),Sequential({fic4, fic5, fic6})}); + REQUIRE(g2->getNodes() == std::set<std::shared_ptr<Node>>({fic1, fic2, fic3, fic4, fic5, fic6})); + REQUIRE(g2->inputNodes() == std::set<std::shared_ptr<Node>>({fic1, fic4})); + REQUIRE(g2->outputNodes() == std::set<std::shared_ptr<Node>>({fic3, fic6})); + } +} + +TEST_CASE("Strucutral Description - Complex Graph", "[GraphView]") { + std::shared_ptr<Node> firstLayer = GenericOperator("first", 1,1,1); + auto g = Sequential({firstLayer, + GenericOperator("l2",1,1,1), + Parallel({Sequential({GenericOperator("conv1",1,1,1), GenericOperator("relu1",1,1,1)}), + Sequential({GenericOperator("conv2",1,1,1), GenericOperator("relu2",1,1,1)})}), + GenericOperator("concat",2,2,1), + GenericOperator("lastLayer",1,1,1)}); + REQUIRE(g->getNodes().size() == 8U); + REQUIRE(g->inputNodes() == std::set<std::shared_ptr<Node>>({firstLayer})); +} diff --git a/tests/graph/Test_GraphView.cpp b/tests/graph/Test_GraphView.cpp new file mode 100644 index 000000000..ec5b82525 --- /dev/null +++ b/tests/graph/Test_GraphView.cpp @@ -0,0 +1,333 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <map> +#include <memory> +#include <string> + +#include <catch2/catch_test_macros.hpp> + +#include "backend/OperatorImpl.hpp" +#include "data/Tensor.hpp" +#include "graph/GraphView.hpp" +#include "operator/Conv.hpp" +#include "operator/GenericOperator.hpp" +#include "operator/Producer.hpp" + +using namespace Aidge; + +TEST_CASE("[aidge/_CORE/graph] GraphView(Constructor)") { + std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1"); + REQUIRE(g0 != nullptr); + REQUIRE(g1 != nullptr); +} + +TEST_CASE("[aidge/_CORE/graph] GraphView(add)") { + SECTION("Node alone") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); + g->add(GOp1); + std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 1, "Gop2"); + g->add(GOp2); + std::shared_ptr<Node> GOp3 = GenericOperator("Fictive", 1, 1, 0, "Gop3"); + g->add(GOp3); + std::shared_ptr<Node> GOp4 = GenericOperator("Fictive", 0, 1, 0, "Gop4"); + g->add(GOp4); + std::shared_ptr<Node> GOp5 = GenericOperator("Fictive", 1, 1, 1, "Gop5"); + g->add(GOp5); + std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 2, 1, "Gop6"); + g->add(GOp6); + } + + SECTION("Several Nodes") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + // should automaticaly add parents for learnable parameters + std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 1, 1, "Gop1"); + std::shared_ptr<Node> GOp1parent = GenericOperator("Fictive", 0, 0, 1, "Gop1parent"); + GOp1parent->addChild(GOp1, 0, 0); + g->add(GOp1); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + + // there should be no deplicates + g->add(GOp1); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + } + + SECTION("Initializer list ofr Node") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); + std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 0, "Gop2"); + g->add({GOp1, GOp1, GOp2}); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp2})); + } + + SECTION("another GraphView") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph-1"); + std::shared_ptr<GraphView> g2 = std::make_shared<GraphView>("TestGraph-2"); + auto conv = GenericOperator("Conv", 1, 1, 1, "c"); + auto conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + auto conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + auto conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + auto conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + conv->addChild(conv1); + conv1->addChild(conv2); + conv2->addChild(conv3); + conv3->addChild(conv4); + g1->add({conv, conv1, conv2, conv3, conv4}); + g2->add(g1); + REQUIRE(((g1->getNodes() == g2->getNodes()) && (g2->getNodes() == std::set<std::shared_ptr<Node>>({conv, conv1, conv2, conv3, conv4})))); + REQUIRE(((g1->inputNodes() == g2->inputNodes()) && + (g2->inputNodes() == std::set<std::shared_ptr<Node>>({conv})))); + REQUIRE(((g1->outputNodes() == g2->outputNodes()) && + (g2->outputNodes() == std::set<std::shared_ptr<Node>>({conv4})))); + } +} + +TEST_CASE("[aidge/_CORE/graph] GraphView(addChild)") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + SECTION("add(node)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv}); + } + g1->addChild(conv1, "c"); + SECTION("add(node, outputNodeName)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv1}); + REQUIRE(conv->getChildren() == std::set<std::shared_ptr<Node>>({conv1})); + REQUIRE(conv1->getParents() == std::vector<std::shared_ptr<Node>>({conv})); + } + g1->addChild(conv2, "c1", 0); + SECTION("add(node, pair<outputNodeName, outID>)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv2}); + REQUIRE(conv1->getChildren() == std::set<std::shared_ptr<Node>>({conv2})); + REQUIRE(conv2->getParents() == std::vector<std::shared_ptr<Node>>({conv1})); + } + g1->addChild(conv3, "c2", 0, 0); + SECTION("add(node, list(outputNodeName))") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3}); + REQUIRE(conv2->getChildren() == std::set<std::shared_ptr<Node>>({conv3})); + REQUIRE(conv3->getParents() == std::vector<std::shared_ptr<Node>>({conv2})); + } + g1->addChild(conv3_5, conv3); + SECTION("add(node, list(outputNodeName))") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3_5}); + REQUIRE(conv3->getChildren() == std::set<std::shared_ptr<Node>>({conv3_5})); + REQUIRE(conv3_5->getParents() == + std::vector<std::shared_ptr<Node>>({conv3})); + } + g1->addChild(conv4, conv3_5, 0); + SECTION("add(node, vector<pair<outputNodeName, outID>>)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv4}); + REQUIRE(conv3_5->getChildren() == std::set<std::shared_ptr<Node>>({conv4})); + REQUIRE(conv4->getParents() == + std::vector<std::shared_ptr<Node>>({conv3_5})); + } + g1->addChild(conv5, conv4, 0, 0); + SECTION("add(node, vector<pair<outputNodeName, outID>>)") { + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv5}); + REQUIRE(conv4->getChildren() == std::set<std::shared_ptr<Node>>({conv5})); + REQUIRE(conv5->getParents() == std::vector<std::shared_ptr<Node>>({conv4})); + } + std::set<std::shared_ptr<Node>> requiredNodes = {conv, conv1, conv2, conv3, + conv3_5, conv4, conv5}; + REQUIRE(g1->getNodes() == requiredNodes); + REQUIRE(g1->getChildren(conv3) == std::set<std::shared_ptr<Node>>({conv3_5})); +} + +TEST_CASE("[aidge/_CORE/graph] GraphView(inputs)") { + auto g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); + g1->add(conv); + + REQUIRE(g1->inputs() == conv->inputs()); +} + +TEST_CASE("[aidge/_CORE/graph] GraphView(outputs)") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); + g1->add(conv); + + REQUIRE(g1->outputs() == conv->outputs()); +} + +TEST_CASE("[aidge/_CORE/graph] GraphView(save)") { + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + g1->addChild(conv1, "c"); + g1->addChild(conv2, "c1", 0); + g1->addChild(conv3, "c2"); + g1->addChild(conv4, "c3", 0); + g1->addChild(conv5, "c4", 0, 0); + + g1->save("./graphExample"); + printf("File saved in ./graphExample.md\n"); +} + +TEST_CASE("[aidge/_CORE/graph] GraphView(resetConnections)") { + SECTION("disconnect data iput") { + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 3, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); + std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); + conv->addChild(conv1); + prod1->addChild(conv1,0,1); + prod2->addChild(conv1,0,2); + conv1->addChild(conv2); + + conv1->resetConnections(false); + + REQUIRE(conv->output(0).size() == 0); + for (std::size_t i = 0; i < conv1->nbDataInputs(); ++i) { + REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + } + REQUIRE((conv1->input(1) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod1, 0))); + REQUIRE((conv1->input(2) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod2, 0))); + REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) { + REQUIRE(conv->output(i).size() == 0U); + } + } + + SECTION("disconnect data iput + learnable parameters") { + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 3, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); + std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); + conv->addChild(conv1); + prod1->addChild(conv1,0,1); + prod2->addChild(conv1,0,2); + conv1->addChild(conv2); + + conv1->resetConnections(true); + + REQUIRE(conv->output(0).size() == 0); + for (std::size_t i = 0; i < conv1->nbInputs(); ++i) { + REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + } + REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); + for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) { + REQUIRE(conv->output(i).size() == 0U); + } + } +} + +TEST_CASE("Graph Forward dims", "[GraphView]") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(64, 10, {1, 1}, "conv3"); + auto g = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g->add(conv1); + g->addChild(conv2, conv1, 0); + g->addChild(conv3, conv2, 0); + g->save("graphForwardDims"); + g->forwardDims(); + + SECTION("Check input-output connections") { + REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getInput(1) == g->getNode("conv1_w")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getInput(2) == g->getNode("conv1_b")->getOperator()->getOutput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(conv2->getOperator()->getInput(1) == g->getNode("conv2_w")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getInput(2) == g->getNode("conv2_b")->getOperator()->getOutput(0)); + REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv3->getOperator()->getInput(1) == g->getNode("conv3_w")->getOperator()->getOutput(0)); + REQUIRE(conv3->getOperator()->getInput(2) == g->getNode("conv3_b")->getOperator()->getOutput(0)); + } + + SECTION("Check forwarded dims") { + REQUIRE(std::static_pointer_cast<Tensor>(conv1->getOperator()->getOutput(0)) + ->dims() == std::vector<DimSize_t>({16, 32, 222, 222})); + REQUIRE(std::static_pointer_cast<Tensor>(conv2->getOperator()->getOutput(0)) + ->dims() == std::vector<DimSize_t>({16, 64, 220, 220})); + } +} + +TEST_CASE("[aidge/_CORE/graph] GraphView(replaceWith)") { + SECTION("replace small pattern") { + // create original graph + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); + auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w"); + auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b"); + auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); + auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); + auto matmul = GenericOperator("MatMul", 1, 2, 1, "matmul"); + auto add = GenericOperator("Add", 1, 2, 1, "add"); + otherInput->addChild(other1); + other1->addChild(matmul); + matmul->addChild(add); + add->addChild(other2); + matmulWeight->addChild(matmul, 0, 1); + addBias->addChild(add, 0, 1); + g->add({other1, matmul, add, other2}); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add})); + + // create graph to replace + std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>(); + nodeToReplace->add({matmul, add}, false); + + // create replacing graph + std::shared_ptr<Node> newNode = GenericOperator("FC", 1, 3, 1, "fc"); + other1->addChild(newNode); + matmulWeight->addChild(newNode, 0, 1); + addBias->addChild(newNode, 0, 2); + + // replace + nodeToReplace->replaceWith({newNode}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, newNode})); + } + SECTION("replace with nothing") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + auto r1 = GenericOperator("relu", 0, 0, 1); + auto r2 = GenericOperator("relu", 1, 1, 1); + auto r3 = GenericOperator("relu", 1, 1, 1); + auto r4 = GenericOperator("relu", 1, 1, 0); + r1->addChild(r2); + r2->addChild(r3); + r3->addChild(r4); + g->add({r1, r2, r3, r4}); + auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3}); + auto graphToReplace = std::make_shared<GraphView>(); + graphToReplace->add(nodesToReplace); + graphToReplace->replaceWith({}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); + REQUIRE((r1->output(0))[0].first == r4); + } +} \ No newline at end of file diff --git a/tests/graphMatching/Test_GRegex.cpp b/tests/graphMatching/Test_GRegex.cpp new file mode 100644 index 000000000..9f600ea3f --- /dev/null +++ b/tests/graphMatching/Test_GRegex.cpp @@ -0,0 +1,306 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <iostream> +#include <map> +#include <memory> +#include <vector> +#include <utility> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> +//test +#include <graphmatching/GRegex.hpp> +#include <graphmatching/StmFactory.hpp> +#include <graphmatching/SeqStm.hpp> +#include <graphmatching/NodeRegex.hpp> +#include <graphmatching/Match.hpp> +//use +#include <backend/OperatorImpl.hpp> +#include <operator/GenericOperator.hpp> +#include <operator/Producer.hpp> +#include <graph/GraphView.hpp> + +using namespace Aidge; + +TEST_CASE("Create good init GRegex", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("A->B;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + // Perform tests + REQUIRE(GReg.getStmInit().size() == 1); + REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); +} + + +TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"Conv","BN","ReLU"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("Conv->BN->ReLU;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); + std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> Random2 = GenericOperator("Random2", 1, 1, 1); + + + g1->add(Conv1); + g1->addChild(BN1, Conv1); + g1->addChild(ReLU1, BN1); + g1->addChild(Random, ReLU1); + //g1->addChild(BN1, Random2); + + std::vector<std::shared_ptr<Node>> startNodes1; + std::set<std::shared_ptr<Node>> result; + + startNodes1.push_back(Conv1); + result = GReg.matchFromStartNodes(startNodes1, g1); + + std::set<std::shared_ptr<Node>> true_result; + true_result.insert(Conv1); + true_result.insert(BN1); + true_result.insert(ReLU1); + + // Perform tests + REQUIRE(result == true_result); +} + +TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"Add","FC","Conv"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("Add#->Conv;"); + seqRegex.push_back("Add#->FC;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + // Instanciate a graphView + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> Add1 = GenericOperator("Add", 1, 1, 1); + std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); + std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); + + g1->add(Random0); + g1->addChild(Add1, Random0); + g1->addChild(Conv1, Add1); + g1->addChild(BN1, Conv1); + g1->addChild(ReLU1, BN1); + g1->addChild(FC1, Add1); + g1->addChild(Random, FC1); + + // Test 1 : Find the match + std::vector<std::shared_ptr<Node>> startNodes; + std::set<std::shared_ptr<Node>> result; + + startNodes.push_back(Add1); + startNodes.push_back(Add1); + result = GReg.matchFromStartNodes(startNodes, g1); + + std::set<std::shared_ptr<Node>> true_result; + true_result.insert(Add1); + true_result.insert(Conv1); + true_result.insert(FC1); + + // Test 2 : Return an empty set when the start nodes are wrong + std::vector<std::shared_ptr<Node>> wrong_startNodes; + std::set<std::shared_ptr<Node>> wrong_start_result; + std::set<std::shared_ptr<Node>> empty_result; + + wrong_startNodes.push_back(Random0); + wrong_startNodes.push_back(Random0); + wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); + + // Perform tests + REQUIRE(result == true_result); + REQUIRE(wrong_start_result == empty_result); +} + +/* +TEST_CASE("Function matchFromStartNodes | Match a sequence with quantifier ", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"FC"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("FC+;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + + // Instanciate a graphView + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> FC2 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> FC3 = GenericOperator("FC", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + + g1->add(Random0); + g1->addChild(FC1, Random0); + g1->addChild(FC2, FC1); + g1->addChild(FC3, FC2); + g1->addChild(ReLU1, FC3); + + // Test 1 : Find the match + std::vector<std::shared_ptr<Node>> startNodes; + std::set<std::shared_ptr<Node>> result; + + startNodes.push_back(FC1); + result = GReg.matchFromStartNodes(startNodes, g1); + + std::set<std::shared_ptr<Node>> true_result; + true_result.insert(FC1); + true_result.insert(FC2); + true_result.insert(FC3); + + // Test 2 : Return an empty set when the start nodes are wrong + std::vector<std::shared_ptr<Node>> wrong_startNodes; + std::set<std::shared_ptr<Node>> wrong_start_result; + std::set<std::shared_ptr<Node>> empty_result; + + wrong_startNodes.push_back(Random0); + wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); + + // Perform tests + REQUIRE(result == true_result); + REQUIRE(wrong_start_result == empty_result); +} +*/ + +TEST_CASE("Function match | ALL matches of Nodes sequence", "[GRegex]") { + // init all input for GRegex + // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex + // Sequential Regex vector : std::vector<std::string>& seqRegexps + + // init the Nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"GEMM"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + // init the Sequential Regex vector + std::vector<std::string> seqRegex; + seqRegex.push_back("GEMM;"); + + // Instanciate a GRegex + GRegex GReg(nodesRegex, seqRegex); + + //init the input graph + std::shared_ptr<GraphView> graphToMatch = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); + std::shared_ptr<Node> GEMM1 = GenericOperator("GEMM", 1, 1, 1); + std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> GEMM2 = GenericOperator("GEMM", 1, 1, 1); + std::shared_ptr<Node> GEMM3 = GenericOperator("GEMM", 1, 1, 1); + std::shared_ptr<Node> ReLU2 = GenericOperator("ReLU", 1, 1, 1); + std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); + + graphToMatch->add(Random0); + graphToMatch->addChild(GEMM1, Random0); + graphToMatch->addChild(ReLU1, GEMM1); + graphToMatch->addChild(GEMM2, ReLU1); + graphToMatch->addChild(GEMM3, GEMM2); + graphToMatch->addChild(ReLU2, GEMM3); + graphToMatch->addChild(Random, ReLU2); + + + //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); + //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); + Match matches = GReg.match(graphToMatch); + + size_t nb = matches.getNbMatch(); + std::vector<std::vector<NodeTmp>> gm_startnodes = matches.getStartNodes(); + std::vector<std::set<NodeTmp>> gm_matchnodes = matches.getMatchNodes(); + + std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs; + + for (size_t i = 0; i < nb; ++i) { + matchs.insert(std::make_pair(gm_startnodes[i], gm_matchnodes[i])); + } + + //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; + std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; + // Carefull : as the assert is on a vector, the Order of match matters + std::vector<NodeTmp> startNode = {GEMM1}; + std::set<NodeTmp> matchNode = {GEMM1}; + //toMatchs.push_back(std::make_pair(startNode,matchNode)); + toMatchs.insert(std::make_pair(startNode,matchNode)); + + std::vector<NodeTmp> startNode2 = {GEMM2}; + std::set<NodeTmp> matchNode2 = {GEMM2}; + //toMatchs.push_back(std::make_pair(startNode2,matchNode2)); + toMatchs.insert(std::make_pair(startNode2,matchNode2)); + + std::vector<NodeTmp> startNode3 = {GEMM3}; + std::set<NodeTmp> matchNode3 = {GEMM3}; + //toMatchs.push_back(std::make_pair(startNode3,matchNode3)); + toMatchs.insert(std::make_pair(startNode3,matchNode3)); + + REQUIRE(matchs == toMatchs); + REQUIRE(nb == 3); +} + + diff --git a/tests/graphMatching/Test_NodeRegex.cpp b/tests/graphMatching/Test_NodeRegex.cpp new file mode 100644 index 000000000..fd583a040 --- /dev/null +++ b/tests/graphMatching/Test_NodeRegex.cpp @@ -0,0 +1,44 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <iostream> +#include <map> +#include <memory> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> + +#include <backend/OperatorImpl.hpp> +#include <graphmatching/NodeRegex.hpp> +#include <operator/GenericOperator.hpp> + + +using namespace Aidge; + +TEST_CASE("Create Noderegex", "[Noderegex]") { + std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("conv"); +} + +TEST_CASE("Test _is function", "[Noderegex]") { + // Create Noderegex with only condition on the name of the Node + // Create several operators to pass into Noderegex _is function + // Assert Noderegex._is(operators) are correct + std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("Conv"); + + std::shared_ptr<Node> Conv = GenericOperator("Conv", 1, 1, 1); + std::shared_ptr<Node> FC = GenericOperator("FC", 1, 1, 1); + + REQUIRE(nr->_is(Conv) == true); + REQUIRE(nr->_is(FC) == false); + REQUIRE(nr->isA("Conv") == true); + REQUIRE(nr->isA("FC") == false); + +} \ No newline at end of file diff --git a/tests/graphMatching/Test_SeqStm.cpp b/tests/graphMatching/Test_SeqStm.cpp new file mode 100644 index 000000000..209bde758 --- /dev/null +++ b/tests/graphMatching/Test_SeqStm.cpp @@ -0,0 +1,159 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <iostream> +#include <map> +#include <memory> +#include <vector> +#include <utility> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> +//test +#include <graphmatching/SeqStm.hpp> +#include <graphmatching/NodeRegex.hpp> +//use +#include <backend/OperatorImpl.hpp> +#include <operator/GenericOperator.hpp> +#include <operator/Producer.hpp> + +using namespace Aidge; + +TEST_CASE("Create good init SeqStm", "[SeqStm]") { + //init all iniput for SeqStm + + + int stmIdx = 0; + //matrix that in B->C + std::vector<std::vector<int>> transitionMatrix { + { -1, 1, -1 }, + { -1, -1, 2 }, + { -1, -1, -1 } }; + + //std::cout << transitionMatrix.size() << "\n"; + // init the nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + // + + std::map<NodeTypeKey,int> typeToIdxTransition; + std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; + //init nodeTypeCommonTag + int idx = 0; + for (const NodeTypeKey& key : nodeTypeCommonTag) { + typeToIdxTransition[key] = idx; + idx += 1; + } + + int actSt = 0; + std::set<NodeTmp> allNodeValidated; + std::set<NodeTmp> allNodeTested; + std::set<std::pair<NodeTmp,std::string>> allCommonNode; + bool stmIsValid =false; + + + SeqStm stm( + stmIdx, + transitionMatrix, + nodesRegex, + typeToIdxTransition, + actSt, + allNodeValidated, + allNodeTested, + allCommonNode, + stmIsValid); + + REQUIRE(stm.getStmIdx() == 0); + REQUIRE(stm.isValid() == false); + REQUIRE(stm.getAllCommonNode().size() == 0); + REQUIRE(stm.getAllNodeTested().size() == 0); + REQUIRE(stm.getAllNodeValidated().size() == 0); +} + +TEST_CASE("Test testNode function", "[SeqStm]") { + + int stmIdx = 0; + std::map<NodeTypeKey,int> typeToIdxTransition; + std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; + //init nodeTypeCommonTag + int idx = 0; + for (const NodeTypeKey& key : nodeTypeCommonTag) { + typeToIdxTransition[key] = idx; + idx += 1; + } + //matrix that in B->C + std::vector<std::vector<int>> transitionMatrix { + { -1, 1, -1 }, + { -1, -1, 2 }, + { -1, -1, -1 } }; + + //std::cout << transitionMatrix.size() << "\n"; + // init the nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + // + int actSt = 0; + std::set<NodeTmp> allNodeValidated; + std::set<NodeTmp> allNodeTested; + std::set<std::pair<NodeTmp,std::string>> allCommonNode; + bool stmIsValid =false; + + SeqStm stm( + stmIdx, + transitionMatrix, + nodesRegex, + typeToIdxTransition, + actSt, + allNodeValidated, + allNodeTested, + allCommonNode, + stmIsValid); + REQUIRE(stm.getStmIdx() == 0); + //test a node + std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); + std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); + + //set use to test the state of the smt + std::set<NodeTmp> testAllNodeTested; + std::set<NodeTmp> testAllNodeValidated; + + stm.testNode(nodeB); + REQUIRE(stm.isValid() == false); + REQUIRE(stm.getState() == 1); + REQUIRE(stm.isStmBlocked() == false); + testAllNodeTested.insert(nodeB); + testAllNodeValidated.insert(nodeB); + REQUIRE(stm.getAllNodeTested() == testAllNodeTested); + REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); + + + stm.testNode(nodeC); + REQUIRE(stm.isValid() == true); + REQUIRE(stm.getState() == 2); + REQUIRE(stm.isStmBlocked() == false); + testAllNodeTested.insert(nodeC); + testAllNodeValidated.insert(nodeC); + REQUIRE(stm.getAllNodeTested() == testAllNodeTested); + REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); + + stm.testNode(nodeC); + REQUIRE(stm.isValid() == true); + REQUIRE(stm.getState() == -1); + REQUIRE(stm.isStmBlocked() == true); + REQUIRE(stm.getAllNodeTested() == testAllNodeTested); + REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); +} \ No newline at end of file diff --git a/tests/graphMatching/Test_StmFactory.cpp b/tests/graphMatching/Test_StmFactory.cpp new file mode 100644 index 000000000..2bc3471d3 --- /dev/null +++ b/tests/graphMatching/Test_StmFactory.cpp @@ -0,0 +1,189 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <iostream> +#include <map> +#include <memory> +#include <vector> +#include <utility> +#include <cassert> + +#include <catch2/catch_test_macros.hpp> +//test +#include <graphmatching/StmFactory.hpp> +#include <graphmatching/NodeRegex.hpp> +//use +#include <backend/OperatorImpl.hpp> +#include <operator/GenericOperator.hpp> +#include <operator/Producer.hpp> + +using namespace Aidge; + +TEST_CASE("Create good init StmFactory", "[StmFactory]") { + // init the nodes Regex map + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + StmFactory stmF(nodesRegex); + REQUIRE(stmF.getNumberOfStm() == 0); +} + +TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { + + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + StmFactory stmF(nodesRegex); + + std::string seq1 = "A->B+->A#;"; + SeqStm* stm = stmF.makeNewStm(seq1); + REQUIRE(stm->getStmIdx() == 0); + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getAllCommonNode().size() == 0); + REQUIRE(stm->getAllNodeTested().size() == 0); + REQUIRE(stm->getAllNodeValidated().size() == 0); + + std::string seq2 = "A->B;"; + SeqStm* stm2 = stmF.makeNewStm(seq2); + REQUIRE(stm2->getStmIdx() == 1); + REQUIRE(stm2->isValid() == false); + REQUIRE(stm2->getAllCommonNode().size() == 0); + REQUIRE(stm2->getAllNodeTested().size() == 0); + REQUIRE(stm2->getAllNodeValidated().size() == 0); + + //test the number of stm + REQUIRE(stmF.getNumberOfStm() == 2); +} + +TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { + + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + + StmFactory stmF(nodesRegex); + std::string seq1 = "B->C;"; + SeqStm* stm = stmF.makeNewStm(seq1); + //test the number of stm + REQUIRE(stmF.getNumberOfStm() == 1); + + //std::shared_ptr<Node> nodeB = GenericOperator("B",1,1,1); + //std::shared_ptr<Node> nodeC = GenericiOperator("C",1,1,1); + std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); + std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); + //set use to test the state of the smt + std::set<NodeTmp> testAllNodeTested; + std::set<NodeTmp> testAllNodeValidated; + + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 0); + REQUIRE(stm->isStmBlocked() == false); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeB); + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 1); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeB); + testAllNodeValidated.insert(nodeB); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == 2); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeC); + testAllNodeValidated.insert(nodeC); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == -1); + REQUIRE(stm->isStmBlocked() == true); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + +} + +TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { + + std::map<std::string,NodeRegex*> nodesRegex ; + std::vector<std::string> nodeTypeKey {"A","B","C"}; + for (const std::string& key : nodeTypeKey) { + nodesRegex[key] = new NodeRegex(key); + } + + + StmFactory stmF(nodesRegex); + std::string seq1 = "B->C;"; + SeqStm* stm = stmF.makeNewStm(seq1); + SeqStm* stmD = stmF.duplicateStm(stm); + + std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); + std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); + //set use to test the state of the smt + std::set<NodeTmp> testAllNodeTested; + std::set<NodeTmp> testAllNodeValidated; + + //run the stm + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 0); + REQUIRE(stm->isStmBlocked() == false); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeB); + REQUIRE(stm->isValid() == false); + REQUIRE(stm->getState() == 1); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeB); + testAllNodeValidated.insert(nodeB); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == 2); + REQUIRE(stm->isStmBlocked() == false); + testAllNodeTested.insert(nodeC); + testAllNodeValidated.insert(nodeC); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + stm->testNode(nodeC); + REQUIRE(stm->isValid() == true); + REQUIRE(stm->getState() == -1); + REQUIRE(stm->isStmBlocked() == true); + REQUIRE(stm->getAllNodeTested() == testAllNodeTested); + REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + + //check if stmD not move + REQUIRE(stmD->isValid() == false); + REQUIRE(stmD->getState() == 0); + REQUIRE(stmD->isStmBlocked() == false); + REQUIRE(stmD->getAllNodeTested().size() == 0); + REQUIRE(stmD->getAllNodeValidated().size() == 0); +} + diff --git a/tests/operator/Test_GenericOperator.cpp b/tests/operator/Test_GenericOperator.cpp new file mode 100644 index 000000000..ef7614431 --- /dev/null +++ b/tests/operator/Test_GenericOperator.cpp @@ -0,0 +1,77 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "operator/GenericOperator.hpp" +#include "graph/GraphView.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[aidge/_CORE/operators] GenericOp(add & get parameters)", "[Operator]") { + SECTION("INT") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + int value = 5; + const char* key = "intParam"; + Testop.addParameter<int>(key, value); + REQUIRE(Testop.getParameter<int>(key) == value); + } + SECTION("LONG") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + long value = 3; + const char* key = "longParam"; + Testop.addParameter<long>(key, value); + REQUIRE(Testop.getParameter<long>(key) == value); + } + SECTION("FLOAT") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + float value = 2.0; + const char* key = "floatParam"; + Testop.addParameter<float>(key, value); + REQUIRE(Testop.getParameter<float>(key) == value); + } + SECTION("VECTOR<INT>") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + std::vector<int> value = {1, 2}; + const char* key = "vect"; + Testop.addParameter<std::vector<int>>(key, value); + + REQUIRE(Testop.getParameter<std::vector<int>>(key).size() == value.size()); + for (std::size_t i=0; i < value.size(); ++i){ + REQUIRE(Testop.getParameter<std::vector<int>>(key)[i] == value[i]); + } + } + SECTION("MULTIPLE PARAMS") { + /* + Goal : Test that the offsets are well done by adding different parameters with different size. + */ + GenericOperator_Op Testop("TestOp", 1, 1, 1); + Testop.addParameter<long>("longParam", 3); + Testop.addParameter<float>("floatParam", 2.0); + Testop.addParameter<uint8_t>("uint8Param", 5); + Testop.addParameter<long long>("llParam", 10); + REQUIRE(Testop.getParameter<long>("longParam") == 3); + REQUIRE(Testop.getParameter<float>("floatParam") == 2.0); + REQUIRE(Testop.getParameter<uint8_t>("uint8Param") == 5); + REQUIRE(Testop.getParameter<long long>("llParam") == 10); + } +} + +TEST_CASE("[aidge/_CORE/operator] GenericOp(type check)", "[.ass]") { + SECTION("WRONG TYPE FOR GETTER") { + GenericOperator_Op Testop("TestOp", 1, 1, 1); + Testop.addParameter<long>("longParam", 3); + + // This line should raise a failled assert + REQUIRE_THROWS(Testop.getParameter<int>("longParameter")); + } +} -- GitLab