Skip to content
Snippets Groups Projects
Commit efb68e6a authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Merge branch 'main' into dataloader

parents a71e2361 f2127e1f
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!4Dataloader
Pipeline #37422 passed
Showing
with 402 additions and 260 deletions
......@@ -66,16 +66,39 @@ endif()
target_compile_features(${module_name} PRIVATE cxx_std_14)
if (DOSANITIZE STREQUAL "ON")
set(SANITIZE_FLAGS -fsanitize=address,leak,undefined,float-divide-by-zero -fno-omit-frame-pointer)
#TODO sanitizer seems buggy in some situations with msvc, leading to linker errors, temporarily inactivating it
#set(SANITIZE_MSVC_FLAGS)
set(SANITIZE_MSVC_FLAGS /fsanitize=address)
target_compile_definitions(${module_name} PUBLIC _DISABLE_VECTOR_ANNOTATION)
else()
set(SANITIZE_FLAGS)
set(SANITIZE_MSVC_FLAGS)
endif()
set(STRICT_ALIASING_FLAGS -fstrict-aliasing -Wstrict-aliasing=2)
# -fvisibility=hidden required by pybind11
target_compile_options(${module_name} PUBLIC
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-fvisibility=hidden>)
target_compile_options(${module_name} PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall -Wextra -Wold-style-cast -Winline -pedantic -Werror=narrowing -Wshadow $<$<BOOL:${WERROR}>:-Werror>>)
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall -Wextra -Wold-style-cast -Winline -pedantic -Werror=narrowing -Wshadow $<$<BOOL:${WERROR}>:-Werror> ${SANITIZE_FLAGS}>)
target_compile_options(${module_name} PRIVATE
$<$<CXX_COMPILER_ID:MSVC>:
/W4>)
$<$<CXX_COMPILER_ID:GNU>:${STRICT_ALIASING_FLAGS}>)
target_compile_options(${module_name} PRIVATE
$<$<CXX_COMPILER_ID:MSVC>:
/W4 /wd4477 /DWIN32 /D_WINDOWS /GR /EHsc /MP /Zc:__cplusplus /Zc:preprocessor /permissive- ${SANITIZE_MSVC_FLAGS}>)
if (DOSANITIZE STREQUAL "ON")
target_compile_options(${module_name} PRIVATE $<$<CXX_COMPILER_ID:MSVC>:/MDd>)
endif()
# TODO FIXME: I'm not sure it's a good idea to propagate this option but, at this point, it was the only way that worked to silence C4477
target_compile_options(${module_name} PUBLIC $<$<CXX_COMPILER_ID:MSVC>: /wd4477>)
target_link_options(${module_name} PUBLIC $<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:${SANITIZE_FLAGS}>)
#target_link_options(${module_name} PUBLIC $<$<CXX_COMPILER_ID:MSVC>:${SANITIZE_MSVC_FLAGS}>)
if(CMAKE_COMPILER_IS_GNUCXX AND COVERAGE)
append_coverage_compiler_flags()
......
......@@ -15,7 +15,7 @@
#include <cstring>
#include <set>
#include <memory>
#include <numeric>
#include <numeric> // std::accumulate
#include <string>
#include <vector>
......@@ -341,11 +341,11 @@ class Tensor : public Data,
/**
* @brief Change the dimensions of the Tensor object according to the given argument.
* If the overall size is not changed (meaning we actually only performed a
* If the overall size is not changed (meaning we actually only performed a
* reshape), data is garanteed to remain valid.
* Otherwise, no garantee is provided regarding the validy of previous data
* (unlike std::vector). If the new overall size is larger than the previous
* one, all previous data is invalided. Otherwise, previous data may or may
* Otherwise, no garantee is provided regarding the validy of previous data
* (unlike std::vector). If the new overall size is larger than the previous
* one, all previous data is invalided. Otherwise, previous data may or may
* not remain valid, depending on the backend implementation.
* @tparam DIM Number of dimensions.
* @param dims New dimensions
......@@ -357,11 +357,11 @@ class Tensor : public Data,
/**
* @brief Change the dimensions of the Tensor object according to the given argument.
* If the overall size is not changed (meaning we actually only performed a
* If the overall size is not changed (meaning we actually only performed a
* reshape), data is garanteed to remain valid.
* Otherwise, no garantee is provided regarding the validy of previous data
* (unlike std::vector). If the new overall size is larger than the previous
* one, all previous data is invalided. Otherwise, previous data may or may
* Otherwise, no garantee is provided regarding the validy of previous data
* (unlike std::vector). If the new overall size is larger than the previous
* one, all previous data is invalided. Otherwise, previous data may or may
* not remain valid, depending on the backend implementation.
* @param dims New dimensions
*/
......@@ -438,7 +438,7 @@ class Tensor : public Data,
return std::string("?"); // To make Clang happy
};
if (dims().empty()) { return "{}"; }
if (dims().empty()) { return ptrToString(mDataType, mImpl->hostPtr(), 0); }
std::string res;
std::size_t dim = 0;
std::size_t counter = 0;
......@@ -560,22 +560,22 @@ class Tensor : public Data,
/**
* Copy-cast data from a Tensor.
* @param src Source tensor to copy-cast from.
* @param movedSrc shared_ptr to an indermediate Tensor that will
* contain the moved data if a device change should occur AND a type
* @param movedSrc shared_ptr to an indermediate Tensor that will
* contain the moved data if a device change should occur AND a type
* conversion is necessary (otherwise it remains unused).
* Any data already present will be overwritten. No new memory allocation
* will occur if movedSrc has already been allocated with the right
* Any data already present will be overwritten. No new memory allocation
* will occur if movedSrc has already been allocated with the right
* type/size/device.
* If required, memory is always allocated on current (destination)
* If required, memory is always allocated on current (destination)
* Tensor's device.
*/
void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc);
/**
* Copy-cast data from a Tensor.
* In case of both a device change AND a data type conversion, an
* In case of both a device change AND a data type conversion, an
* intermediate buffer on will be allocated and deallocated each time.
* If required, buffer's memory is always allocated on current (destination)
* If required, buffer's memory is always allocated on current (destination)
* Tensor's device.
* @param src Source tensor to copy-cast from.
*/
......@@ -593,7 +593,7 @@ class Tensor : public Data,
* The backend stays the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param dt The desired data type.
* @return Reference to either itself or to fallback.
......@@ -608,7 +608,7 @@ class Tensor : public Data,
* The data type stays the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param backend The desired backend.
* @param device The desired device.
......@@ -621,11 +621,11 @@ class Tensor : public Data,
* Return a reference to a Tensor on desired data type and backend/device:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* If required, fallback is always allocated on desired (destination)
* If required, fallback is always allocated on desired (destination)
* device.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param dt The desired data type.
* @param backend The desired backend.
......@@ -642,11 +642,11 @@ class Tensor : public Data,
* (data type, backend/device) as targetReqs Tensor:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* If required, fallback is always allocated on current (destination)
* If required, fallback is always allocated on current (destination)
* Tensor's device.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param targetReqs Tensor with the desired target characteristics.
* @return Reference to either itself or to fallback.
......@@ -658,15 +658,8 @@ class Tensor : public Data,
private:
///\bug not protected against overflow
std::size_t computeSize() {
if (mDims.empty()) {
mSize = DimSize_t(0);
}
else {
mSize = std::accumulate(mDims.begin(), mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>());
}
return mSize;
void computeSize() {
mSize = std::accumulate(mDims.begin(), mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>());
}
};
} // namespace Aidge
......
......@@ -140,7 +140,7 @@ public:
/**
* @brief List of pair <Parent, ID of the data intput>. When an input is not
* linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* Data inputs exclude inputs expecting parameters (weights or bias).
* @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
*/
......
......@@ -10,164 +10,162 @@
#include "aidge/graph/Node.hpp"
namespace Aidge{
class FsmNode;
class FsmNode;
/**
* @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph
* all node that have been Validate,Rejecte or Considered common
*/
class FsmRunTimeContext
{
private:
/**
* @brief the list of node rejected for all the context
*/
static std::vector<std::set<NodePtr>> mRejectedNodes;
/**
* @brief the actual state of this Context (where it's in the FSM graph)
*/
std::shared_ptr<FsmNode> mActState;
/**
* @brief the actual node of this Context (where it's in the graph)
*/
NodePtr mActOpNode;
/**
* @brief the map of the node consider as common and the common ID
* @details we need to store what node it's consider as common because of the end
* resolution of the matching, all node consider as common need to be the same in all context
*/
std::map<NodePtr,std::size_t> mCommonNodes;
/**
* @brief the map of the node that as been valid in this context , and the test that valide the node
*/
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes;
/**
* @brief the index in the rejected node of this context
*/
std::size_t mLocalIdxRejeced;
public:
/**
* @brief constructor
* @param actState the actual state in the FSM
* @param actOpNode the actual node in the graph
* @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind
*/
FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() );
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime);
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode );
virtual ~FsmRunTimeContext()=default;
/**
* @defgroup FsmRunTimeContextRejected Function for managing rejected nodes
*/
/**
* @ingroup FsmRunTimeContextRejected
* @brief Add a node as rejected in this context
*/
void addRejectedNode(NodePtr node);
/**
* @ingroup FsmRunTimeContextRejected
* @brief get the rejected nodes of this context
*/
inline std::set<NodePtr> getRejectedNodes(void) const {
return mRejectedNodes[mLocalIdxRejeced];
}
/**
* @defgroup FsmRunTimeContextTest Function for test the context
*/
class FsmNode;
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the actual state is valide
* @return bool
*/
bool isOnValidState(void);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the node is considered as common in this context
* @param node node to test
* @return bool
*/
bool isCommonDefined(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if has already validated in this context
* @param node node to test
* @return bool
*/
bool isAlreadyValid(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is compatible with an others
* @details to say that two contexts are compatible is to check :
* that the contexts do not validate the same nodes (other than the common ones)
* and that the common ones have the same idx
* @param fsmContext the others context
* @return bool
*/
bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is strictly equal with an others
* @param fsmContext the others context
* @return bool
*/
bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph
* all node that have been Validate,Rejecte or Considered common
* @defgroup FsmRunTimeContextSet Function set context
*/
class FsmRunTimeContext
{
private:
/**
* @brief the list of node rejected for all the context
*/
static std::vector<std::set<NodePtr>> mRejectedNodes;
/**
* @brief the actual state of this Context (where it's in the FSM graph)
*/
std::shared_ptr<FsmNode> mActState;
/**
* @brief the actual node of this Context (where it's in the graph)
*/
NodePtr mActOpNode;
/**
* @brief the map of the node consider as common and the common ID
* @details we need to store what node it's consider as common because of the end
* resolution of the matching, all node consider as common need to be the same in all context
*/
std::map<NodePtr,std::size_t> mCommonNodes;
/**
* @brief the map of the node that as been valid in this context , and the test that valide the node
*/
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes;
/**
* @brief the index in the rejected node of this context
*/
std::size_t mLocalIdxRejeced;
public:
/**
* @brief constructor
* @param actState the actual state in the FSM
* @param actOpNode the actual node in the graph
* @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind
*/
FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() );
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime);
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode );
virtual ~FsmRunTimeContext()=default;
/**
* @defgroup FsmRunTimeContextRejected Function for managing rejected nodes
*/
/**
* @ingroup FsmRunTimeContextRejected
* @brief Add a node as rejected in this context
*/
void addRejectedNode(NodePtr node);
/**
* @ingroup FsmRunTimeContextRejected
* @brief get the rejected nodes of this context
*/
std::set<NodePtr> getRejectedNodes(void);
/**
* @defgroup FsmRunTimeContextTest Function for test the context
*/
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the actual state is valide
* @return bool
*/
bool isOnValidState(void);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the node is considered as common in this context
* @param node node to test
* @return bool
*/
bool isCommonDefined(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if has already validated in this context
* @param node node to test
* @return bool
*/
bool isAlreadyValid(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is compatible with an others
* @details to say that two contexts are compatible is to check :
* that the contexts do not validate the same nodes (other than the common ones)
* and that the common ones have the same idx
* @param fsmContext the others context
* @return bool
*/
bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is strictly equal with an others
* @param fsmContext the others context
* @return bool
*/
bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @defgroup FsmRunTimeContextSet Function set context
*/
void setCommon(NodePtr node,std::size_t commonIdx);
void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag);
/**
* @defgroup FsmRunTimeContextGet Function get context
*/
/**
* @ingroup FsmRunTimeContextGet
* @brief get the sub idx state
* @return bool
*/
std::size_t getSubStmId(void);
NodePtr getCommonNodeFromIdx(std::size_t commonIdx);
std::size_t getCommonNodeIdx(NodePtr node);
std::set<NodePtr> getCommonNodes(void);
std::map<NodePtr,std::size_t> getCommon(void);
std::set<NodePtr> getValidNodes(void);
std::set<NodePtr> getValidNodesNoCommon(void);
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void);
NodePtr getActNode(void);
std::shared_ptr<FsmNode> getActState(void);
/**
* @defgroup FsmRunTimeContextMem
*/
void rst(void);
};
}
#endif //AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_
void setCommon(NodePtr node,std::size_t commonIdx);
void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag);
/**
* @defgroup FsmRunTimeContextGet Function get context
*/
/**
* @ingroup FsmRunTimeContextGet
* @brief get the sub idx state
* @return bool
*/
std::size_t getSubStmId(void);
NodePtr getCommonNodeFromIdx(std::size_t commonIdx);
std::size_t getCommonNodeIdx(NodePtr node);
std::set<NodePtr> getCommonNodes(void);
std::map<NodePtr,std::size_t> getCommon(void);
std::set<NodePtr> getValidNodes(void);
std::set<NodePtr> getValidNodesNoCommon(void);
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void);
NodePtr getActNode(void);
std::shared_ptr<FsmNode> getActState(void);
/**
* @defgroup FsmRunTimeContextMem
*/
void rst(void);
};
} // namespace Aidge
#endif // AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_
......@@ -24,8 +24,10 @@ private:
const std::vector<NodePtr> mStartNode;
public:
MatchSolution() = delete;
MatchSolution(std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string query,const std::vector<NodePtr> startNode);
inline const std::set<NodePtr>& at(const std::string key) {
inline const std::set<NodePtr>& at(const std::string& key) {
return mSolution[key];
}
const std::set<NodePtr> getAll();
......@@ -33,7 +35,6 @@ public:
inline const std::string& getQuery() const noexcept { return mQueryFrom; }
inline const std::vector<NodePtr>& getStartNode() const noexcept { return mStartNode; }
};
......@@ -60,14 +61,18 @@ private:
public:
MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm,
const std::string& query,const std::vector<NodePtr>& startNodes);
MatchResult() = delete;
MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid,
std::size_t nbSubStm,
const std::string& query,const std::vector<NodePtr>& startNodes);
/**
* @brief get the set of the node match for une expression
* @return the set of node of the graph that corresponding to an expression
*/
std::shared_ptr<MatchSolution> getBiggerSolution(void);
inline std::shared_ptr<MatchSolution> getBiggerSolution(void) const noexcept {
return mSolve.empty() ? nullptr : mSolve[0];
}
inline std::vector<std::shared_ptr<MatchSolution>> getSolutions(void) const noexcept {
return mSolve;
......
......@@ -40,7 +40,7 @@ public:
static const std::string Type;
Identity_Op()
: OperatorTensor(Type, 1, 0, 0)
: OperatorTensor(Type, 1, 0, 1)
{
mImpl = std::make_shared<OperatorImpl>(*this);
}
......@@ -101,7 +101,10 @@ public:
if (outputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbInputs());
}
return mInputs[outputIdx];
if (mInputs[outputIdx] == nullptr){
return mOutputs[outputIdx]; // Input is not initialized with empty tensor
}
return mInputs[outputIdx]; // Identity, so Output is Input
}
void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final {
// setBackend do nothing, Identity node has no backend it just pass the same Tensor
......
......@@ -24,22 +24,32 @@
namespace Aidge {
enum class ProdAttr { Constant };
class Producer_Op
: public OperatorTensor,
public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>(
const Producer_Op &)> {
const Producer_Op &)>,
public StaticAttributes<ProdAttr, bool> {
public:
static const std::string Type;
using Attributes_ = StaticAttributes<ProdAttr, bool>;
template <ProdAttr e>
using attr = typename Attributes_::template attr<e>;
template <std::size_t DIM>
Producer_Op(const std::array<DimSize_t, DIM>& dims)
: OperatorTensor(Type, 0, 0, 1)
Producer_Op(const std::array<DimSize_t, DIM>& dims,
bool constant = false)
: OperatorTensor(Type, 0, 0, 1),
Attributes_(attr<ProdAttr::Constant>(constant))
{
mOutputs[0]->resize(dims);
}
Producer_Op(const std::shared_ptr<Tensor> tensor)
: OperatorTensor(Type, 0, 0, 1)
Producer_Op(const std::shared_ptr<Tensor> tensor, bool constant = false)
: OperatorTensor(Type, 0, 0, 1),
Attributes_(attr<ProdAttr::Constant>(constant))
{
mOutputs[0] = tensor; // copy the pointer of the Tensor
}
......@@ -49,7 +59,8 @@ public:
* @param op OperatorTensor to copy.
*/
Producer_Op(const Producer_Op& op)
: OperatorTensor(op)
: OperatorTensor(op),
Attributes_(op)
{
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i)));
......@@ -89,28 +100,41 @@ public:
}
public:
void forward() override final {
printf("Basic Producer forward() function.\n");
}
void backward() override final {
printf("Basic Producer backward() function.\n");
}
void forward() override final {
printf("Basic Producer forward() function.\n");
}
void backward() override final {
printf("Basic Producer backward() function.\n");
}
void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override {
if (getAttr<ProdAttr::Constant>()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
}
OperatorTensor::setOutput(outputIdx, std::move(data));
}
void setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) override {
if (getAttr<ProdAttr::Constant>()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
}
OperatorTensor::setOutput(outputIdx, data);
}
};
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "") {
inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "", bool constant = false) {
static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported");
return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name);
return std::make_shared<Node>(std::make_shared<Producer_Op>(dims, constant), name);
}
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <std::size_t DIM>
inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "") {
return Producer(to_array(dims), name);
inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "", bool constant = false) {
return Producer(to_array(dims), name, constant);
}
inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor), name);
inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "", bool constant = false) {
return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor, constant), name);
}
template <std::array<DimSize_t, 1>::size_type DIM>
......@@ -130,4 +154,10 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
\ No newline at end of file
namespace {
template <>
const char *const EnumStrings<Aidge::ProdAttr>::data[] = {
"Constant"
};
}
#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_UTILS_RECIPIES_H_
#define AIDGE_CORE_UTILS_RECIPIES_H_
#include <memory>
#include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge {
/**
* @brief Getter for every Producer operator in a GraphView.
* @param graphview GraphView instance where Producers should be searched.
* @return std::set<std::shared_ptr<Node>>
*/
std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
std::copy_if(nodes.cbegin(),
nodes.cend(),
std::inserter(res, res.begin()),
[](std::shared_ptr<Node> n){ return n->type() == "Producer"; });
return res;
}
} // namespace Aidge
\ No newline at end of file
......@@ -42,7 +42,7 @@ void addCtor(py::class_<Tensor,
std::set<std::string> availableBackends = Tensor::getAvailableBackends();
if (availableBackends.find("cpu") != availableBackends.end()){
newTensor->setBackend("cpu");
newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr), newTensor->size());
newTensor->getImpl()->copyFromHost(static_cast<T*>(info.ptr), newTensor->size());
}else{
printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n");
}
......@@ -95,7 +95,9 @@ void init_Tensor(py::module& m){
case DataType::Float32:
return py::cast(b.get<float>(idx));
case DataType::Int32:
return py::cast(b.get<int>(idx));
return py::cast(b.get<std::int32_t>(idx));
case DataType::Int64:
return py::cast(b.get<std::int64_t>(idx));
default:
return py::none();
}
......@@ -108,7 +110,9 @@ void init_Tensor(py::module& m){
case DataType::Float32:
return py::cast(b.get<float>(coordIdx));
case DataType::Int32:
return py::cast(b.get<int>(coordIdx));
return py::cast(b.get<std::int32_t>(coordIdx));
case DataType::Int64:
return py::cast(b.get<std::int64_t>(coordIdx));
default:
return py::none();
}
......@@ -137,7 +141,10 @@ void init_Tensor(py::module& m){
dataFormatDescriptor = py::format_descriptor<float>::format();
break;
case DataType::Int32:
dataFormatDescriptor = py::format_descriptor<int>::format();
dataFormatDescriptor = py::format_descriptor<std::int32_t>::format();
break;
case DataType::Int64:
dataFormatDescriptor = py::format_descriptor<std::int64_t>::format();
break;
default:
throw py::value_error("Unsupported data format");
......@@ -155,7 +162,8 @@ void init_Tensor(py::module& m){
// TODO : If the ctor with the right data type does not exist, pybind will always convert the data to INT !
// Need to find a way to avoid this !
addCtor<int>(pyClassTensor);
addCtor<std::int32_t>(pyClassTensor);
addCtor<std::int64_t>(pyClassTensor);
addCtor<float>(pyClassTensor);
// #if SIZE_MAX != 0xFFFFFFFF
addCtor<double>(pyClassTensor);
......
......@@ -63,12 +63,20 @@ void init_Node(py::module& m) {
)mydelimiter")
.def("add_child",
(void (Node::*)(std::shared_ptr<GraphView>, const IOIndex_t,
std::pair<std::shared_ptr<Node>, IOIndex_t>)) &
Node::addChild,
[](Node &self, std::shared_ptr<GraphView> other_graph, const IOIndex_t out_id=0,
py::object other_in_id = py::none()) {
std::pair<NodePtr, IOIndex_t> cpp_other_in_id;
// Note: PyBind on windows does not support conversion of nullptr -> std::shared_ptr, using this trampoline to change the default arg to a py::none(). If signature change, we would be able to directly bind the function.
if (other_in_id.is_none()) {
cpp_other_in_id = std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex);
}else{
cpp_other_in_id = other_in_id.cast<std::pair<NodePtr, IOIndex_t>>();
}
self.addChild(other_graph, out_id, cpp_other_in_id);
},
py::arg("other_graph"), py::arg("out_id") = 0,
py::arg("other_in_id") =
std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex),
py::arg("other_in_id") = py::none(),
R"mydelimiter(
Link a Node from a specific GraphView to the current Node.
......
......@@ -10,6 +10,7 @@
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include "aidge/graphRegex/GraphRegex.hpp"
namespace py = pybind11;
......@@ -20,7 +21,7 @@ void init_GraphRegex(py::module& m){
py::class_<GraphRegex, std::shared_ptr<GraphRegex>>(m, "GraphRegex", "GraphRegex class describes a regex to test a graph.")
.def(py::init<>())
.def("add_query", &GraphRegex::addQuery, R"mydelimiter(
.def("add_query", &GraphRegex::addQuery, py::arg("query"), py::arg("f") = nullptr, R"mydelimiter(
:rtype: str
)mydelimiter")
......@@ -47,10 +48,10 @@ void init_GraphRegex(py::module& m){
Add a node test
:param key: the key of the node test to use in the query.
:param conditionalExpressions: the test to do .
)mydelimiter")
.def("set_node_key",
(void (GraphRegex::*)(const std::string, std::function<bool(NodePtr)>)) &
GraphRegex::setNodeKey,
......@@ -59,7 +60,7 @@ void init_GraphRegex(py::module& m){
Add a node test
:param key: the key of the lambda test to use in the conditional expressions.
:param f: bool lambda (nodePtr) .
)mydelimiter")
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/graphRegex/matchFsm/MatchResult.hpp"
namespace py = pybind11;
namespace Aidge {
void init_MatchSolution(py::module& m){
py::class_<MatchSolution, std::shared_ptr<MatchSolution>>(m, "MatchSolution", "MatchSolution class contains the result of one match and the associated key, the query and the start node.")
.def("at", &MatchSolution::at, py::arg("key"),
R"mydelimiter(
:rtype: str
)mydelimiter")
.def("get_all", &MatchSolution::getAll,
R"mydelimiter(
)mydelimiter")
.def("get_query", &MatchSolution::getQuery,
R"mydelimiter(
)mydelimiter")
.def("get_start_node", &MatchSolution::getStartNode,
R"mydelimiter(
)mydelimiter")
;
}
} // namespace Aidge
......@@ -122,7 +122,7 @@ void init_MetaOperatorDefs(py::module &m) {
declare_PaddedMaxPoolingOp<2>(m);
declare_PaddedMaxPoolingOp<3>(m);
py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, Operator>(m, "MetaOperator_Op", py::multiple_inheritance())
py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance())
.def("get_micro_graph", &MetaOperator_Op::getMicroGraph);
m.def("meta_operator", &MetaOperator,
......
......@@ -21,6 +21,9 @@ void init_OperatorTensor(py::module& m){
py::class_<OperatorTensor, std::shared_ptr<OperatorTensor>, Operator>(m, "OperatorTensor")
.def("get_output", &OperatorTensor::getOutput, py::arg("outputIdx"))
.def("get_input", &OperatorTensor::getInput, py::arg("inputIdx"))
.def("set_output", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setOutput, py::arg("outputIdx"), py::arg("data"))
.def("set_input", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setInput, py::arg("outputIdx"), py::arg("data"))
.def("output_dims_forwarded", &OperatorTensor::outputDimsForwarded)
;
}
......
......@@ -24,20 +24,20 @@ namespace Aidge {
template <DimIdx_t DIM>
void declare_Producer(py::module &m) {
// m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name"));
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = "");
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false);
}
void init_Producer(py::module &m) {
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>(
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor, Attributes>(
m,
"ProducerOp",
py::multiple_inheritance())
.def("dims", &Producer_Op::dims)
.def("get_inputs_name", &Producer_Op::getInputsName)
.def("get_outputs_name", &Producer_Op::getOutputsName);
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = "");
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false);
declare_Producer<1>(m);
declare_Producer<2>(m);
......
......@@ -58,6 +58,7 @@ void init_OpArgs(py::module&);
void init_Connector(py::module&);
void init_GraphRegex(py::module&);
void init_MatchSolution(py::module&);
void init_Recipies(py::module&);
......@@ -110,7 +111,9 @@ void init_Aidge(py::module& m){
init_Identity(m);
init_Producer(m);
init_GraphRegex(m);
init_MatchSolution(m);
init_Recipies(m);
init_Scheduler(m);
......
......@@ -40,7 +40,7 @@ if ($install_reqs)
mkdir -Force build_cpp
mkdir -Force $env:AIDGE_INSTALL_PATH
Set-Location build_cpp
cmake -DCMAKE_INSTALL_PREFIX:PATH=$env:AIDGE_INSTALL_PATH -DCMAKE_BUILD_TYPE=Debug ..
cmake -DCMAKE_INSTALL_PREFIX:PATH=$env:AIDGE_INSTALL_PATH -DCMAKE_BUILD_TYPE=Debug -DDOSANITIZE=OFF ..
if(!$?) { $lastError = $LASTEXITCODE; Set-Location $PSScriptRoot; Exit $lastError }
cmake --build . -j2
if(!$?) { $lastError = $LASTEXITCODE; Set-Location $PSScriptRoot; Exit $lastError }
......
#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp"
#include "aidge/graphRegex/matchFsm/FsmNode.hpp"
using namespace Aidge;
using namespace Aidge;
std::vector<std::set<NodePtr>> FsmRunTimeContext::mRejectedNodes;
......@@ -42,10 +42,6 @@ void FsmRunTimeContext::addRejectedNode(NodePtr node){
mRejectedNodes[mLocalIdxRejeced].insert(node);
}
std::set<NodePtr> FsmRunTimeContext::getRejectedNodes(void){
return mRejectedNodes[mLocalIdxRejeced];
}
bool FsmRunTimeContext::isOnValidState(void){
return mActState->isValid();
}
......@@ -57,7 +53,7 @@ bool FsmRunTimeContext::isCommonDefined(NodePtr node){
for(const auto& nodeC : nodes){
if(nodeC.get() == node.get()){
return true;
}
}
}
return false;
}
......@@ -68,7 +64,7 @@ bool FsmRunTimeContext::isAlreadyValid(NodePtr node){
for(const auto& nodeV : nodes){
if(nodeV.get() == node.get()){
return true;
}
}
}
return false;
......@@ -82,7 +78,7 @@ bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmCont
and the same idx for the common
*/
//common node
//common node
for (const auto& ref : getCommon()) {
for (const auto& test : fsmContext->getCommon()) {
......@@ -97,20 +93,15 @@ bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmCont
//valid nodes
std::set<NodePtr> commonElements;
std::set<NodePtr> A = getValidNodesNoCommon();
std::set<NodePtr> B = fsmContext->getValidNodesNoCommon();
std::set<NodePtr> A = getValidNodesNoCommon();
std::set<NodePtr> B = fsmContext->getValidNodesNoCommon();
std::set_intersection(
A.begin(),A.end(),
B.begin(), B.end(),
std::inserter(commonElements, commonElements.end())
);
if (!commonElements.empty()) {
return false;
}
return true;
return (commonElements.empty()) ? true : false;
}
bool FsmRunTimeContext::areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext){
......@@ -142,7 +133,7 @@ void FsmRunTimeContext::setCommon(NodePtr node,std::size_t commonIdx){
}
void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag){
//we already find a node of this type
//we already find a node of this type
if(mValidNodes.find(tag) != mValidNodes.end()){
if(isAlreadyValid(node) && !isCommonDefined(node) ){
throw std::runtime_error("setValid you valid tow time");
......@@ -151,7 +142,7 @@ void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpr
}else{
mValidNodes[tag] = {node};
}
}
std::size_t FsmRunTimeContext::getSubStmId(void){
......
......@@ -132,8 +132,4 @@ void Aidge::MatchResult::_generateCombination( std::size_t idxSubStm,
}
return;
}
std::shared_ptr<Aidge::MatchSolution> Aidge::MatchResult::getBiggerSolution(void){
return mSolve.empty() ? nullptr : mSolve[0];
}
\ No newline at end of file
......@@ -13,4 +13,4 @@
#include "aidge/operator/Producer.hpp"
const std::string Aidge::Producer_Op::Type = "Producer";
\ No newline at end of file
const std::string Aidge::Producer_Op::Type = "Producer";
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment