diff --git a/.gitlab/ci/build.gitlab-ci.yml b/.gitlab/ci/build.gitlab-ci.yml index 8d896c8ec9eb92dd87689d84cad5fc09bf03c4f1..a4579e2951ccbafc4335ae428c62eba94c0757e5 100644 --- a/.gitlab/ci/build.gitlab-ci.yml +++ b/.gitlab/ci/build.gitlab-ci.yml @@ -95,60 +95,60 @@ build:ubuntu_python: paths: - venv/ -# build:windows_cpp: -# stage: build -# needs: [] -# tags: -# - windows - -# image: buildtools -# before_script: -# # Install Chocolatey -# - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) -# # Install dependencies -# - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y -# - choco install git -Y -# - choco install python -Y -# # Update PATH -# - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") -# script: -# - mkdir -p build_cpp -# - mkdir -p install_cpp -# - cd build_cpp -# - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug .. -# - cmake --build . -j2 -# - cmake --install . --config Debug - -# artifacts: -# expire_in: 1 week -# paths: -# - build_cpp/ -# - install_cpp/ - -# build:windows_python: -# stage: build -# needs: [] -# tags: -# - windows - -# image: buildtools -# before_script: -# # Install Chocolatey -# - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) -# # Install dependencies -# - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y -# - choco install git -Y -# - choco install python -Y -# # Update PATH -# - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") -# script: -# - python -m pip install virtualenv -# - virtualenv venv -# - venv\Scripts\Activate.ps1 -# # Numpy dependancy for unit test -# - python -m pip install -r requirements.txt -# - python -m pip install . -# artifacts: -# expire_in: 1 week -# paths: -# - venv/ +build:windows_cpp: + stage: build + needs: [] + tags: + - windows + + image: buildtools + before_script: + # Install Chocolatey + - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + # Install dependencies + - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + - choco install git -Y + - choco install python -Y + # Update PATH + - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + script: + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug .. + - cmake --build . -j2 + - cmake --install . --config Debug + + artifacts: + expire_in: 1 week + paths: + - build_cpp/ + - install_cpp/ + +build:windows_python: + stage: build + needs: [] + tags: + - windows + + image: buildtools + before_script: + # Install Chocolatey + - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + # Install dependencies + - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + - choco install git -Y + - choco install python -Y + # Update PATH + - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + script: + - python -m pip install virtualenv + - virtualenv venv + - venv\Scripts\Activate.ps1 + # Numpy dependancy for unit test + - python -m pip install -r requirements.txt + - python -m pip install . + artifacts: + expire_in: 1 week + paths: + - venv/ diff --git a/.gitlab/ci/test.gitlab-ci.yml b/.gitlab/ci/test.gitlab-ci.yml index abe526cdf3fac882177509cade20e5ed58ed7f77..81e6ca9ac5b868287aa0ef27040c0ead785d3639 100644 --- a/.gitlab/ci/test.gitlab-ci.yml +++ b/.gitlab/ci/test.gitlab-ci.yml @@ -26,23 +26,23 @@ test:ubuntu_python: reports: junit: ${CI_PROJECT_NAME}/xmlrunner-results.xml -# test:windows_cpp: -# stage: test -# needs: ["build:windows_cpp"] -# tags: -# - windows -# image: buildtools -# before_script: -# # Install Chocolatey -# - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) -# # Install dependencies -# - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y -# - choco install python -Y -# # Update PATH -# - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") -# script: -# - cd build_cpp -# - ctest --output-junit ctest-results.xml --output-on-failure -# artifacts: -# reports: -# junit: build_cpp/ctest-results.xml +test:windows_cpp: + stage: test + needs: ["build:windows_cpp"] + tags: + - windows + image: buildtools + before_script: + # Install Chocolatey + - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + # Install dependencies + - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + - choco install python -Y + # Update PATH + - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + script: + - cd build_cpp + - ctest --output-junit ctest-results.xml --output-on-failure + artifacts: + reports: + junit: build_cpp/ctest-results.xml diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 36947ca7f5b1862276fe3e9e5df32b2f9dbb4dfc..eda3ee34ba234cc2714d4424128efb647f45e63d 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -15,7 +15,7 @@ #include <cstring> #include <set> #include <memory> -#include <numeric> +#include <numeric> // std::accumulate #include <string> #include <vector> @@ -365,11 +365,11 @@ class Tensor : public Data, /** * @brief Change the dimensions of the Tensor object according to the given argument. - * If the overall size is not changed (meaning we actually only performed a + * If the overall size is not changed (meaning we actually only performed a * reshape), data is garanteed to remain valid. - * Otherwise, no garantee is provided regarding the validy of previous data - * (unlike std::vector). If the new overall size is larger than the previous - * one, all previous data is invalided. Otherwise, previous data may or may + * Otherwise, no garantee is provided regarding the validy of previous data + * (unlike std::vector). If the new overall size is larger than the previous + * one, all previous data is invalided. Otherwise, previous data may or may * not remain valid, depending on the backend implementation. * @tparam DIM Number of dimensions. * @param dims New dimensions @@ -381,11 +381,11 @@ class Tensor : public Data, /** * @brief Change the dimensions of the Tensor object according to the given argument. - * If the overall size is not changed (meaning we actually only performed a + * If the overall size is not changed (meaning we actually only performed a * reshape), data is garanteed to remain valid. - * Otherwise, no garantee is provided regarding the validy of previous data - * (unlike std::vector). If the new overall size is larger than the previous - * one, all previous data is invalided. Otherwise, previous data may or may + * Otherwise, no garantee is provided regarding the validy of previous data + * (unlike std::vector). If the new overall size is larger than the previous + * one, all previous data is invalided. Otherwise, previous data may or may * not remain valid, depending on the backend implementation. * @param dims New dimensions * @param strides Stride of the tensor (if not specified, "nested" stride is used) @@ -501,7 +501,7 @@ class Tensor : public Data, return std::string("?"); // To make Clang happy }; - if (dims().empty()) { return "{}"; } + if (dims().empty()) { return ptrToString(mDataType, mImpl->hostPtr(), 0); } std::string res; std::size_t dim = 0; std::size_t counter = 0; @@ -671,22 +671,22 @@ class Tensor : public Data, /** * Copy-cast data from a Tensor. * @param src Source tensor to copy-cast from. - * @param movedSrc shared_ptr to an indermediate Tensor that will - * contain the moved data if a device change should occur AND a type + * @param movedSrc shared_ptr to an indermediate Tensor that will + * contain the moved data if a device change should occur AND a type * conversion is necessary (otherwise it remains unused). - * Any data already present will be overwritten. No new memory allocation - * will occur if movedSrc has already been allocated with the right + * Any data already present will be overwritten. No new memory allocation + * will occur if movedSrc has already been allocated with the right * type/size/device. - * If required, memory is always allocated on current (destination) + * If required, memory is always allocated on current (destination) * Tensor's device. */ void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc); /** * Copy-cast data from a Tensor. - * In case of both a device change AND a data type conversion, an + * In case of both a device change AND a data type conversion, an * intermediate buffer on will be allocated and deallocated each time. - * If required, buffer's memory is always allocated on current (destination) + * If required, buffer's memory is always allocated on current (destination) * Tensor's device. * @param src Source tensor to copy-cast from. */ @@ -718,7 +718,7 @@ class Tensor : public Data, * The backend stays the same. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * The shared_ptr does not need to be initialized. No new memory allocation - * will occur if fallback has already been allocated with the right + * will occur if fallback has already been allocated with the right * type/size/device. * @param dt The desired data type. * @return Reference to either itself or to fallback. @@ -733,7 +733,7 @@ class Tensor : public Data, * The data type stays the same. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * The shared_ptr does not need to be initialized. No new memory allocation - * will occur if fallback has already been allocated with the right + * will occur if fallback has already been allocated with the right * type/size/device. * @param backend The desired backend. * @param device The desired device. @@ -746,11 +746,11 @@ class Tensor : public Data, * Return a reference to a Tensor on desired data type and backend/device: * - itself, if already with the right characteristics; * - the provided Tensor, overwritten with the copy-casted data. - * If required, fallback is always allocated on desired (destination) + * If required, fallback is always allocated on desired (destination) * device. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * The shared_ptr does not need to be initialized. No new memory allocation - * will occur if fallback has already been allocated with the right + * will occur if fallback has already been allocated with the right * type/size/device. * @param dt The desired data type. * @param backend The desired backend. @@ -767,11 +767,11 @@ class Tensor : public Data, * (data type, backend/device) as targetReqs Tensor: * - itself, if already with the right characteristics; * - the provided Tensor, overwritten with the copy-casted data. - * If required, fallback is always allocated on current (destination) + * If required, fallback is always allocated on current (destination) * Tensor's device. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * The shared_ptr does not need to be initialized. No new memory allocation - * will occur if fallback has already been allocated with the right + * will occur if fallback has already been allocated with the right * type/size/device. * @param targetReqs Tensor with the desired target characteristics. * @return Reference to either itself or to fallback. @@ -783,15 +783,8 @@ class Tensor : public Data, private: ///\bug not protected against overflow - std::size_t computeSize() { - if (mDims.empty()) { - mSize = DimSize_t(0); - } - else { - mSize = std::accumulate(mDims.begin(), mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>()); - } - - return mSize; + void computeSize() { + mSize = std::accumulate(mDims.begin(), mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>()); } }; } // namespace Aidge diff --git a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp index 2f6066ba4cd97284c43b509c9d5eb988b65b53a5..36d09db47d23395d649a688252f2af803cb1bc9d 100644 --- a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp @@ -10,164 +10,162 @@ #include "aidge/graph/Node.hpp" - namespace Aidge{ - class FsmNode; +class FsmNode; + +/** + * @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph + * all node that have been Validate,Rejecte or Considered common +*/ +class FsmRunTimeContext +{ +private: + /** + * @brief the list of node rejected for all the context + */ + static std::vector<std::set<NodePtr>> mRejectedNodes; + /** + * @brief the actual state of this Context (where it's in the FSM graph) + */ + std::shared_ptr<FsmNode> mActState; + /** + * @brief the actual node of this Context (where it's in the graph) + */ + NodePtr mActOpNode; + /** + * @brief the map of the node consider as common and the common ID + * @details we need to store what node it's consider as common because of the end + * resolution of the matching, all node consider as common need to be the same in all context + */ + std::map<NodePtr,std::size_t> mCommonNodes; + /** + * @brief the map of the node that as been valid in this context , and the test that valide the node + */ + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes; + /** + * @brief the index in the rejected node of this context + */ + std::size_t mLocalIdxRejeced; +public: + /** + * @brief constructor + * @param actState the actual state in the FSM + * @param actOpNode the actual node in the graph + * @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind + */ + FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() ); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ); + + virtual ~FsmRunTimeContext()=default; + + /** + * @defgroup FsmRunTimeContextRejected Function for managing rejected nodes + */ + + /** + * @ingroup FsmRunTimeContextRejected + * @brief Add a node as rejected in this context + */ + void addRejectedNode(NodePtr node); + + /** + * @ingroup FsmRunTimeContextRejected + * @brief get the rejected nodes of this context + */ + inline std::set<NodePtr> getRejectedNodes(void) const { + return mRejectedNodes[mLocalIdxRejeced]; + } + + + /** + * @defgroup FsmRunTimeContextTest Function for test the context + */ - class FsmNode; + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the actual state is valide + * @return bool + */ + bool isOnValidState(void); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the node is considered as common in this context + * @param node node to test + * @return bool + */ + bool isCommonDefined(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if has already validated in this context + * @param node node to test + * @return bool + */ + bool isAlreadyValid(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is compatible with an others + * @details to say that two contexts are compatible is to check : + * that the contexts do not validate the same nodes (other than the common ones) + * and that the common ones have the same idx + * @param fsmContext the others context + * @return bool + */ + bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is strictly equal with an others + * @param fsmContext the others context + * @return bool + */ + bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext); /** - * @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph - * all node that have been Validate,Rejecte or Considered common + * @defgroup FsmRunTimeContextSet Function set context */ - class FsmRunTimeContext - { - private: - /** - * @brief the list of node rejected for all the context - */ - static std::vector<std::set<NodePtr>> mRejectedNodes; - /** - * @brief the actual state of this Context (where it's in the FSM graph) - */ - std::shared_ptr<FsmNode> mActState; - /** - * @brief the actual node of this Context (where it's in the graph) - */ - NodePtr mActOpNode; - /** - * @brief the map of the node consider as common and the common ID - * @details we need to store what node it's consider as common because of the end - * resolution of the matching, all node consider as common need to be the same in all context - */ - std::map<NodePtr,std::size_t> mCommonNodes; - /** - * @brief the map of the node that as been valid in this context , and the test that valide the node - */ - std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes; - /** - * @brief the index in the rejected node of this context - */ - std::size_t mLocalIdxRejeced; - public: - /** - * @brief constructor - * @param actState the actual state in the FSM - * @param actOpNode the actual node in the graph - * @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind - */ - FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() ); - FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime); - FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ); - - virtual ~FsmRunTimeContext()=default; - - /** - * @defgroup FsmRunTimeContextRejected Function for managing rejected nodes - */ - - /** - * @ingroup FsmRunTimeContextRejected - * @brief Add a node as rejected in this context - */ - void addRejectedNode(NodePtr node); - - /** - * @ingroup FsmRunTimeContextRejected - * @brief get the rejected nodes of this context - */ - std::set<NodePtr> getRejectedNodes(void); - - - /** - * @defgroup FsmRunTimeContextTest Function for test the context - */ - - /** - * @ingroup FsmRunTimeContextTest - * @brief test if the actual state is valide - * @return bool - */ - bool isOnValidState(void); - /** - * @ingroup FsmRunTimeContextTest - * @brief test if the node is considered as common in this context - * @param node node to test - * @return bool - */ - bool isCommonDefined(NodePtr node); - /** - * @ingroup FsmRunTimeContextTest - * @brief test if has already validated in this context - * @param node node to test - * @return bool - */ - bool isAlreadyValid(NodePtr node); - /** - * @ingroup FsmRunTimeContextTest - * @brief test if this context is compatible with an others - * @details to say that two contexts are compatible is to check : - * that the contexts do not validate the same nodes (other than the common ones) - * and that the common ones have the same idx - * @param fsmContext the others context - * @return bool - */ - bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext); - /** - * @ingroup FsmRunTimeContextTest - * @brief test if this context is strictly equal with an others - * @param fsmContext the others context - * @return bool - */ - bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext); - - /** - * @defgroup FsmRunTimeContextSet Function set context - */ - - - void setCommon(NodePtr node,std::size_t commonIdx); - - - void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag); - - /** - * @defgroup FsmRunTimeContextGet Function get context - */ - - - /** - * @ingroup FsmRunTimeContextGet - * @brief get the sub idx state - * @return bool - */ - std::size_t getSubStmId(void); - - NodePtr getCommonNodeFromIdx(std::size_t commonIdx); - std::size_t getCommonNodeIdx(NodePtr node); - std::set<NodePtr> getCommonNodes(void); - - std::map<NodePtr,std::size_t> getCommon(void); - std::set<NodePtr> getValidNodes(void); - - std::set<NodePtr> getValidNodesNoCommon(void); - std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void); - - - NodePtr getActNode(void); - std::shared_ptr<FsmNode> getActState(void); - - - /** - * @defgroup FsmRunTimeContextMem - */ - - void rst(void); - - - }; - -} - -#endif //AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ + + + void setCommon(NodePtr node,std::size_t commonIdx); + + + void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag); + + /** + * @defgroup FsmRunTimeContextGet Function get context + */ + + + /** + * @ingroup FsmRunTimeContextGet + * @brief get the sub idx state + * @return bool + */ + std::size_t getSubStmId(void); + + NodePtr getCommonNodeFromIdx(std::size_t commonIdx); + std::size_t getCommonNodeIdx(NodePtr node); + std::set<NodePtr> getCommonNodes(void); + + std::map<NodePtr,std::size_t> getCommon(void); + std::set<NodePtr> getValidNodes(void); + + std::set<NodePtr> getValidNodesNoCommon(void); + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void); + + + NodePtr getActNode(void); + std::shared_ptr<FsmNode> getActState(void); + + + /** + * @defgroup FsmRunTimeContextMem + */ + + void rst(void); + + +}; +} // namespace Aidge + +#endif // AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ diff --git a/include/aidge/graphRegex/matchFsm/MatchResult.hpp b/include/aidge/graphRegex/matchFsm/MatchResult.hpp index 4f7f9bf1dd9b0612e71a1f7894bfc382713c0ad0..7954e932a20940946f444cf7277e2bf359f7f15a 100644 --- a/include/aidge/graphRegex/matchFsm/MatchResult.hpp +++ b/include/aidge/graphRegex/matchFsm/MatchResult.hpp @@ -24,8 +24,10 @@ private: const std::vector<NodePtr> mStartNode; public: + MatchSolution() = delete; MatchSolution(std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string query,const std::vector<NodePtr> startNode); - inline const std::set<NodePtr>& at(const std::string key) { + + inline const std::set<NodePtr>& at(const std::string& key) { return mSolution[key]; } const std::set<NodePtr> getAll(); @@ -33,7 +35,6 @@ public: inline const std::string& getQuery() const noexcept { return mQueryFrom; } inline const std::vector<NodePtr>& getStartNode() const noexcept { return mStartNode; } - }; @@ -60,14 +61,18 @@ private: public: - MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm, - const std::string& query,const std::vector<NodePtr>& startNodes); + MatchResult() = delete; + MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, + std::size_t nbSubStm, + const std::string& query,const std::vector<NodePtr>& startNodes); /** * @brief get the set of the node match for une expression * @return the set of node of the graph that corresponding to an expression */ - std::shared_ptr<MatchSolution> getBiggerSolution(void); + inline std::shared_ptr<MatchSolution> getBiggerSolution(void) const noexcept { + return mSolve.empty() ? nullptr : mSolve[0]; + } inline std::vector<std::shared_ptr<MatchSolution>> getSolutions(void) const noexcept { return mSolve; diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index 7348fa10a96c55914bae68983b5e3bd4a9c40b12..57cd20311a4e4c98966af0af98b9fe4533155ea6 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -40,7 +40,7 @@ public: static const std::string Type; Identity_Op() - : OperatorTensor(Type, 1, 0, 0) + : OperatorTensor(Type, 1, 0, 1) { mImpl = std::make_shared<OperatorImpl>(*this); } @@ -101,7 +101,10 @@ public: if (outputIdx >= nbInputs()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs()); } - return mInputs[outputIdx]; + if (mInputs[outputIdx] == nullptr){ + return mOutputs[outputIdx]; // Input is not initialized with empty tensor + } + return mInputs[outputIdx]; // Identity, so Output is Input } void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final { // setBackend do nothing, Identity node has no backend it just pass the same Tensor diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index ee00ead696efe623a4e051994f470a38397777ec..fe9b044e2309eb7e724d6648b84c044d7407bafb 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -24,22 +24,32 @@ namespace Aidge { +enum class ProdAttr { Constant }; + class Producer_Op : public OperatorTensor, public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( - const Producer_Op &)> { + const Producer_Op &)>, + public StaticAttributes<ProdAttr, bool> { public: static const std::string Type; + using Attributes_ = StaticAttributes<ProdAttr, bool>; + template <ProdAttr e> + using attr = typename Attributes_::template attr<e>; + template <std::size_t DIM> - Producer_Op(const std::array<DimSize_t, DIM>& dims) - : OperatorTensor(Type, 0, 0, 1) + Producer_Op(const std::array<DimSize_t, DIM>& dims, + bool constant = false) + : OperatorTensor(Type, 0, 0, 1), + Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0]->resize(dims); } - Producer_Op(const std::shared_ptr<Tensor> tensor) - : OperatorTensor(Type, 0, 0, 1) + Producer_Op(const std::shared_ptr<Tensor> tensor, bool constant = false) + : OperatorTensor(Type, 0, 0, 1), + Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0] = tensor; // copy the pointer of the Tensor } @@ -49,7 +59,8 @@ public: * @param op OperatorTensor to copy. */ Producer_Op(const Producer_Op& op) - : OperatorTensor(op) + : OperatorTensor(op), + Attributes_(op) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) { mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i))); @@ -89,28 +100,41 @@ public: } public: - void forward() override final { - printf("Basic Producer forward() function.\n"); - } - void backward() override final { - printf("Basic Producer backward() function.\n"); - } + void forward() override final { + printf("Basic Producer forward() function.\n"); + } + void backward() override final { + printf("Basic Producer backward() function.\n"); + } + void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override { + if (getAttr<ProdAttr::Constant>()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output."); + } + OperatorTensor::setOutput(outputIdx, std::move(data)); + } + + void setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) override { + if (getAttr<ProdAttr::Constant>()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output."); + } + OperatorTensor::setOutput(outputIdx, data); + } }; template <std::array<DimSize_t, 1>::size_type DIM> -inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "") { +inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "", bool constant = false) { static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported"); - return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name); + return std::make_shared<Node>(std::make_shared<Producer_Op>(dims, constant), name); } // helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <std::size_t DIM> -inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "") { - return Producer(to_array(dims), name); +inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "", bool constant = false) { + return Producer(to_array(dims), name, constant); } -inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor), name); +inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "", bool constant = false) { + return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor, constant), name); } template <std::array<DimSize_t, 1>::size_type DIM> @@ -130,4 +154,10 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim } } // namespace Aidge -#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ \ No newline at end of file +namespace { +template <> +const char *const EnumStrings<Aidge::ProdAttr>::data[] = { + "Constant" +}; +} +#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ diff --git a/include/aidge/recipies/GraphViewHelper.hpp b/include/aidge/recipies/GraphViewHelper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d7bcec713087054640c87c6fd229fee53d1ed4a6 --- /dev/null +++ b/include/aidge/recipies/GraphViewHelper.hpp @@ -0,0 +1,40 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_UTILS_RECIPIES_H_ +#define AIDGE_CORE_UTILS_RECIPIES_H_ + +#include <memory> +#include <set> + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" + + +namespace Aidge { + +/** + * @brief Getter for every Producer operator in a GraphView. + * @param graphview GraphView instance where Producers should be searched. + * @return std::set<std::shared_ptr<Node>> + */ +std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) { + std::set<std::shared_ptr<Node>> res; + const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); + + std::copy_if(nodes.cbegin(), + nodes.cend(), + std::inserter(res, res.begin()), + [](std::shared_ptr<Node> n){ return n->type() == "Producer"; }); + + return res; +} +} // namespace Aidge \ No newline at end of file diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 4f760a65fcade253e1de2bea8ccddbb0369962ec..c948b1ffd414fd1b421c9a842a16982501b5b2e0 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -42,7 +42,7 @@ void addCtor(py::class_<Tensor, std::set<std::string> availableBackends = Tensor::getAvailableBackends(); if (availableBackends.find("cpu") != availableBackends.end()){ newTensor->setBackend("cpu"); - newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr), newTensor->size()); + newTensor->getImpl()->copyFromHost(static_cast<T*>(info.ptr), newTensor->size()); }else{ printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); } @@ -95,7 +95,9 @@ void init_Tensor(py::module& m){ case DataType::Float32: return py::cast(b.get<float>(idx)); case DataType::Int32: - return py::cast(b.get<int>(idx)); + return py::cast(b.get<std::int32_t>(idx)); + case DataType::Int64: + return py::cast(b.get<std::int64_t>(idx)); default: return py::none(); } @@ -108,7 +110,9 @@ void init_Tensor(py::module& m){ case DataType::Float32: return py::cast(b.get<float>(coordIdx)); case DataType::Int32: - return py::cast(b.get<int>(coordIdx)); + return py::cast(b.get<std::int32_t>(coordIdx)); + case DataType::Int64: + return py::cast(b.get<std::int64_t>(coordIdx)); default: return py::none(); } @@ -137,7 +141,10 @@ void init_Tensor(py::module& m){ dataFormatDescriptor = py::format_descriptor<float>::format(); break; case DataType::Int32: - dataFormatDescriptor = py::format_descriptor<int>::format(); + dataFormatDescriptor = py::format_descriptor<std::int32_t>::format(); + break; + case DataType::Int64: + dataFormatDescriptor = py::format_descriptor<std::int64_t>::format(); break; default: throw py::value_error("Unsupported data format"); @@ -155,7 +162,8 @@ void init_Tensor(py::module& m){ // TODO : If the ctor with the right data type does not exist, pybind will always convert the data to INT ! // Need to find a way to avoid this ! - addCtor<int>(pyClassTensor); + addCtor<std::int32_t>(pyClassTensor); + addCtor<std::int64_t>(pyClassTensor); addCtor<float>(pyClassTensor); // #if SIZE_MAX != 0xFFFFFFFF addCtor<double>(pyClassTensor); diff --git a/python_binding/graphRegex/pybind_GraphRegex.cpp b/python_binding/graphRegex/pybind_GraphRegex.cpp index be3cd9e9124ba1306226dcbdc13ee39748cf0606..921f204d017f90823b9c4a1a024efa271a4691e5 100644 --- a/python_binding/graphRegex/pybind_GraphRegex.cpp +++ b/python_binding/graphRegex/pybind_GraphRegex.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/functional.h> #include "aidge/graphRegex/GraphRegex.hpp" namespace py = pybind11; @@ -20,7 +21,7 @@ void init_GraphRegex(py::module& m){ py::class_<GraphRegex, std::shared_ptr<GraphRegex>>(m, "GraphRegex", "GraphRegex class describes a regex to test a graph.") .def(py::init<>()) - .def("add_query", &GraphRegex::addQuery, R"mydelimiter( + .def("add_query", &GraphRegex::addQuery, py::arg("query"), py::arg("f") = nullptr, R"mydelimiter( :rtype: str )mydelimiter") @@ -47,10 +48,10 @@ void init_GraphRegex(py::module& m){ Add a node test :param key: the key of the node test to use in the query. :param conditionalExpressions: the test to do . - + )mydelimiter") - + .def("set_node_key", (void (GraphRegex::*)(const std::string, std::function<bool(NodePtr)>)) & GraphRegex::setNodeKey, @@ -59,7 +60,7 @@ void init_GraphRegex(py::module& m){ Add a node test :param key: the key of the lambda test to use in the conditional expressions. :param f: bool lambda (nodePtr) . - + )mydelimiter") diff --git a/python_binding/graphRegex/pybind_MatchSolution.cpp b/python_binding/graphRegex/pybind_MatchSolution.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81d39f86ed6da5b7b63a2d43b3a4fcbb2f8e9043 --- /dev/null +++ b/python_binding/graphRegex/pybind_MatchSolution.cpp @@ -0,0 +1,40 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_MatchSolution(py::module& m){ + + + py::class_<MatchSolution, std::shared_ptr<MatchSolution>>(m, "MatchSolution", "MatchSolution class contains the result of one match and the associated key, the query and the start node.") + .def("at", &MatchSolution::at, py::arg("key"), + R"mydelimiter( + :rtype: str + )mydelimiter") + + .def("get_all", &MatchSolution::getAll, + R"mydelimiter( + )mydelimiter") + + .def("get_query", &MatchSolution::getQuery, + R"mydelimiter( + )mydelimiter") + + .def("get_start_node", &MatchSolution::getStartNode, + R"mydelimiter( + )mydelimiter") + ; +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index f5c5145e0a86d939b96e6d2a579dfa2579f8b3a5..b043ac23c378b9d591b7d1273ebcb5d48a37394a 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -122,7 +122,7 @@ void init_MetaOperatorDefs(py::module &m) { declare_PaddedMaxPoolingOp<2>(m); declare_PaddedMaxPoolingOp<3>(m); - py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, Operator>(m, "MetaOperator_Op", py::multiple_inheritance()) + py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance()) .def("get_micro_graph", &MetaOperator_Op::getMicroGraph); m.def("meta_operator", &MetaOperator, diff --git a/python_binding/operator/pybind_OperatorTensor.cpp b/python_binding/operator/pybind_OperatorTensor.cpp index ce34dea158e6df1466db415b2539962c2113d42b..386a3af6c7c6e9dfad34ec2e56189a53797b59d9 100644 --- a/python_binding/operator/pybind_OperatorTensor.cpp +++ b/python_binding/operator/pybind_OperatorTensor.cpp @@ -21,6 +21,9 @@ void init_OperatorTensor(py::module& m){ py::class_<OperatorTensor, std::shared_ptr<OperatorTensor>, Operator>(m, "OperatorTensor") .def("get_output", &OperatorTensor::getOutput, py::arg("outputIdx")) .def("get_input", &OperatorTensor::getInput, py::arg("inputIdx")) + + .def("set_output", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setOutput, py::arg("outputIdx"), py::arg("data")) + .def("set_input", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setInput, py::arg("outputIdx"), py::arg("data")) .def("output_dims_forwarded", &OperatorTensor::outputDimsForwarded) ; } diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 3dae24b620fe99098205d7d5f23591780f1e9cb7..78d9ce3489a8309c42cc90189e588a448fd9649a 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -24,20 +24,20 @@ namespace Aidge { template <DimIdx_t DIM> void declare_Producer(py::module &m) { // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); - m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = ""); + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false); } void init_Producer(py::module &m) { - py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>( + py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor, Attributes>( m, "ProducerOp", py::multiple_inheritance()) .def("dims", &Producer_Op::dims) .def("get_inputs_name", &Producer_Op::getInputsName) .def("get_outputs_name", &Producer_Op::getOutputsName); - m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false); declare_Producer<1>(m); declare_Producer<2>(m); diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 0353953fda39cd6dc283d20a0a3e36659dd891a4..be0d357b7f73e26aad44994f407696f70617ad71 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -56,6 +56,7 @@ void init_OpArgs(py::module&); void init_Connector(py::module&); void init_GraphRegex(py::module&); +void init_MatchSolution(py::module&); void init_Recipies(py::module&); @@ -106,7 +107,9 @@ void init_Aidge(py::module& m){ init_Identity(m); init_Producer(m); + init_GraphRegex(m); + init_MatchSolution(m); init_Recipies(m); init_Scheduler(m); diff --git a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp index ddf6a46cc7c75dc853d71ba98b051b4263a31164..7a09908e5629e299b6b264fbfaac97bdaf7fa316 100644 --- a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp +++ b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp @@ -1,7 +1,7 @@ #include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" #include "aidge/graphRegex/matchFsm/FsmNode.hpp" -using namespace Aidge; +using namespace Aidge; std::vector<std::set<NodePtr>> FsmRunTimeContext::mRejectedNodes; @@ -42,10 +42,6 @@ void FsmRunTimeContext::addRejectedNode(NodePtr node){ mRejectedNodes[mLocalIdxRejeced].insert(node); } -std::set<NodePtr> FsmRunTimeContext::getRejectedNodes(void){ - return mRejectedNodes[mLocalIdxRejeced]; -} - bool FsmRunTimeContext::isOnValidState(void){ return mActState->isValid(); } @@ -57,7 +53,7 @@ bool FsmRunTimeContext::isCommonDefined(NodePtr node){ for(const auto& nodeC : nodes){ if(nodeC.get() == node.get()){ return true; - } + } } return false; } @@ -68,7 +64,7 @@ bool FsmRunTimeContext::isAlreadyValid(NodePtr node){ for(const auto& nodeV : nodes){ if(nodeV.get() == node.get()){ return true; - } + } } return false; @@ -82,7 +78,7 @@ bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmCont and the same idx for the common */ - //common node + //common node for (const auto& ref : getCommon()) { for (const auto& test : fsmContext->getCommon()) { @@ -97,20 +93,15 @@ bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmCont //valid nodes std::set<NodePtr> commonElements; - std::set<NodePtr> A = getValidNodesNoCommon(); - std::set<NodePtr> B = fsmContext->getValidNodesNoCommon(); + std::set<NodePtr> A = getValidNodesNoCommon(); + std::set<NodePtr> B = fsmContext->getValidNodesNoCommon(); std::set_intersection( A.begin(),A.end(), B.begin(), B.end(), std::inserter(commonElements, commonElements.end()) ); - - if (!commonElements.empty()) { - return false; - } - - return true; + return (commonElements.empty()) ? true : false; } bool FsmRunTimeContext::areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext){ @@ -142,7 +133,7 @@ void FsmRunTimeContext::setCommon(NodePtr node,std::size_t commonIdx){ } void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag){ - //we already find a node of this type + //we already find a node of this type if(mValidNodes.find(tag) != mValidNodes.end()){ if(isAlreadyValid(node) && !isCommonDefined(node) ){ throw std::runtime_error("setValid you valid tow time"); @@ -151,7 +142,7 @@ void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpr }else{ mValidNodes[tag] = {node}; } - + } std::size_t FsmRunTimeContext::getSubStmId(void){ diff --git a/src/graphRegex/matchFsm/MatchResult.cpp b/src/graphRegex/matchFsm/MatchResult.cpp index 08be00dea66c66a46dbbf2b225efd0df3f332188..99df00e198a9a30be21e0ad18b4933a40b9b7a06 100644 --- a/src/graphRegex/matchFsm/MatchResult.cpp +++ b/src/graphRegex/matchFsm/MatchResult.cpp @@ -132,8 +132,4 @@ void Aidge::MatchResult::_generateCombination( std::size_t idxSubStm, } return; -} - -std::shared_ptr<Aidge::MatchSolution> Aidge::MatchResult::getBiggerSolution(void){ - return mSolve.empty() ? nullptr : mSolve[0]; } \ No newline at end of file diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp index 443f2fa7d8a60cd25ccb622f2dad5b4926b88eea..7bccbe763b90f2697997a889b30b610e4b531334 100644 --- a/src/operator/Producer.cpp +++ b/src/operator/Producer.cpp @@ -13,4 +13,4 @@ #include "aidge/operator/Producer.hpp" -const std::string Aidge::Producer_Op::Type = "Producer"; \ No newline at end of file +const std::string Aidge::Producer_Op::Type = "Producer";