Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_core
  • hrouis/aidge_core
  • mszczep/aidge_core
  • oantoni/aidge_core
  • cguillon/aidge_core
  • jeromeh/aidge_core
  • axelfarr/aidge_core
  • cmoineau/aidge_core
  • noamzerah/aidge_core
  • lrakotoarivony/aidge_core
  • silvanosky/aidge_core
  • maab05/aidge_core
  • mick94/aidge_core
  • lucaslopez/aidge_core_ll
  • wboussella/aidge_core
  • farnez/aidge_core
  • mnewson/aidge_core
17 results
Show changes
Commits on Source (78)
Showing
with 3778 additions and 483 deletions
# Version 0.1.0 (January 23, 2024)
Initial release
......@@ -66,16 +66,39 @@ endif()
target_compile_features(${module_name} PRIVATE cxx_std_14)
if (DOSANITIZE STREQUAL "ON")
set(SANITIZE_FLAGS -fsanitize=address,leak,undefined,float-divide-by-zero -fno-omit-frame-pointer)
#TODO sanitizer seems buggy in some situations with msvc, leading to linker errors, temporarily inactivating it
#set(SANITIZE_MSVC_FLAGS)
set(SANITIZE_MSVC_FLAGS /fsanitize=address)
target_compile_definitions(${module_name} PUBLIC _DISABLE_VECTOR_ANNOTATION)
else()
set(SANITIZE_FLAGS)
set(SANITIZE_MSVC_FLAGS)
endif()
set(STRICT_ALIASING_FLAGS -fstrict-aliasing -Wstrict-aliasing=2)
# -fvisibility=hidden required by pybind11
target_compile_options(${module_name} PUBLIC
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-fvisibility=hidden>)
target_compile_options(${module_name} PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall -Wextra -Wold-style-cast -Winline -pedantic -Werror=narrowing -Wshadow $<$<BOOL:${WERROR}>:-Werror>>)
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall -Wextra -Wold-style-cast -Winline -pedantic -Werror=narrowing -Wshadow $<$<BOOL:${WERROR}>:-Werror> ${SANITIZE_FLAGS}>)
target_compile_options(${module_name} PRIVATE
$<$<CXX_COMPILER_ID:MSVC>:
/W4>)
$<$<CXX_COMPILER_ID:GNU>:${STRICT_ALIASING_FLAGS}>)
target_compile_options(${module_name} PRIVATE
$<$<CXX_COMPILER_ID:MSVC>:
/W4 /wd4477 /DWIN32 /D_WINDOWS /GR /EHsc /MP /Zc:__cplusplus /Zc:preprocessor /permissive- ${SANITIZE_MSVC_FLAGS}>)
if (DOSANITIZE STREQUAL "ON")
target_compile_options(${module_name} PRIVATE $<$<CXX_COMPILER_ID:MSVC>:/MDd>)
endif()
# TODO FIXME: I'm not sure it's a good idea to propagate this option but, at this point, it was the only way that worked to silence C4477
target_compile_options(${module_name} PUBLIC $<$<CXX_COMPILER_ID:MSVC>: /wd4477>)
target_link_options(${module_name} PUBLIC $<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:${SANITIZE_FLAGS}>)
#target_link_options(${module_name} PUBLIC $<$<CXX_COMPILER_ID:MSVC>:${SANITIZE_MSVC_FLAGS}>)
if(CMAKE_COMPILER_IS_GNUCXX AND COVERAGE)
append_coverage_compiler_flags()
......
......@@ -14,29 +14,154 @@
#include <cstddef>
#include <cstdio>
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge {
/**
* This is a thin wrapper around std::any that can only hold pointers.
* It also handles the case where a U* pointer is stored and a const U* pointer
* is requested, which is legit (std::any would throw a bad_cast exception in
* this case).
* Note: not used yet, put in reserve here for possible future use.
*/
/*
class AnyPtr {
public:
template <typename T, typename = std::enable_if_t<std::is_pointer<T>::value>>
constexpr inline AnyPtr(T value) : data(value), ptrToConst(std::is_const<std::remove_pointer_t<T>>::value) {}
// Requested T is "U*"
template <typename T, typename std::enable_if<std::is_same<std::remove_pointer_t<T>, std::remove_const_t<std::remove_pointer_t<T>>>::value>::type* = nullptr>
constexpr inline T get() const {
// data has to be "U*"
return future_std::any_cast<T>(data);
}
// Requested T is "const U*"
template <typename T, typename std::enable_if<!std::is_same<std::remove_pointer_t<T>, std::remove_const_t<std::remove_pointer_t<T>>>::value>::type* = nullptr>
constexpr inline T get() const {
if (ptrToConst) {
// data is "const U*" => OK, no bad cast
return future_std::any_cast<T>(data);
}
else {
// data is "U*" => need to remove const from request to avoid bad cast
return future_std::any_cast<std::add_pointer_t<std::remove_const_t<std::remove_pointer_t<T>>>>(data);
}
}
private:
const future_std::any data;
const bool ptrToConst;
};
*/
/**
* This class manages the raw data storage of a Tensor and provide generic copy
* primitives from other devices and from/to host.
* It can own the data or not (use setRawPtr() to set an external data owner).
* It only knows the data type and data capacity, but does not handle anything else.
*/
class TensorImpl {
public:
TensorImpl() = delete;
TensorImpl(const char *backend) : mBackend(backend){};
virtual void copy(const void *src, NbElts_t length) = 0;
virtual void *rawPtr() = 0;
virtual void setRawPtr(void* /*ptr*/)
TensorImpl(const char *backend, DeviceIdx_t device = 0) : mBackend(backend), mDevice(device){};
/**
* Return the (backend, device) pair for this implementation.
*/
std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); }
/**
* Set the device ID for current backend.
* @param device New device ID on current backend.
*/
virtual void setDevice(DeviceIdx_t device) = 0;
/**
* Copy data from the same device.
* @param src Pointer on current implementation device.
* @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/
virtual void copy(const void *src, NbElts_t length, NbElts_t offset = 0) = 0;
/**
* Copy-convert data from the same device.
* @param srcDt Source data type.
* @param src Pointer on current implementation device.
* @param length Number of elements to copy.
*/
virtual void copyCast(const void *src, NbElts_t length, const DataType srcDt) = 0;
/**
* Copy data from an other device on the same backend.
* @param device (backend, device) pair to copy from. The backend must match current implementation backend.
* @param src Pointer on current implementation backend.
* @param length Number of elements to copy.
*/
virtual void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) = 0;
/**
* Copy data from host.
* @param src Host pointer to copy from.
* @param length Number of elements to copy.
*/
virtual void copyFromHost(const void *src, NbElts_t length) = 0;
/**
* Copy data to host.
* @param src Host pointer to copy to.
* @param length Number of elements to copy.
*/
virtual void copyToHost(void *dst, NbElts_t length) const = 0;
/**
* Return the raw device pointer.
* The raw pointer is garanteed to be valid only on the *same* device.
* @param offset Offset, in number of elements.
*/
virtual void* rawPtr(NbElts_t offset = 0) = 0;
virtual const void* rawPtr(NbElts_t offset = 0) const = 0;
/**
* Return the host pointer.
* If the implementation does not have a valid host pointer, nullptr is returned.
* @param offset Offset, in number of elements.
*/
virtual void* hostPtr(NbElts_t /*offset*/ = 0) { return nullptr; };
virtual const void* hostPtr(NbElts_t /*offset*/ = 0) const { return nullptr; };
/**
* Sets the device pointer. The previously owned data is deleted.
* UNSAFE: directly setting the device pointer may lead to undefined behavior
* if it does not match the required storage.
* @param ptr A valid device pointer.
* @param length Storage capacity at the provided pointer
*/
virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/)
{
printf("Cannot set raw pointer for backend %s\n", mBackend);
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
};
virtual void* getRaw(std::size_t /*idx*/)=0;
virtual std::size_t size() const = 0; // Storage size
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes)
constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default;
virtual bool operator==(const TensorImpl &othImpl) const = 0;
private:
/**
* Copy from another backend.
* @param srcImpl Source TensorImpl to copy from.
* @param length Number of elements of size scalarSize() to copy
*/
void copyFrom(const TensorImpl& srcImpl, NbElts_t length);
protected:
const char *mBackend;
DeviceIdx_t mDevice;
};
} // namespace Aidge
......
......@@ -12,6 +12,7 @@
#ifndef AIDGE_DATA_H_
#define AIDGE_DATA_H_
#include "aidge/data/half.hpp"
#include "aidge/utils/Attributes.hpp"
namespace Aidge {
......@@ -61,8 +62,15 @@ namespace {
template <typename T> struct NativeType { static const Aidge::DataType type; };
template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64;
template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32;
template <> const Aidge::DataType NativeType<long>::type = Aidge::DataType::Int64;
template <> const Aidge::DataType NativeType<int>::type = Aidge::DataType::Int32;
template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16;
template <> const Aidge::DataType NativeType<int8_t>::type = Aidge::DataType::Int8;
template <> const Aidge::DataType NativeType<int16_t>::type = Aidge::DataType::Int16;
template <> const Aidge::DataType NativeType<int32_t>::type = Aidge::DataType::Int32;
template <> const Aidge::DataType NativeType<int64_t>::type = Aidge::DataType::Int64;
template <> const Aidge::DataType NativeType<uint8_t>::type = Aidge::DataType::UInt8;
template <> const Aidge::DataType NativeType<uint16_t>::type = Aidge::DataType::UInt16;
template <> const Aidge::DataType NativeType<uint32_t>::type = Aidge::DataType::UInt32;
template <> const Aidge::DataType NativeType<uint64_t>::type = Aidge::DataType::UInt64;
template <>
const char* const EnumStrings<Aidge::DataType>::data[]
......
This diff is collapsed.
This diff is collapsed.
......@@ -51,7 +51,7 @@ private:
std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
public:
GraphView(std::string name="")
GraphView(const std::string& name="")
: mName(name)
{
// ctor
......@@ -62,7 +62,7 @@ public:
return mNodes == gv.mNodes;
}
NodePtr operator[](std::string name)
NodePtr operator[](const std::string& name)
{
assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
return mNodeRegistry.at(name);
......@@ -96,7 +96,7 @@ public:
* specified location.
* @param path
*/
void save(std::string path, bool verbose = false) const;
void save(std::string path, bool verbose = false, bool showProducers = true) const;
inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end();
......@@ -203,7 +203,7 @@ public:
* If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators.
*/
void compile(const std::string& backend, const Aidge::DataType datatype);
void compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device = 0);
/**
* @brief Compute dimensions of input/output Tensors for each Operator of the
......@@ -212,7 +212,7 @@ public:
void forwardDims();
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string &backend);
void setBackend(const std::string &backend, DeviceIdx_t device = 0);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setDataType(const DataType &datatype);
......
......@@ -140,7 +140,7 @@ public:
/**
* @brief List of pair <Parent, ID of the data intput>. When an input is not
* linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* Data inputs exclude inputs expecting parameters (weights or bias).
* @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
*/
......
......@@ -10,164 +10,162 @@
#include "aidge/graph/Node.hpp"
namespace Aidge{
class FsmNode;
class FsmNode;
/**
* @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph
* all node that have been Validate,Rejecte or Considered common
*/
class FsmRunTimeContext
{
private:
/**
* @brief the list of node rejected for all the context
*/
static std::vector<std::set<NodePtr>> mRejectedNodes;
/**
* @brief the actual state of this Context (where it's in the FSM graph)
*/
std::shared_ptr<FsmNode> mActState;
/**
* @brief the actual node of this Context (where it's in the graph)
*/
NodePtr mActOpNode;
/**
* @brief the map of the node consider as common and the common ID
* @details we need to store what node it's consider as common because of the end
* resolution of the matching, all node consider as common need to be the same in all context
*/
std::map<NodePtr,std::size_t> mCommonNodes;
/**
* @brief the map of the node that as been valid in this context , and the test that valide the node
*/
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes;
/**
* @brief the index in the rejected node of this context
*/
std::size_t mLocalIdxRejeced;
public:
/**
* @brief constructor
* @param actState the actual state in the FSM
* @param actOpNode the actual node in the graph
* @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind
*/
FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() );
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime);
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode );
virtual ~FsmRunTimeContext()=default;
/**
* @defgroup FsmRunTimeContextRejected Function for managing rejected nodes
*/
/**
* @ingroup FsmRunTimeContextRejected
* @brief Add a node as rejected in this context
*/
void addRejectedNode(NodePtr node);
/**
* @ingroup FsmRunTimeContextRejected
* @brief get the rejected nodes of this context
*/
inline std::set<NodePtr> getRejectedNodes(void) const {
return mRejectedNodes[mLocalIdxRejeced];
}
/**
* @defgroup FsmRunTimeContextTest Function for test the context
*/
class FsmNode;
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the actual state is valide
* @return bool
*/
bool isOnValidState(void);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the node is considered as common in this context
* @param node node to test
* @return bool
*/
bool isCommonDefined(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if has already validated in this context
* @param node node to test
* @return bool
*/
bool isAlreadyValid(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is compatible with an others
* @details to say that two contexts are compatible is to check :
* that the contexts do not validate the same nodes (other than the common ones)
* and that the common ones have the same idx
* @param fsmContext the others context
* @return bool
*/
bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is strictly equal with an others
* @param fsmContext the others context
* @return bool
*/
bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph
* all node that have been Validate,Rejecte or Considered common
* @defgroup FsmRunTimeContextSet Function set context
*/
class FsmRunTimeContext
{
private:
/**
* @brief the list of node rejected for all the context
*/
static std::vector<std::set<NodePtr>> mRejectedNodes;
/**
* @brief the actual state of this Context (where it's in the FSM graph)
*/
std::shared_ptr<FsmNode> mActState;
/**
* @brief the actual node of this Context (where it's in the graph)
*/
NodePtr mActOpNode;
/**
* @brief the map of the node consider as common and the common ID
* @details we need to store what node it's consider as common because of the end
* resolution of the matching, all node consider as common need to be the same in all context
*/
std::map<NodePtr,std::size_t> mCommonNodes;
/**
* @brief the map of the node that as been valid in this context , and the test that valide the node
*/
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes;
/**
* @brief the index in the rejected node of this context
*/
std::size_t mLocalIdxRejeced;
public:
/**
* @brief constructor
* @param actState the actual state in the FSM
* @param actOpNode the actual node in the graph
* @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind
*/
FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() );
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime);
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode );
virtual ~FsmRunTimeContext()=default;
/**
* @defgroup FsmRunTimeContextRejected Function for managing rejected nodes
*/
/**
* @ingroup FsmRunTimeContextRejected
* @brief Add a node as rejected in this context
*/
void addRejectedNode(NodePtr node);
/**
* @ingroup FsmRunTimeContextRejected
* @brief get the rejected nodes of this context
*/
std::set<NodePtr> getRejectedNodes(void);
/**
* @defgroup FsmRunTimeContextTest Function for test the context
*/
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the actual state is valide
* @return bool
*/
bool isOnValidState(void);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if the node is considered as common in this context
* @param node node to test
* @return bool
*/
bool isCommonDefined(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if has already validated in this context
* @param node node to test
* @return bool
*/
bool isAlreadyValid(NodePtr node);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is compatible with an others
* @details to say that two contexts are compatible is to check :
* that the contexts do not validate the same nodes (other than the common ones)
* and that the common ones have the same idx
* @param fsmContext the others context
* @return bool
*/
bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @ingroup FsmRunTimeContextTest
* @brief test if this context is strictly equal with an others
* @param fsmContext the others context
* @return bool
*/
bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext);
/**
* @defgroup FsmRunTimeContextSet Function set context
*/
void setCommon(NodePtr node,std::size_t commonIdx);
void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag);
/**
* @defgroup FsmRunTimeContextGet Function get context
*/
/**
* @ingroup FsmRunTimeContextGet
* @brief get the sub idx state
* @return bool
*/
std::size_t getSubStmId(void);
NodePtr getCommonNodeFromIdx(std::size_t commonIdx);
std::size_t getCommonNodeIdx(NodePtr node);
std::set<NodePtr> getCommonNodes(void);
std::map<NodePtr,std::size_t> getCommon(void);
std::set<NodePtr> getValidNodes(void);
std::set<NodePtr> getValidNodesNoCommon(void);
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void);
NodePtr getActNode(void);
std::shared_ptr<FsmNode> getActState(void);
/**
* @defgroup FsmRunTimeContextMem
*/
void rst(void);
};
}
#endif //AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_
void setCommon(NodePtr node,std::size_t commonIdx);
void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag);
/**
* @defgroup FsmRunTimeContextGet Function get context
*/
/**
* @ingroup FsmRunTimeContextGet
* @brief get the sub idx state
* @return bool
*/
std::size_t getSubStmId(void);
NodePtr getCommonNodeFromIdx(std::size_t commonIdx);
std::size_t getCommonNodeIdx(NodePtr node);
std::set<NodePtr> getCommonNodes(void);
std::map<NodePtr,std::size_t> getCommon(void);
std::set<NodePtr> getValidNodes(void);
std::set<NodePtr> getValidNodesNoCommon(void);
std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void);
NodePtr getActNode(void);
std::shared_ptr<FsmNode> getActState(void);
/**
* @defgroup FsmRunTimeContextMem
*/
void rst(void);
};
} // namespace Aidge
#endif // AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_
......@@ -24,8 +24,10 @@ private:
const std::vector<NodePtr> mStartNode;
public:
MatchSolution() = delete;
MatchSolution(std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string query,const std::vector<NodePtr> startNode);
inline const std::set<NodePtr>& at(const std::string key) {
inline const std::set<NodePtr>& at(const std::string& key) {
return mSolution[key];
}
const std::set<NodePtr> getAll();
......@@ -33,7 +35,6 @@ public:
inline const std::string& getQuery() const noexcept { return mQueryFrom; }
inline const std::vector<NodePtr>& getStartNode() const noexcept { return mStartNode; }
};
......@@ -60,14 +61,18 @@ private:
public:
MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm,
const std::string& query,const std::vector<NodePtr>& startNodes);
MatchResult() = delete;
MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid,
std::size_t nbSubStm,
const std::string& query,const std::vector<NodePtr>& startNodes);
/**
* @brief get the set of the node match for une expression
* @return the set of node of the graph that corresponding to an expression
*/
std::shared_ptr<MatchSolution> getBiggerSolution(void);
inline std::shared_ptr<MatchSolution> getBiggerSolution(void) const noexcept {
return mSolve.empty() ? nullptr : mSolve[0];
}
inline std::vector<std::shared_ptr<MatchSolution>> getSolutions(void) const noexcept {
return mSolve;
......
......@@ -76,14 +76,9 @@ public:
// }
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Add_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
for (std::size_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setBackend(name);
}
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName() {
......
......@@ -136,12 +136,9 @@ public:
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -87,22 +87,22 @@ public:
if(getInput(i)->size() != nbFeatures) {
// /!\ Input size should be handled BEFORE calling this function
// This should raise an error
getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]}));
getInput(i)->resize({getInput(0)->dims()[1]});
}
}
mOutputs[0]->resize(getInput(0)->dims());
}
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(1)->setBackend(name);
getInput(2)->setBackend(name);
getInput(3)->setBackend(name);
getInput(4)->setBackend(name);
// By default, automatically set backend for scale, shift, mean and variance
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
getInput(3)->setBackend(name, device);
getInput(4)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName() {
......@@ -123,10 +123,10 @@ inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale");
addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift");
addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean");
addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance");
addProducer(batchNorm, 1, {nbFeatures}, "scale");
addProducer(batchNorm, 2, {nbFeatures}, "shift");
addProducer(batchNorm, 3, {nbFeatures}, "batch_mean");
addProducer(batchNorm, 4, {nbFeatures}, "batch_variance");
return batchNorm;
}
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CAST_H_
#define AIDGE_CORE_OPERATOR_CAST_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Cast_Op : public OperatorTensor,
public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> {
public:
static const std::string Type;
Cast_Op() : OperatorTensor(Type, 1, 0, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Cast_Op(const Cast_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Cast_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Cast_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Cast_Op>(*this);
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
if (Registrar<Cast_Op>::exists({name})) {
mImpl = Registrar<Cast_Op>::create({name})(*this);
}
mOutputs[0]->setBackend(name, device);
}
void forward() override;
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Cast(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Cast_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_CAST_H_ */
\ No newline at end of file
......@@ -101,14 +101,9 @@ public:
}
}
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Concat_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
for (std::size_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setBackend(name);
}
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -173,13 +173,13 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> co
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(1)->setBackend(name);
getInput(2)->setBackend(name);
// By default, automatically set backend for weight and bias inputs
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......@@ -215,9 +215,8 @@ inline std::shared_ptr<Node> Conv(DimSize_t inChannels,
// FIXME: properly handle default w&b initialization in every cases
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported");
auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(inChannels, outChannels, kernelDims, strideDims, dilationDims), name);
// addProducer(conv, 1, append(append(kernel_dims, in_channels), out_channels), "w");
addProducer(conv, 1, append(outChannels, append(inChannels, kernelDims)), "w");
addProducer(conv, 2, std::array<DimSize_t, 1>({outChannels}), "b");
addProducer(conv, 2, {outChannels}, "b");
return conv;
}
......
......@@ -167,13 +167,13 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(1)->setBackend(name);
getInput(2)->setBackend(name);
// By default, automatically set backend for weight and bias inputs
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......@@ -197,7 +197,7 @@ inline std::shared_ptr<Node> ConvDepthWise(const DimSize_t nbChannels,
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported");
auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(nbChannels, kernelDims, strideDims, dilationDims), name);
addProducer(convDW, 1, append(nbChannels, append(DimSize_t(1), kernelDims)), "w");
addProducer(convDW, 2, std::array<DimSize_t, 1>({nbChannels}), "b");
addProducer(convDW, 2, {nbChannels}, "b");
return convDW;
}
......
......@@ -54,13 +54,9 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Div_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
getInput(1)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -51,12 +51,9 @@ public:
return std::make_shared<Erf_Op>(*this);
}
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Erf_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -77,7 +77,7 @@ public:
}
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
if (inputIdx == 0 && getInput(0)->nbDims() == 1)
mInputs[inputIdx]->resize(std::array<DimSize_t, 2>({1, getInput(inputIdx)->size()}));
mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()});
}
void computeOutputDims() override final {
......@@ -95,14 +95,13 @@ public:
}
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<FC_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
getInput(1)->setBackend(name);
getInput(2)->setBackend(name);
// By default, automatically set backend for weight and bias inputs
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......@@ -116,8 +115,8 @@ public:
inline std::shared_ptr<Node> FC(DimSize_t inChannels, DimSize_t outChannels, bool noBias = false, const std::string& name = "") {
// FIXME: properly handle default w&b initialization in every cases
auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(outChannels, noBias), name);
addProducer(fc, 1, std::array<DimSize_t, 2>({outChannels, inChannels}), "w");
addProducer(fc, 2, (noBias ? std::array<DimSize_t, 1>({0}) : std::array<DimSize_t, 1>({outChannels})), "b"); // already sets bias dims
addProducer(fc, 1, {outChannels, inChannels}, "w");
addProducer(fc, 2, {(noBias ? 0 : outChannels)}, "b"); // already sets bias dims
return fc;
}
} // namespace Aidge
......