diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 6cf89a45fd0d4cf1dc970d199d074e886b131896..26ae544d6e05f2f9a9da371d3617f9265a037364 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase): def tearDown(self): pass + def test_remove_dropout(self): + graph_view = aidge_core.sequential([ + aidge_core.GenericOperator("Conv", 1, 0, 1, "Conv0"), + aidge_core.GenericOperator("Dropout", 1, 0, 1, name="Dropout0") + ]) + old_nodes = graph_view.get_nodes() + aidge_core.remove_dropout(graph_view) + self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1) + self.assertTrue("Dropout0" not in [i.name for i in graph_view.get_nodes()]) + + self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) + def test_remove_flatten(self): graph_view = aidge_core.sequential([ aidge_core.GenericOperator("Flatten", 1, 0, 1, name="Flatten0"), diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 4fea0f4950c0d48ed3eaadf361438e2859692e10..09ebccb7e830d6576f1c3c7d9e7eb057b91c055c 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -36,7 +36,9 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/Div.hpp" +#include "aidge/operator/Erf.hpp" #include "aidge/operator/FC.hpp" +#include "aidge/operator/Gather.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/MatMul.hpp" #include "aidge/operator/MaxPooling.hpp" @@ -47,13 +49,15 @@ #include "aidge/operator/Pad.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Pow.hpp" +#include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/ReLU.hpp" +#include "aidge/operator/Reshape.hpp" #include "aidge/operator/Scaling.hpp" #include "aidge/operator/Slice.hpp" #include "aidge/operator/Softmax.hpp" #include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sub.hpp" - +#include "aidge/operator/Transpose.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/stimuli/Stimuli.hpp" diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index f8d398c7801f45a0411fafa446ae7c51ce671cfc..a27f0317c59916facef970a3c1b91704fb485cd4 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -14,29 +14,154 @@ #include <cstddef> #include <cstdio> +#include "aidge/data/Data.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { +/** + * This is a thin wrapper around std::any that can only hold pointers. + * It also handles the case where a U* pointer is stored and a const U* pointer + * is requested, which is legit (std::any would throw a bad_cast exception in + * this case). + * Note: not used yet, put in reserve here for possible future use. +*/ +/* +class AnyPtr { +public: + template <typename T, typename = std::enable_if_t<std::is_pointer<T>::value>> + constexpr inline AnyPtr(T value) : data(value), ptrToConst(std::is_const<std::remove_pointer_t<T>>::value) {} + + // Requested T is "U*" + template <typename T, typename std::enable_if<std::is_same<std::remove_pointer_t<T>, std::remove_const_t<std::remove_pointer_t<T>>>::value>::type* = nullptr> + constexpr inline T get() const { + // data has to be "U*" + return future_std::any_cast<T>(data); + } + + // Requested T is "const U*" + template <typename T, typename std::enable_if<!std::is_same<std::remove_pointer_t<T>, std::remove_const_t<std::remove_pointer_t<T>>>::value>::type* = nullptr> + constexpr inline T get() const { + if (ptrToConst) { + // data is "const U*" => OK, no bad cast + return future_std::any_cast<T>(data); + } + else { + // data is "U*" => need to remove const from request to avoid bad cast + return future_std::any_cast<std::add_pointer_t<std::remove_const_t<std::remove_pointer_t<T>>>>(data); + } + } + +private: + const future_std::any data; + const bool ptrToConst; +}; +*/ + +/** + * This class manages the raw data storage of a Tensor and provide generic copy + * primitives from other devices and from/to host. + * It can own the data or not (use setRawPtr() to set an external data owner). + * It only knows the data type and data capacity, but does not handle anything else. +*/ class TensorImpl { public: TensorImpl() = delete; - TensorImpl(const char *backend) : mBackend(backend){}; - virtual void copy(const void *src, NbElts_t length, std::size_t offset = 0) = 0; - virtual void *rawPtr() = 0; - virtual void setRawPtr(void* /*ptr*/) + TensorImpl(const char *backend, DeviceIdx_t device = 0) : mBackend(backend), mDevice(device){}; + + /** + * Return the (backend, device) pair for this implementation. + */ + std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); } + + /** + * Set the device ID for current backend. + * @param device New device ID on current backend. + */ + virtual void setDevice(DeviceIdx_t device) = 0; + + /** + * Copy data from the same device. + * @param src Pointer on current implementation device. + * @param length Number of elements to copy. + * @param offset Destination offset (in number of elements). + */ + virtual void copy(const void *src, NbElts_t length, NbElts_t offset = 0) = 0; + + /** + * Copy-convert data from the same device. + * @param srcDt Source data type. + * @param src Pointer on current implementation device. + * @param length Number of elements to copy. + */ + virtual void copyCast(const void *src, NbElts_t length, const DataType srcDt) = 0; + + /** + * Copy data from an other device on the same backend. + * @param device (backend, device) pair to copy from. The backend must match current implementation backend. + * @param src Pointer on current implementation backend. + * @param length Number of elements to copy. + */ + virtual void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) = 0; + + /** + * Copy data from host. + * @param src Host pointer to copy from. + * @param length Number of elements to copy. + */ + virtual void copyFromHost(const void *src, NbElts_t length) = 0; + + /** + * Copy data to host. + * @param src Host pointer to copy to. + * @param length Number of elements to copy. + */ + virtual void copyToHost(void *dst, NbElts_t length) const = 0; + + /** + * Return the raw device pointer. + * The raw pointer is garanteed to be valid only on the *same* device. + * @param offset Offset, in number of elements. + */ + virtual void* rawPtr(NbElts_t offset = 0) = 0; + virtual const void* rawPtr(NbElts_t offset = 0) const = 0; + + /** + * Return the host pointer. + * If the implementation does not have a valid host pointer, nullptr is returned. + * @param offset Offset, in number of elements. + */ + virtual void* hostPtr(NbElts_t /*offset*/ = 0) { return nullptr; }; + virtual const void* hostPtr(NbElts_t /*offset*/ = 0) const { return nullptr; }; + + /** + * Sets the device pointer. The previously owned data is deleted. + * UNSAFE: directly setting the device pointer may lead to undefined behavior + * if it does not match the required storage. + * @param ptr A valid device pointer. + * @param length Storage capacity at the provided pointer + */ + virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/) { - printf("Cannot set raw pointer for backend %s\n", mBackend); + AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); }; - virtual void* getRaw(std::size_t /*idx*/)=0; - + virtual std::size_t size() const = 0; // Storage size virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) constexpr const char *backend() const { return mBackend; } virtual ~TensorImpl() = default; virtual bool operator==(const TensorImpl &othImpl) const = 0; -private: + /** + * Copy from another backend. + * @param srcImpl Source TensorImpl to copy from. + * @param length Number of elements of size scalarSize() to copy + */ + void copyFrom(const TensorImpl& srcImpl, NbElts_t length); + +protected: const char *mBackend; + DeviceIdx_t mDevice; }; } // namespace Aidge diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 953f684c9ceeae52b67977460256c8ed61ce9a41..bf34860fbc4e4d6cfef8528d20de40c3e31a292b 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -12,6 +12,7 @@ #ifndef AIDGE_DATA_H_ #define AIDGE_DATA_H_ +#include "aidge/data/half.hpp" #include "aidge/utils/Attributes.hpp" namespace Aidge { @@ -61,12 +62,15 @@ namespace { template <typename T> struct NativeType { static const Aidge::DataType type; }; template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64; template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32; -template <> const Aidge::DataType NativeType<long>::type = Aidge::DataType::Int64; -template <> const Aidge::DataType NativeType<int>::type = Aidge::DataType::Int32; -template <> const Aidge::DataType NativeType<int16_t>::type = Aidge::DataType::Int16; -template <> const Aidge::DataType NativeType<uint16_t>::type = Aidge::DataType::UInt16; +template <> const Aidge::DataType NativeType<half_float::half>::type = Aidge::DataType::Float16; template <> const Aidge::DataType NativeType<int8_t>::type = Aidge::DataType::Int8; +template <> const Aidge::DataType NativeType<int16_t>::type = Aidge::DataType::Int16; +template <> const Aidge::DataType NativeType<int32_t>::type = Aidge::DataType::Int32; +template <> const Aidge::DataType NativeType<int64_t>::type = Aidge::DataType::Int64; template <> const Aidge::DataType NativeType<uint8_t>::type = Aidge::DataType::UInt8; +template <> const Aidge::DataType NativeType<uint16_t>::type = Aidge::DataType::UInt16; +template <> const Aidge::DataType NativeType<uint32_t>::type = Aidge::DataType::UInt32; +template <> const Aidge::DataType NativeType<uint64_t>::type = Aidge::DataType::UInt64; template <> const char* const EnumStrings<Aidge::DataType>::data[] diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 021f8a88b326a08a7b96a1dd5c98832cc0d607cb..3780e0f4c001581c09af80f9e664a0c4e09c5796 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -23,114 +23,9 @@ #include "aidge/data/Data.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/ArrayHelpers.hpp" namespace Aidge { - -// Helper to create default arrays -template <typename T, std::size_t ... Is> -constexpr std::array<T, sizeof...(Is)> -create_array_impl(T value, std::index_sequence<Is...>) -{ - // cast Is to void to remove the warning: unused value - return {{(static_cast<void>(Is), value)...}}; -} - -template <typename T, std::size_t N> -constexpr std::array<T, N> create_array(const T& value) -{ - return create_array_impl(value, std::make_index_sequence<N>()); -} - - -// Helper to convert vector to array -template <typename T, typename Iter, std::size_t... Is> -constexpr auto to_array(Iter &iter, std::index_sequence<Is...>) -> std::array<T, sizeof...(Is)> { - return {{((void)Is, T(*iter++))...}}; -} - -/** - * @brief Convert an object with an iterator to an std::array. - */ -template <std::size_t N, typename U = void, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type, - typename T = std::conditional_t<std::is_same<U, void>{}, V, U>> -constexpr auto to_array(Iter iter) -> std::array<T, N> { - return to_array<T>(iter, std::make_index_sequence<N>{}); -} - -namespace detail { - -template <class T, std::size_t N, std::size_t... I> -constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], std::index_sequence<I...>) { - return {{a[I]...}}; -} - -} // namespace detail - -/** - * @brief Convert a C-stype array into a C++ std::array. - * - * @tparam T Data type. - * @tparam N Number of elements. - * @param a C-style array to convert. - * @return constexpr std::array<std::remove_cv_t<T>, N> - */ -template <class T, std::size_t N> -constexpr std::array<std::remove_cv_t<T>, N> to_array(T (&a)[N]) { - return detail::to_array_impl(a, std::make_index_sequence<N>{}); -} - -template <typename T, std::size_t N, std::size_t... I> -constexpr std::array<T, N + 1> append(std::array<T, N> a, T t, std::index_sequence<I...>) { - return std::array<T, N + 1>{a[I]..., t}; -} - -template <typename T, std::size_t N, std::size_t... I> -constexpr std::array<T, N + 1> append(T t, std::array<T, N> a, std::index_sequence<I...>) { - return std::array<T, N + 1>{t, a[I]...}; -} - -/** - * @brief Create a new array concatenating the initial one with the value to - * add. - * @details append({1,2,7}, 3) -> {1,2,7,3} - * - * @tparam T Data type. - * @tparam N Number of elements in the initilial array. - * @param a Initial array. - * @param t Element to add. - * @return constexpr std::array<T, N + 1> - */ -template <typename T, std::size_t N> -constexpr std::array<T, N + 1> append(std::array<T, N> a, T t) { - return append(a, t, std::make_index_sequence<N>()); -} - -template <typename T, std::size_t N> -constexpr std::array<T, N + 1> append(T t, std::array<T, N> a) { - return append(t, a, std::make_index_sequence<N>()); -} - -// Generic helper for initializing a Tensor -template <typename T, std::size_t SIZE_0> -struct Array1D { - T data[SIZE_0]; -}; - -template <typename T, std::size_t SIZE_0, std::size_t SIZE_1> -struct Array2D { - T data[SIZE_0][SIZE_1]; -}; - -template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2> -struct Array3D { - T data[SIZE_0][SIZE_1][SIZE_2]; -}; - -template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3> -struct Array4D { - T data[SIZE_0][SIZE_1][SIZE_2][SIZE_3]; -}; - /** * @brief Description for the tensor data structure. * @details Sets the properties of the tensor without actually containing any data. @@ -145,8 +40,7 @@ class Tensor : public Data, std::shared_ptr<Tensor> mGrad; /** Pointer to the associated gradient Tensor instance. */ // Cached data - std::size_t mSize; /** Number of elements in the Tensor. */ - std::size_t mSizeM1; /** Number of elements in the N-1 first dimensions */ + std::size_t mSize = 0; /** Number of elements in the Tensor. */ public: static constexpr const char *Type = "Tensor"; @@ -157,10 +51,7 @@ class Tensor : public Data, */ Tensor(DataType dataType = DataType::Float32) : Data(Type), - mDataType(dataType), - mDims({}), - mSize(0), - mSizeM1(0) + mDataType(dataType) { // ctor } @@ -187,11 +78,12 @@ class Tensor : public Data, : Data(Type), mDataType(otherTensor.mDataType), mDims(otherTensor.mDims), - mSize(otherTensor.mSize), - mSizeM1(otherTensor.mSizeM1) + mSize(otherTensor.mSize) { if (otherTensor.hasImpl()) { mImpl = Registrar<Tensor>::create({otherTensor.mImpl->backend(), dataType()})(*this); + mImpl->setDevice(otherTensor.mImpl->device().second); + // Same backend, same device => directly use copy() mImpl->copy(otherTensor.mImpl->rawPtr(), mSize); } } @@ -207,9 +99,8 @@ class Tensor : public Data, mDataType(NativeType<T>::type), mDims({SIZE_0}), mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), - mSize(SIZE_0), - mSizeM1(SIZE_0) { - mImpl->copy(&arr.data[0], SIZE_0); + mSize(SIZE_0) { + mImpl->copyFromHost(&arr.data[0], SIZE_0); } template <typename T, std::size_t SIZE_0> @@ -218,7 +109,7 @@ class Tensor : public Data, if (!mImpl) { mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); } - mImpl->copy(&arr.data[0], SIZE_0); + mImpl->copyFromHost(&arr.data[0], SIZE_0); return *this; } @@ -234,9 +125,8 @@ class Tensor : public Data, mDataType(NativeType<T>::type), mDims({SIZE_0, SIZE_1}), mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), - mSize(SIZE_0 * SIZE_1), - mSizeM1(SIZE_1) { - mImpl->copy(&arr.data[0][0], SIZE_0 * SIZE_1); + mSize(SIZE_0 * SIZE_1) { + mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1); } template <typename T, std::size_t SIZE_0, std::size_t SIZE_1> @@ -245,7 +135,7 @@ class Tensor : public Data, if (!mImpl) { mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); } - mImpl->copy(&arr.data[0][0], SIZE_0 * SIZE_1); + mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1); return *this; } @@ -262,9 +152,8 @@ class Tensor : public Data, mDataType(NativeType<T>::type), mDims({SIZE_0, SIZE_1, SIZE_2}), mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), - mSize(SIZE_0 * SIZE_1 * SIZE_2), - mSizeM1(SIZE_1 * SIZE_2) { - mImpl->copy(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); + mSize(SIZE_0 * SIZE_1 * SIZE_2) { + mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); } template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2> @@ -273,7 +162,7 @@ class Tensor : public Data, if (!mImpl) { mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); } - mImpl->copy(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); + mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); return *this; } @@ -291,9 +180,8 @@ class Tensor : public Data, mDataType(NativeType<T>::type), mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}), mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), - mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3), - mSizeM1(SIZE_1 * SIZE_2 * SIZE_3) { - mImpl->copy(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); + mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3) { + mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); } template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3> @@ -302,7 +190,7 @@ class Tensor : public Data, if (!mImpl) { mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); } - mImpl->copy(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); + mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); return *this; } @@ -315,8 +203,15 @@ class Tensor : public Data, resize(t.dims()); setDataType(t.dataType()); if (t.hasImpl()) { - setBackend(t.mImpl->backend()); - mImpl->copy(t.mImpl->rawPtr(), size()); + if (hasImpl()) { + copyCastFrom(t); + } + else { + mImpl = Registrar<Tensor>::create({t.mImpl->backend(), dataType()})(*this); + mImpl->setDevice(t.mImpl->device().second); + // Same backend, same device => directly use copy() + mImpl->copy(t.mImpl->rawPtr(), mSize); + } } else { mImpl = nullptr; @@ -337,21 +232,33 @@ class Tensor : public Data, } /** - * @brief Set the backend of the Tensor associated implementation - * @details Create and initialized an implementation if non was associated. - * @param name + * @brief Set the backend of the Tensor associated implementation. If there + * was no previous implementation set, data will be allocated, but it will + * not be initialized to any particular value. + * If data was already initialized in a previous backend, it will be moved + * to the new one except if copyFrom is false. + * @param name Backend name + * @param device Backend device + * @param copyFrom If true (default), move data from previous backend/device + * to the new one. Previous data is lost otherwise. */ - inline void setBackend(const std::string &name) { + inline void setBackend(const std::string &name, DeviceIdx_t device = 0, bool copyFrom = true) { if (mImpl) { - if (strcmp(mImpl->backend(), name.c_str()) != 0) { + if (mImpl->device() != std::make_pair(name, device)) { // Backend change: create new impl, copy from old to new and replace // impl std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this); - newImpl->copy(mImpl->rawPtr(), size()); + newImpl->setDevice(device); + if (copyFrom) { + newImpl->copyFrom(*mImpl, size()); + } mImpl = std::move(newImpl); } - } else + } + else { mImpl = Registrar<Tensor>::create({name, mDataType})(*this); + mImpl->setDevice(device); + } } /** @@ -373,16 +280,17 @@ class Tensor : public Data, /** * @brief Set the DataType of the Tensor and converts data - * if the Tensor has already been initialized. - * @param dt DataType. + * if the Tensor has already been initialized and copyCast is true. + * @param dt DataType + * @param copyCast If true (default), previous data is copy-casted. Otherwise + * previous data is lost. */ - void setDataType(const DataType dt) { + void setDataType(const DataType dt, bool copyCast = true) { if (mImpl && (dataType() != dt)) { - // get ptr before changing Tensor backend or the type difference will trigger a warning - const void *data = mImpl->rawPtr(); - mDataType = dt; std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this); - newImpl->copy(data, size()); // /!\ it does not cast data but reinterpret them + if (copyCast) { + newImpl->copyCast(mImpl->rawPtr(), size(), mDataType); + } mImpl = std::move(newImpl); } mDataType = dt; @@ -393,6 +301,7 @@ class Tensor : public Data, * @return constexpr const std::unique_ptr<TensorImpl>& */ constexpr const std::unique_ptr<TensorImpl> &getImpl() { return mImpl; } + constexpr const std::unique_ptr<TensorImpl> &getImpl() const { return mImpl; } /** * @brief Return if an implementaiton has been associated. @@ -431,23 +340,31 @@ class Tensor : public Data, constexpr std::size_t size() const { return mSize; } /** - * @brief Get the number of elements in the N-1 dimensions of the Tensor object. - * @return constexpr std::size_t - */ - constexpr std::size_t sizeM1() const { return mSizeM1; } - - /** - * @brief Change the shape of the Tensor object according to the given argument. - * @tparam DIM new dimensions. - * @param dims + * @brief Change the dimensions of the Tensor object according to the given argument. + * If the overall size is not changed (meaning we actually only performed a + * reshape), data is garanteed to remain valid. + * Otherwise, no garantee is provided regarding the validy of previous data + * (unlike std::vector). If the new overall size is larger than the previous + * one, all previous data is invalided. Otherwise, previous data may or may + * not remain valid, depending on the backend implementation. + * @tparam DIM Number of dimensions. + * @param dims New dimensions */ template <std::array<DimSize_t, 1>::size_type DIM> // deducing std::array size_type and declaring DIM accordingly void resize(const std::array<DimSize_t, DIM> &dims) { - static_assert(DIM<=MaxDim,"Too many tensor dimensions required by resize, not supported"); - mDims.assign(dims.begin(), dims.end()); - computeSize(); + resize(std::vector<DimSize_t>(dims.begin(), dims.end())); } + /** + * @brief Change the dimensions of the Tensor object according to the given argument. + * If the overall size is not changed (meaning we actually only performed a + * reshape), data is garanteed to remain valid. + * Otherwise, no garantee is provided regarding the validy of previous data + * (unlike std::vector). If the new overall size is larger than the previous + * one, all previous data is invalided. Otherwise, previous data may or may + * not remain valid, depending on the backend implementation. + * @param dims New dimensions + */ void resize(const std::vector<DimSize_t> &dims) { mDims = dims; computeSize(); @@ -461,23 +378,23 @@ class Tensor : public Data, bool empty() const { return mDims.empty(); } template <typename expectedType> - expectedType& get(std::size_t idx){ - // TODO : add assert expected Type compatible with datatype - // TODO : add assert idx < Size - return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + const expectedType& get(std::size_t idx) const { + AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type"); + AIDGE_ASSERT(idx < mSize, "idx out of range"); + return *reinterpret_cast<expectedType *>(mImpl->hostPtr(idx)); } template <typename expectedType> - expectedType& get(std::vector<std::size_t> coordIdx){ + const expectedType& get(std::vector<std::size_t> coordIdx) const { return get<expectedType>(getIdx(coordIdx)); } template <typename expectedType> void set(std::size_t idx, expectedType value){ - // TODO : add assert expected Type compatible with datatype - // TODO : add assert idx < Size - void* dataPtr = mImpl->getRaw(idx); - std::memcpy(dataPtr, &value, sizeof(expectedType)); + AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type"); + AIDGE_ASSERT(idx < mSize, "idx out of range"); + expectedType* dataPtr = static_cast<expectedType*>(mImpl->hostPtr(idx)); + *dataPtr = value; } template <typename expectedType> @@ -487,17 +404,46 @@ class Tensor : public Data, - std::string toString() { + std::string toString() const { + AIDGE_ASSERT(mImpl && mImpl->hostPtr() != nullptr, "tensor should have a valid host pointer"); + + // TODO: move lambda elsewhere? + auto ptrToString = [](DataType dt, void* ptr, size_t idx) { + switch (dt) { + case DataType::Float64: + return std::to_string(static_cast<double*>(ptr)[idx]); + case DataType::Float32: + return std::to_string(static_cast<float*>(ptr)[idx]); + case DataType::Float16: + return std::to_string(static_cast<half_float::half*>(ptr)[idx]); + case DataType::Int8: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Int16: + return std::to_string(static_cast<int16_t*>(ptr)[idx]); + case DataType::Int32: + return std::to_string(static_cast<int32_t*>(ptr)[idx]); + case DataType::Int64: + return std::to_string(static_cast<int64_t*>(ptr)[idx]); + case DataType::UInt8: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::UInt16: + return std::to_string(static_cast<uint16_t*>(ptr)[idx]); + case DataType::UInt32: + return std::to_string(static_cast<uint32_t*>(ptr)[idx]); + case DataType::UInt64: + return std::to_string(static_cast<uint64_t*>(ptr)[idx]); + default: + AIDGE_ASSERT(true, "unsupported type to convert to string"); + } + return std::string("?"); // To make Clang happy + }; + if (dims().empty()) { return "{}"; } std::string res; std::size_t dim = 0; std::size_t counter = 0; if (nbDims()>=2) { - std::size_t *dimVals = new std::size_t[nbDims()]; - for (std::size_t i = 0; i < nbDims(); ++i) { - dimVals[i] = 0; - } - // std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0); + std::vector<std::size_t> dimVals(nbDims(), 0); res += "{\n"; while (counter < mSize) { std::string spaceString = std::string((dim+1)<<1,' '); @@ -517,31 +463,9 @@ class Tensor : public Data, for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { res += spaceString + "{"; for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { - switch (mDataType) - { - case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + ","; - break; - case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + ","; - break; - default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + ","; - break; - } - } - switch (mDataType) - { - case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + "}"; - break; - case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + "}"; - break; - default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + "}"; - break; + res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + ","; } + res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + "}"; if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { res += ","; } @@ -554,7 +478,6 @@ class Tensor : public Data, dimVals[dim]++; } } - delete[] dimVals; for(int i = static_cast<int>(dim); i > 0; --i) { res += std::string((dim+1)<<1,' ') + "}\n"; @@ -562,25 +485,14 @@ class Tensor : public Data, } else { res += "{"; for (DimSize_t j = 0; j < dims()[0]; ++j) { - switch (mDataType) - { - case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); - break; - case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); - break; - default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); - break; - } + res += " " + ptrToString(mDataType, mImpl->hostPtr(), j) + ((j < dims()[0]-1) ? "," : ""); } } res += "}"; return res; } - inline void print() { printf("%s\n", toString().c_str()); } + inline void print() const { printf("%s\n", toString().c_str()); } std::shared_ptr<Tensor> grad() { if (!mGrad) { @@ -594,9 +506,9 @@ class Tensor : public Data, } /** - * @brief From the the 1D index, return the coordinate of an element in the tensor. + * @brief From the the 1D contiguous index, return the coordinate of an element in the tensor. * - * @param flatIdx 1D index of the value considering a flatten tensor. + * @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor. * @return std::vector<DimSize_t> */ std::vector<std::size_t> getCoord(std::size_t flatIdx) const { @@ -611,39 +523,147 @@ class Tensor : public Data, } /** - * @brief From the coordinate returns the 1D index of an element in the tensor. + * @brief From the coordinate returns the 1D contiguous index of an element in the tensor. + * If the number of coordinates is inferior to the number of dimensions, + * the remaining coordinates are assumed to be 0. * * @param coordIdx Coordinate to an element in the tensor - * @return DimSize_t + * @return DimSize_t Contiguous index */ - std::size_t getIdx(std::vector<std::size_t> coordIdx) const { - // std::size_t flatIdx = 0; - // std::size_t stride = 1; + std::size_t getIdx(const std::vector<std::size_t>& coordIdx) const { + AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions"); std::size_t flatIdx = 0; - assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions"); std::size_t i = 0; - for(; i < mDims.size() - 1; ++i){ - assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor"); + for(; i < coordIdx.size() - 1; ++i){ + AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor"); flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1]; } return flatIdx + coordIdx[i]; } + /** + * Copy-cast data from a Tensor on the same device. + * If current tensor backend/device is set and is different from src, an + * assertion is raised. + * @param src Source tensor to copy-cast from. + */ + void copyCast(const Tensor& src); + + /** + * Copy data from a Tensor from another backend/device. + * If current tensor data type is set and is different from src, an + * assertion is raised. + * @param src Source tensor to copy from. + */ + void copyFrom(const Tensor& src); + + /** + * Copy-cast data from a Tensor. + * @param src Source tensor to copy-cast from. + * @param movedSrc shared_ptr to an indermediate Tensor that will + * contain the moved data if a device change should occur AND a type + * conversion is necessary (otherwise it remains unused). + * Any data already present will be overwritten. No new memory allocation + * will occur if movedSrc has already been allocated with the right + * type/size/device. + * If required, memory is always allocated on current (destination) + * Tensor's device. + */ + void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc); + + /** + * Copy-cast data from a Tensor. + * In case of both a device change AND a data type conversion, an + * intermediate buffer on will be allocated and deallocated each time. + * If required, buffer's memory is always allocated on current (destination) + * Tensor's device. + * @param src Source tensor to copy-cast from. + */ + void copyCastFrom(const Tensor& src) { + // Internal buffer will be allocated and deallocated at each call + // (only if needed) + std::shared_ptr<Tensor> movedSrc; + copyCastFrom(src, movedSrc); + } + + /** + * Return a reference to a Tensor casted to the desired data type: + * - itself, if already at the right data type; + * - the provided Tensor, overwritten with the copy-casted data. + * The backend stays the same. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param dt The desired data type. + * @return Reference to either itself or to fallback. + */ + Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt); + const Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const; + + /** + * Return a reference to a Tensor on the desired backend/device: + * - itself, if already on the right device; + * - the provided Tensor, overwritten with the copied data. + * The data type stays the same. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param backend The desired backend. + * @param device The desired device. + * @return Reference to either itself or to fallback. + */ + Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device = 0); + const Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device = 0) const; + + /** + * Return a reference to a Tensor on desired data type and backend/device: + * - itself, if already with the right characteristics; + * - the provided Tensor, overwritten with the copy-casted data. + * If required, fallback is always allocated on desired (destination) + * device. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param dt The desired data type. + * @param backend The desired backend. + * @param device The desired device. + * @return Reference to either itself or to fallback. + */ + Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) { + // First refFrom, to ensure that fallback, if required, is also on desired device + return refFrom(fallback, backend, device).refCast(fallback, dt); + } + + /** + * Return a reference to a Tensor with same characteristics + * (data type, backend/device) as targetReqs Tensor: + * - itself, if already with the right characteristics; + * - the provided Tensor, overwritten with the copy-casted data. + * If required, fallback is always allocated on current (destination) + * Tensor's device. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param targetReqs Tensor with the desired target characteristics. + * @return Reference to either itself or to fallback. + */ + Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) { + const auto& device = targetReqs.getImpl()->device(); + return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second); + } + private: ///\bug not protected against overflow std::size_t computeSize() { if (mDims.empty()) { - mSizeM1 = DimSize_t(0); mSize = DimSize_t(0); } - else if (mDims.size() == 1) - { - mSizeM1 = mDims[0]; - mSize = mDims[0]; - } else { - mSizeM1 = std::accumulate(++mDims.begin(),mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>()); - mSize = static_cast<std::size_t>(mSizeM1 * mDims[0]); + mSize = std::accumulate(mDims.begin(), mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>()); } return mSize; diff --git a/include/aidge/data/half.hpp b/include/aidge/data/half.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89df93cf3d10087833b3ad00dfbe3afd4e94c725 --- /dev/null +++ b/include/aidge/data/half.hpp @@ -0,0 +1,3067 @@ +// half - IEEE 754-based half-precision floating point library. +// +// Copyright (c) 2012-2017 Christian Rau <rauy@users.sourceforge.net> +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Version 1.12.0 + +/// \file +/// Main header file for half precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +/// Combined gcc version number. +#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) + +//check C++11 language features +#if defined(__clang__) //clang + #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif +/*#elif defined(__INTEL_COMPILER) //Intel C++ + #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif*/ +#elif defined(__GNUC__) //gcc + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #endif +#elif defined(_MSC_VER) //Visual C++ + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #define HALF_POP_WARNINGS 1 + #pragma warning(push) + #pragma warning(disable : 4099 4127 4146) //struct vs class, constant in if, negative unsigned +#endif + +//check C++11 library features +#include <utility> +#if defined(_LIBCPP_VERSION) //libc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif +#elif defined(__GLIBCXX__) //libstdc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifdef __clang__ + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #else + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif + #endif +#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ + #if _CPPLIB_VER >= 520 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #endif + #if _CPPLIB_VER >= 610 + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #endif +#endif +#undef HALF_GNUC_VERSION + +//support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR + #define HALF_CONSTEXPR constexpr + #define HALF_CONSTEXPR_CONST constexpr +#else + #define HALF_CONSTEXPR + #define HALF_CONSTEXPR_CONST const +#endif + +//support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT + #define HALF_NOEXCEPT noexcept + #define HALF_NOTHROW noexcept +#else + #define HALF_NOEXCEPT + #define HALF_NOTHROW throw() +#endif + +#include <algorithm> +#include <iostream> +#include <limits> +#include <climits> +#include <cmath> +#include <cstring> +#if HALF_ENABLE_CPP11_TYPE_TRAITS + #include <type_traits> +#endif +#if HALF_ENABLE_CPP11_CSTDINT + #include <cstdint> +#endif +#if HALF_ENABLE_CPP11_HASH + #include <functional> +#endif + + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as +/// for the half_cast() if not specifying a rounding mode explicitly. It can be redefined (before including half.hpp) to one +/// of the standard rounding modes using their respective constants or the equivalent values of `std::float_round_style`: +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest (default) +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `-1` (`std::round_indeterminate`), which uses truncation (round toward zero, but with overflows +/// set to infinity) and is the fastest rounding mode possible. It can even be set to `std::numeric_limits<float>::round_style` +/// to synchronize the rounding mode with that of the underlying single-precision implementation. +#ifndef HALF_ROUND_STYLE + #define HALF_ROUND_STYLE -1 // = std::round_indeterminate +#endif + +/// Tie-breaking behaviour for round to nearest. +/// This specifies if ties in round to nearest should be resolved by rounding to the nearest even value. By default this is +/// defined to `0` resulting in the faster but slightly more biased behaviour of rounding away from zero in half-way cases (and +/// thus equal to the round() function), but can be redefined to `1` (before including half.hpp) if more IEEE-conformant +/// behaviour is needed. +#ifndef HALF_ROUND_TIES_TO_EVEN + #define HALF_ROUND_TIES_TO_EVEN 0 // ties away from zero +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `<cmath>` this symbol expands to a positive value signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +#define HUGE_VALH std::numeric_limits<half_float::half>::infinity() + +/// Fast half-precision fma function. +/// This symbol is only defined if the fma() function generally executes as fast as, or faster than, a separate +/// half-precision multiplication followed by an addition. Due to the internal single-precision implementation of all +/// arithmetic operations, this is in fact always the case. +#define FP_FAST_FMAH 1 + +#ifndef FP_ILOGB0 + #define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN + #define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL + #define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO + #define FP_ZERO 1 +#endif +#ifndef FP_NAN + #define FP_NAN 2 +#endif +#ifndef FP_INFINITE + #define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL + #define FP_NORMAL 4 +#endif + + +/// Main namespace for half precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float +{ + class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS + /// Library-defined half-precision literals. + /// Import this namespace to enable half-precision floating point literals: + /// ~~~~{.cpp} + /// using namespace half_float::literal; + /// half_float::half = 4.2_h; + /// ~~~~ + namespace literal + { + half operator"" _h(long double); + } +#endif + + /// \internal + /// \brief Implementation details. + namespace detail + { + #if HALF_ENABLE_CPP11_TYPE_TRAITS + /// Conditional type. + template<bool B,typename T,typename F> struct conditional : std::conditional<B,T,F> {}; + + /// Helper for tag dispatching. + template<bool B> struct bool_type : std::integral_constant<bool,B> {}; + using std::true_type; + using std::false_type; + + /// Type traits for floating point types. + template<typename T> struct is_float : std::is_floating_point<T> {}; + #else + /// Conditional type. + template<bool,typename T,typename> struct conditional { typedef T type; }; + template<typename T,typename F> struct conditional<false,T,F> { typedef F type; }; + + /// Helper for tag dispatching. + template<bool> struct bool_type {}; + typedef bool_type<true> true_type; + typedef bool_type<false> false_type; + + /// Type traits for floating point types. + template<typename> struct is_float : false_type {}; + template<typename T> struct is_float<const T> : is_float<T> {}; + template<typename T> struct is_float<volatile T> : is_float<T> {}; + template<typename T> struct is_float<const volatile T> : is_float<T> {}; + template<> struct is_float<float> : true_type {}; + template<> struct is_float<double> : true_type {}; + template<> struct is_float<long double> : true_type {}; + #endif + + /// Type traits for floating point bits. + template<typename T> struct bits { typedef unsigned char type; }; + template<typename T> struct bits<const T> : bits<T> {}; + template<typename T> struct bits<volatile T> : bits<T> {}; + template<typename T> struct bits<const volatile T> : bits<T> {}; + + #if HALF_ENABLE_CPP11_CSTDINT + /// Unsigned integer of (at least) 16 bits width. + typedef std::uint_least16_t uint16; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits<float> { typedef std::uint_least32_t type; }; + + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits<double> { typedef std::uint_least64_t type; }; + #else + /// Unsigned integer of (at least) 16 bits width. + typedef unsigned short uint16; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits<float> : conditional<std::numeric_limits<unsigned int>::digits>=32,unsigned int,unsigned long> {}; + + #if HALF_ENABLE_CPP11_LONG_LONG + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits<double> : conditional<std::numeric_limits<unsigned long>::digits>=64,unsigned long,unsigned long long> {}; + #else + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits<double> { typedef unsigned long type; }; + #endif + #endif + + /// Tag type for binary construction. + struct binary_t {}; + + /// Tag for binary construction. + HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + + /// Temporary half-precision expression. + /// This class represents a half-precision expression which just stores a single-precision value internally. + struct expr + { + /// Conversion constructor. + /// \param f single-precision value to convert + explicit HALF_CONSTEXPR expr(float f) HALF_NOEXCEPT : value_(f) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + HALF_CONSTEXPR operator float() const HALF_NOEXCEPT { return value_; } + + private: + /// Internal expression value stored in single-precision. + float value_; + }; + + /// SFINAE helper for generic half-precision functions. + /// This class template has to be specialized for each valid combination of argument types to provide a corresponding + /// `type` member equivalent to \a T. + /// \tparam T type to return + template<typename T,typename,typename=void,typename=void> struct enable {}; + template<typename T> struct enable<T,half,void,void> { typedef T type; }; + template<typename T> struct enable<T,expr,void,void> { typedef T type; }; + template<typename T> struct enable<T,half,half,void> { typedef T type; }; + template<typename T> struct enable<T,half,expr,void> { typedef T type; }; + template<typename T> struct enable<T,expr,half,void> { typedef T type; }; + template<typename T> struct enable<T,expr,expr,void> { typedef T type; }; + template<typename T> struct enable<T,half,half,half> { typedef T type; }; + template<typename T> struct enable<T,half,half,expr> { typedef T type; }; + template<typename T> struct enable<T,half,expr,half> { typedef T type; }; + template<typename T> struct enable<T,half,expr,expr> { typedef T type; }; + template<typename T> struct enable<T,expr,half,half> { typedef T type; }; + template<typename T> struct enable<T,expr,half,expr> { typedef T type; }; + template<typename T> struct enable<T,expr,expr,half> { typedef T type; }; + template<typename T> struct enable<T,expr,expr,expr> { typedef T type; }; + + /// Return type for specialized generic 2-argument half-precision functions. + /// This class template has to be specialized for each valid combination of argument types to provide a corresponding + /// `type` member denoting the appropriate return type. + /// \tparam T first argument type + /// \tparam U first argument type + template<typename T,typename U> struct result : enable<expr,T,U> {}; + template<> struct result<half,half> { typedef half type; }; + + /// \name Classification helpers + /// \{ + + /// Check for infinity. + /// \tparam T argument type (builtin floating point type) + /// \param arg value to query + /// \retval true if infinity + /// \retval false else + template<typename T> bool builtin_isinf(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); + #elif defined(_MSC_VER) + return !::_finite(static_cast<double>(arg)) && !::_isnan(static_cast<double>(arg)); + #else + return arg == std::numeric_limits<T>::infinity() || arg == -std::numeric_limits<T>::infinity(); + #endif + } + + /// Check for NaN. + /// \tparam T argument type (builtin floating point type) + /// \param arg value to query + /// \retval true if not a number + /// \retval false else + template<typename T> bool builtin_isnan(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); + #elif defined(_MSC_VER) + return ::_isnan(static_cast<double>(arg)) != 0; + #else + return arg != arg; + #endif + } + + /// Check sign. + /// \tparam T argument type (builtin floating point type) + /// \param arg value to query + /// \retval true if signbit set + /// \retval false else + template<typename T> bool builtin_signbit(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); + #else + return arg < T() || (arg == T() && T(1)/arg < T()); + #endif + } + + /// \} + /// \name Conversion + /// \{ + + /// Convert IEEE single-precision to half-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \param value single-precision value + /// \return binary representation of half-precision value + template<std::float_round_style R> uint16 float2half_impl(float value, true_type) + { + typedef bits<float>::type uint32; + uint32 bits;// = *reinterpret_cast<uint32*>(&value); //violating strict aliasing! + std::memcpy(&bits, &value, sizeof(float)); +/* uint16 hbits = (bits>>16) & 0x8000; + bits &= 0x7FFFFFFF; + int exp = bits >> 23; + if(exp == 255) + return hbits | 0x7C00 | (0x3FF&-static_cast<unsigned>((bits&0x7FFFFF)!=0)); + if(exp > 142) + { + if(R == std::round_toward_infinity) + return hbits | 0x7C00 - (hbits>>15); + if(R == std::round_toward_neg_infinity) + return hbits | 0x7BFF + (hbits>>15); + return hbits | 0x7BFF + (R!=std::round_toward_zero); + } + int g, s; + if(exp > 112) + { + g = (bits>>12) & 1; + s = (bits&0xFFF) != 0; + hbits |= ((exp-112)<<10) | ((bits>>13)&0x3FF); + } + else if(exp > 101) + { + int i = 125 - exp; + bits = (bits&0x7FFFFF) | 0x800000; + g = (bits>>i) & 1; + s = (bits&((1L<<i)-1)) != 0; + hbits |= bits >> (i+1); + } + else + { + g = 0; + s = bits != 0; + } + if(R == std::round_to_nearest) + #if HALF_ROUND_TIES_TO_EVEN + hbits += g & (s|hbits); + #else + hbits += g; + #endif + else if(R == std::round_toward_infinity) + hbits += ~(hbits>>15) & (s|g); + else if(R == std::round_toward_neg_infinity) + hbits += (hbits>>15) & (g|s); +*/ static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, + 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, + 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, + 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, + 0xC000, 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, + 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00 }; + static const unsigned char shift_table[512] = { + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13 }; + uint16 hbits = base_table[bits>>23] + static_cast<uint16>((bits&0x7FFFFF)>>shift_table[bits>>23]); + if(R == std::round_to_nearest) + hbits += (((bits&0x7FFFFF)>>(shift_table[bits>>23]-1))|(((bits>>23)&0xFF)==102)) & ((hbits&0x7C00)!=0x7C00) + #if HALF_ROUND_TIES_TO_EVEN + & (((((static_cast<uint32>(1)<<(shift_table[bits>>23]-1))-1)&bits)!=0)|hbits) + #endif + ; + else if(R == std::round_toward_zero) + hbits -= ((hbits&0x7FFF)==0x7C00) & ~shift_table[bits>>23]; + else if(R == std::round_toward_infinity) + hbits += ((((bits&0x7FFFFF&((static_cast<uint32>(1)<<(shift_table[bits>>23]))-1))!=0)|(((bits>>23)<=102)& + ((bits>>23)!=0)))&(hbits<0x7C00)) - ((hbits==0xFC00)&((bits>>23)!=511)); + else if(R == std::round_toward_neg_infinity) + hbits += ((((bits&0x7FFFFF&((static_cast<uint32>(1)<<(shift_table[bits>>23]))-1))!=0)|(((bits>>23)<=358)& + ((bits>>23)!=256)))&(hbits<0xFC00)&(hbits>>15)) - ((hbits==0x7C00)&((bits>>23)!=255)); + return hbits; + } + + /// Convert IEEE double-precision to half-precision. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \param value double-precision value + /// \return binary representation of half-precision value + template<std::float_round_style R> uint16 float2half_impl(double value, true_type) + { + typedef bits<float>::type uint32; + typedef bits<double>::type uint64; + uint64 bits;// = *reinterpret_cast<uint64*>(&value); //violating strict aliasing! + std::memcpy(&bits, &value, sizeof(double)); + uint32 hi = bits >> 32, lo = bits & 0xFFFFFFFF; + uint16 hbits = (hi>>16) & 0x8000; + hi &= 0x7FFFFFFF; + int exp = hi >> 20; + if(exp == 2047) + return hbits | 0x7C00 | (0x3FF&-static_cast<unsigned>((bits&0xFFFFFFFFFFFFF)!=0)); + if(exp > 1038) + { + if(R == std::round_toward_infinity) + return hbits | (0x7C00 - (hbits>>15)); + if(R == std::round_toward_neg_infinity) + return hbits | (0x7BFF + (hbits>>15)); + return hbits | (0x7BFF + (R!=std::round_toward_zero)); + } + int g, s = lo != 0; + if(exp > 1008) + { + g = (hi>>9) & 1; + s |= (hi&0x1FF) != 0; + hbits |= ((exp-1008)<<10) | ((hi>>10)&0x3FF); + } + else if(exp > 997) + { + int i = 1018 - exp; + hi = (hi&0xFFFFF) | 0x100000; + g = (hi>>i) & 1; + s |= (hi&((1L<<i)-1)) != 0; + hbits |= hi >> (i+1); + } + else + { + g = 0; + s |= hi != 0; + } + if(R == std::round_to_nearest) + #if HALF_ROUND_TIES_TO_EVEN + hbits += g & (s|hbits); + #else + hbits += g; + #endif + else if(R == std::round_toward_infinity) + hbits += ~(hbits>>15) & (s|g); + else if(R == std::round_toward_neg_infinity) + hbits += (hbits>>15) & (g|s); + return hbits; + } + + /// Convert non-IEEE floating point to half-precision. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T source type (builtin floating point type) + /// \param value floating point value + /// \return binary representation of half-precision value + template<std::float_round_style R,typename T> uint16 float2half_impl(T value, ...) + { + uint16 hbits = static_cast<unsigned>(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + { + if(R == std::round_toward_infinity) + return hbits | (0x7C00 - (hbits>>15)); + else if(R == std::round_toward_neg_infinity) + return hbits | (0x7BFF + (hbits>>15)); + return hbits | (0x7BFF + (R!=std::round_toward_zero)); + } + if(exp < -13) + value = std::ldexp(value, 24); + else + { + value = std::ldexp(value, 11-exp); + hbits |= ((exp+13)<<10); + } + T ival, frac = std::modf(value, &ival); + hbits += static_cast<uint16>(std::abs(static_cast<int>(ival))); + if(R == std::round_to_nearest) + { + frac = std::abs(frac); + #if HALF_ROUND_TIES_TO_EVEN + hbits += (frac>T(0.5)) | ((frac==T(0.5))&hbits); + #else + hbits += frac >= T(0.5); + #endif + } + else if(R == std::round_toward_infinity) + hbits += frac > T(); + else if(R == std::round_toward_neg_infinity) + hbits += frac < T(); + return hbits; + } + + /// Convert floating point to half-precision. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T source type (builtin floating point type) + /// \param value floating point value + /// \return binary representation of half-precision value + template<std::float_round_style R,typename T> uint16 float2half(T value) + { + return float2half_impl<R>(value, bool_type<std::numeric_limits<T>::is_iec559&&sizeof(typename bits<T>::type)==sizeof(T)>()); + } + + /// Convert integer to half-precision floating point. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam S `true` if value negative, `false` else + /// \tparam T type to convert (builtin integer type) + /// \param value non-negative integral value + /// \return binary representation of half-precision value + template<std::float_round_style R,bool S,typename T> uint16 int2half_impl(T value) + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_integral<T>::value, "int to half conversion only supports builtin integer types"); + #endif + if(S) + value = -value; + uint16 bits = S << 15; + if(value > 0xFFFF) + { + if(R == std::round_toward_infinity) + bits |= 0x7C00 - S; + else if(R == std::round_toward_neg_infinity) + bits |= 0x7BFF + S; + else + bits |= 0x7BFF + (R!=std::round_toward_zero); + } + else if(value) + { + unsigned int m = value, exp = 24; + for(; m<0x400; m<<=1,--exp) ; + for(; m>0x7FF; m>>=1,++exp) ; + bits |= (exp<<10) + m; + if(exp > 24) + { + if(R == std::round_to_nearest) + bits += (value>>(exp-25)) & 1 + #if HALF_ROUND_TIES_TO_EVEN + & (((((1<<(exp-25))-1)&value)!=0)|bits) + #endif + ; + else if(R == std::round_toward_infinity) + bits += ((value&((1<<(exp-24))-1))!=0) & !S; + else if(R == std::round_toward_neg_infinity) + bits += ((value&((1<<(exp-24))-1))!=0) & S; + } + } + return bits; + } + + /// Convert integer to half-precision floating point. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T type to convert (builtin integer type) + /// \param value integral value + /// \return binary representation of half-precision value + template<std::float_round_style R,typename T> uint16 int2half(T value) + { + return (value<0) ? int2half_impl<R,true>(value) : int2half_impl<R,false>(value); + } + + /// Convert half-precision to IEEE single-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \param value binary representation of half-precision value + /// \return single-precision value + inline float half2float_impl(uint16 value, float, true_type) + { + typedef bits<float>::type uint32; +/* uint32 bits = static_cast<uint32>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + bits |= 0x38000000 << static_cast<unsigned>(abs>=0x7C00); + for(; abs<0x400; abs<<=1,bits-=0x800000) ; + bits += static_cast<uint32>(abs) << 13; + } +*/ static const uint32 mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000, + 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, + 0x36000000, 0x36040000, 0x36080000, 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, + 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, 0x367C0000, + 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000, + 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, + 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000, + 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, + 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, + 0x37200000, 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000, + 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, + 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, + 0x37500000, 0x37510000, 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, + 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, + 0x37800000, 0x37808000, 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, + 0x37880000, 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, + 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, 0x37968000, 0x37970000, 0x37978000, + 0x37980000, 0x37988000, 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, + 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, + 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, + 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, + 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, + 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000, + 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, + 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000, + 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, + 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, + 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000, + 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, + 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, + 0x38080000, 0x38084000, 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, + 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, 0x38130000, 0x38134000, 0x38138000, 0x3813C000, + 0x38140000, 0x38144000, 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, + 0x38180000, 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, + 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, + 0x38200000, 0x38204000, 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, + 0x38240000, 0x38244000, 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, + 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, + 0x38300000, 0x38304000, 0x38308000, 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, + 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, 0x3837C000, + 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000, + 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, + 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, 0x38478000, 0x3847C000, + 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, + 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, + 0x38500000, 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, 0x3853C000, + 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, + 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, + 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, + 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, 0x38670000, 0x38674000, 0x38678000, 0x3867C000, + 0x38680000, 0x38684000, 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, + 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, + 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, 0x38734000, 0x38738000, 0x3873C000, + 0x38740000, 0x38744000, 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, + 0x38780000, 0x38784000, 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, + 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, + 0x38020000, 0x38022000, 0x38024000, 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, + 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, + 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000, + 0x38080000, 0x38082000, 0x38084000, 0x38086000, 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, + 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000, + 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, + 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, + 0x38120000, 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000, + 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, + 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, + 0x38180000, 0x38182000, 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, + 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, + 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, + 0x38200000, 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, + 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, + 0x38240000, 0x38242000, 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, + 0x38260000, 0x38262000, 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, + 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, + 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, + 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, + 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000, + 0x38320000, 0x38322000, 0x38324000, 0x38326000, 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, + 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000, + 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, + 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, + 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000, + 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, + 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, + 0x38420000, 0x38422000, 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, + 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, + 0x38480000, 0x38482000, 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, + 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, + 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, + 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, + 0x38500000, 0x38502000, 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, + 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, + 0x38560000, 0x38562000, 0x38564000, 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, + 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, + 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000, + 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, + 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000, + 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, + 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, + 0x38660000, 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000, + 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, + 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, + 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, + 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, + 0x38720000, 0x38722000, 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, + 0x38740000, 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, + 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, + 0x38780000, 0x38782000, 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, + 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, + 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000 }; + static const uint32 exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, 0x07000000, 0x07800000, + 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, + 0x80000000, 0x80800000, 0x81000000, 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, + 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000 }; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024 }; + uint32 bits = mantissa_table[offset_table[value>>10]+(value&0x3FF)] + exponent_table[value>>10]; +// return *reinterpret_cast<float*>(&bits); //violating strict aliasing! + float out; + std::memcpy(&out, &bits, sizeof(float)); + return out; + } + + /// Convert half-precision to IEEE double-precision. + /// \param value binary representation of half-precision value + /// \return double-precision value + inline double half2float_impl(uint16 value, double, true_type) + { + typedef bits<float>::type uint32; + typedef bits<double>::type uint64; + uint32 hi = static_cast<uint32>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast<unsigned>(abs>=0x7C00); + for(; abs<0x400; abs<<=1,hi-=0x100000) ; + hi += static_cast<uint32>(abs) << 10; + } + uint64 bits = static_cast<uint64>(hi) << 32; +// return *reinterpret_cast<double*>(&bits); //violating strict aliasing! + double out; + std::memcpy(&out, &bits, sizeof(double)); + return out; + } + + /// Convert half-precision to non-IEEE floating point. + /// \tparam T type to convert to (builtin integer type) + /// \param value binary representation of half-precision value + /// \return floating point value + template<typename T> T half2float_impl(uint16 value, T, ...) + { + T out; + int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = std::numeric_limits<T>::has_quiet_NaN ? std::numeric_limits<T>::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast<T>((abs&0x3FF)|0x400), (abs>>10)-25); + else + out = std::ldexp(static_cast<T>(abs), -24); + return (value&0x8000) ? -out : out; + } + + /// Convert half-precision to floating point. + /// \tparam T type to convert to (builtin integer type) + /// \param value binary representation of half-precision value + /// \return floating point value + template<typename T> T half2float(uint16 value) + { + return half2float_impl(value, T(), bool_type<std::numeric_limits<T>::is_iec559&&sizeof(typename bits<T>::type)==sizeof(T)>()); + } + + /// Convert half-precision floating point to integer. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value binary representation of half-precision value + /// \return integral value + template<std::float_round_style R,bool E,typename T> T half2int_impl(uint16 value) + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_integral<T>::value, "half to int conversion only supports builtin integer types"); + #endif + unsigned int e = value & 0x7FFF; + if(e >= 0x7C00) + return (value&0x8000) ? std::numeric_limits<T>::min() : std::numeric_limits<T>::max(); + if(e < 0x3800) + { + if(R == std::round_toward_infinity) + return T(~(value>>15)&(e!=0)); + else if(R == std::round_toward_neg_infinity) + return -T(value>0x8000); + return T(); + } + unsigned int m = (value&0x3FF) | 0x400; + e >>= 10; + if(e < 25) + { + if(R == std::round_to_nearest) + m += (1<<(24-e)) - (~(m>>(25-e))&E); + else if(R == std::round_toward_infinity) + m += ((value>>15)-1) & ((1<<(25-e))-1U); + else if(R == std::round_toward_neg_infinity) + m += -(value>>15) & ((1<<(25-e))-1U); + m >>= 25 - e; + } + else + m <<= e - 25; + return (value&0x8000) ? -static_cast<T>(m) : static_cast<T>(m); + } + + /// Convert half-precision floating point to integer. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value binary representation of half-precision value + /// \return integral value + template<std::float_round_style R,typename T> T half2int(uint16 value) { return half2int_impl<R,HALF_ROUND_TIES_TO_EVEN,T>(value); } + + /// Convert half-precision floating point to integer using round-to-nearest-away-from-zero. + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value binary representation of half-precision value + /// \return integral value + template<typename T> T half2int_up(uint16 value) { return half2int_impl<std::round_to_nearest,0,T>(value); } + + /// Round half-precision number to nearest integer value. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \tparam E `true` for round to even, `false` for round away from zero + /// \param value binary representation of half-precision value + /// \return half-precision bits for nearest integral value + template<std::float_round_style R,bool E> uint16 round_half_impl(uint16 value) + { + unsigned int e = value & 0x7FFF; + uint16 result = value; + if(e < 0x3C00) + { + result &= 0x8000; + if(R == std::round_to_nearest) + result |= 0x3C00U & -(e>=(0x3800+E)); + else if(R == std::round_toward_infinity) + result |= 0x3C00U & -(~(value>>15)&(e!=0)); + else if(R == std::round_toward_neg_infinity) + result |= 0x3C00U & -(value>0x8000); + } + else if(e < 0x6400) + { + e = 25 - (e>>10); + unsigned int mask = (1<<e) - 1; + if(R == std::round_to_nearest) + result += (1<<(e-1)) - (~(result>>e)&E); + else if(R == std::round_toward_infinity) + result += mask & ((value>>15)-1); + else if(R == std::round_toward_neg_infinity) + result += mask & -(value>>15); + result &= ~mask; + } + return result; + } + + /// Round half-precision number to nearest integer value. + /// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding + /// \param value binary representation of half-precision value + /// \return half-precision bits for nearest integral value + template<std::float_round_style R> uint16 round_half(uint16 value) { return round_half_impl<R,HALF_ROUND_TIES_TO_EVEN>(value); } + + /// Round half-precision number to nearest integer value using round-to-nearest-away-from-zero. + /// \param value binary representation of half-precision value + /// \return half-precision bits for nearest integral value + inline uint16 round_half_up(uint16 value) { return round_half_impl<std::round_to_nearest,0>(value); } + /// \} + + struct functions; + template<typename> struct unary_specialized; + template<typename,typename> struct binary_specialized; + template<typename,typename,std::float_round_style> struct half_caster; + } + + /// Half-precision floating point type. + /// This class implements an IEEE-conformant half-precision floating point type with the usual arithmetic operators and + /// conversions. It is implicitly convertible to single-precision floating point, which makes artihmetic expressions and + /// functions with mixed-type operands to be of the most precise operand type. Additionally all arithmetic operations + /// (and many mathematical functions) are carried out in single-precision internally. All conversions from single- to + /// half-precision are done using the library's default rounding mode, but temporary results inside chained arithmetic + /// expressions are kept in single-precision as long as possible (while of course still maintaining a strong half-precision type). + /// + /// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and + /// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which + /// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the + /// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of + /// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most + /// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit + /// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if + /// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this should be the case on + /// nearly any reasonable platform. + /// + /// So if your C++ implementation is not totally exotic or imposes special alignment requirements, it is a reasonable + /// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE representation. + class half + { + friend struct detail::functions; + friend struct detail::unary_specialized<half>; + friend struct detail::binary_specialized<half,half>; + template<typename,typename,std::float_round_style> friend struct detail::half_caster; + friend class std::numeric_limits<half>; + #if HALF_ENABLE_CPP11_HASH + friend struct std::hash<half>; + #endif + #if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator"" _h(long double); + #endif + + public: + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Copy constructor. + /// \tparam T type of concrete half expression + /// \param rhs half expression to copy from + half(detail::expr rhs) : data_(detail::float2half<round_style>(static_cast<float>(rhs))) {} + + /// Conversion constructor. + /// \param rhs float to convert + explicit half(float rhs) : data_(detail::float2half<round_style>(rhs)) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float<float>(data_); } + + /// Assignment operator. + /// \tparam T type of concrete half expression + /// \param rhs half expression to copy from + /// \return reference to this half + half& operator=(detail::expr rhs) { return *this = static_cast<float>(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + template<typename T> typename detail::enable<half&,T>::type operator+=(T rhs) { return *this += static_cast<float>(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + template<typename T> typename detail::enable<half&,T>::type operator-=(T rhs) { return *this -= static_cast<float>(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + template<typename T> typename detail::enable<half&,T>::type operator*=(T rhs) { return *this *= static_cast<float>(rhs); } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + template<typename T> typename detail::enable<half&,T>::type operator/=(T rhs) { return *this /= static_cast<float>(rhs); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + half& operator=(float rhs) { data_ = detail::float2half<round_style>(rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + half& operator+=(float rhs) { data_ = detail::float2half<round_style>(detail::half2float<float>(data_)+rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + half& operator-=(float rhs) { data_ = detail::float2half<round_style>(detail::half2float<float>(data_)-rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + half& operator*=(float rhs) { data_ = detail::float2half<round_style>(detail::half2float<float>(data_)*rhs); return *this; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + half& operator/=(float rhs) { data_ = detail::float2half<round_style>(detail::half2float<float>(data_)/rhs); return *this; } + + /// Prefix increment. + /// \return incremented half value + half& operator++() { return *this += 1.0f; } + + /// Prefix decrement. + /// \return decremented half value + half& operator--() { return *this -= 1.0f; } + + /// Postfix increment. + /// \return non-incremented half value + half operator++(int) { half out(*this); ++*this; return out; } + + /// Postfix decrement. + /// \return non-decremented half value + half operator--(int) { half out(*this); --*this; return out; } + + private: + /// Rounding mode to use + static const std::float_round_style round_style = static_cast<std::float_round_style>(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, detail::uint16 bits) HALF_NOEXCEPT : data_(bits) {} + + /// Internal binary representation + detail::uint16 data_; + }; + +#if HALF_ENABLE_CPP11_USER_LITERALS + namespace literal + { + /// Half literal. + /// While this returns an actual half-precision value, half literals can unfortunately not be constant expressions due + /// to rather involved conversions. + /// \param value literal value + /// \return half with given value (if representable) + inline half operator"" _h(long double value) { return half(detail::binary, detail::float2half<half::round_style>(value)); } + } +#endif + + namespace detail + { + /// Wrapper implementing unspecialized half-precision functions. + struct functions + { + /// Addition implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision sum stored in single-precision + static expr plus(float x, float y) { return expr(x+y); } + + /// Subtraction implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision difference stored in single-precision + static expr minus(float x, float y) { return expr(x-y); } + + /// Multiplication implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision product stored in single-precision + static expr multiplies(float x, float y) { return expr(x*y); } + + /// Division implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision quotient stored in single-precision + static expr divides(float x, float y) { return expr(x/y); } + + /// Output implementation. + /// \param out stream to write to + /// \param arg value to write + /// \return reference to stream + template<typename charT,typename traits> static std::basic_ostream<charT,traits>& write(std::basic_ostream<charT,traits> &out, float arg) { return out << arg; } + + /// Input implementation. + /// \param in stream to read from + /// \param arg half to read into + /// \return reference to stream + template<typename charT,typename traits> static std::basic_istream<charT,traits>& read(std::basic_istream<charT,traits> &in, half &arg) + { + float f; + if(in >> f) + arg = f; + return in; + } + + /// Modulo implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision division remainder stored in single-precision + static expr fmod(float x, float y) { return expr(std::fmod(x, y)); } + + /// Remainder implementation. + /// \param x first operand + /// \param y second operand + /// \return Half-precision division remainder stored in single-precision + static expr remainder(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::remainder(x, y)); + #else + if(builtin_isnan(x) || builtin_isnan(y)) + return expr(std::numeric_limits<float>::quiet_NaN()); + float ax = std::fabs(x), ay = std::fabs(y); + if(ax >= 65536.0f || ay < std::ldexp(1.0f, -24)) + return expr(std::numeric_limits<float>::quiet_NaN()); + if(ay >= 65536.0f) + return expr(x); + if(ax == ay) + return expr(builtin_signbit(x) ? -0.0f : 0.0f); + ax = std::fmod(ax, ay+ay); + float y2 = 0.5f * ay; + if(ax > y2) + { + ax -= ay; + if(ax >= y2) + ax -= ay; + } + return expr(builtin_signbit(x) ? -ax : ax); + #endif + } + + /// Remainder implementation. + /// \param x first operand + /// \param y second operand + /// \param quo address to store quotient bits at + /// \return Half-precision division remainder stored in single-precision + static expr remquo(float x, float y, int *quo) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::remquo(x, y, quo)); + #else + if(builtin_isnan(x) || builtin_isnan(y)) + return expr(std::numeric_limits<float>::quiet_NaN()); + bool sign = builtin_signbit(x), qsign = static_cast<bool>(sign^builtin_signbit(y)); + float ax = std::fabs(x), ay = std::fabs(y); + if(ax >= 65536.0f || ay < std::ldexp(1.0f, -24)) + return expr(std::numeric_limits<float>::quiet_NaN()); + if(ay >= 65536.0f) + return expr(x); + if(ax == ay) + return *quo = qsign ? -1 : 1, expr(sign ? -0.0f : 0.0f); + ax = std::fmod(ax, 8.0f*ay); + int cquo = 0; + if(ax >= 4.0f * ay) + { + ax -= 4.0f * ay; + cquo += 4; + } + if(ax >= 2.0f * ay) + { + ax -= 2.0f * ay; + cquo += 2; + } + float y2 = 0.5f * ay; + if(ax > y2) + { + ax -= ay; + ++cquo; + if(ax >= y2) + { + ax -= ay; + ++cquo; + } + } + return *quo = qsign ? -cquo : cquo, expr(sign ? -ax : ax); + #endif + } + + /// Positive difference implementation. + /// \param x first operand + /// \param y second operand + /// \return Positive difference stored in single-precision + static expr fdim(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::fdim(x, y)); + #else + return expr((x<=y) ? 0.0f : (x-y)); + #endif + } + + /// Fused multiply-add implementation. + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return \a x * \a y + \a z stored in single-precision + static expr fma(float x, float y, float z) + { + #if HALF_ENABLE_CPP11_CMATH && defined(FP_FAST_FMAF) + return expr(std::fma(x, y, z)); + #else + return expr(x*y+z); + #endif + } + + /// Get NaN. + /// \return Half-precision quiet NaN + static half nanh() { return half(binary, 0x7FFF); } + + /// Exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr exp(float arg) { return expr(std::exp(arg)); } + + /// Exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr expm1(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::expm1(arg)); + #else + return expr(static_cast<float>(std::exp(static_cast<double>(arg))-1.0)); + #endif + } + + /// Binary exponential implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr exp2(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::exp2(arg)); + #else + return expr(static_cast<float>(std::exp(arg*0.69314718055994530941723212145818))); + #endif + } + + /// Logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log(float arg) { return expr(std::log(arg)); } + + /// Common logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log10(float arg) { return expr(std::log10(arg)); } + + /// Logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log1p(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::log1p(arg)); + #else + return expr(static_cast<float>(std::log(1.0+arg))); + #endif + } + + /// Binary logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr log2(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::log2(arg)); + #else + return expr(static_cast<float>(std::log(static_cast<double>(arg))*1.4426950408889634073599246810019)); + #endif + } + + /// Square root implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr sqrt(float arg) { return expr(std::sqrt(arg)); } + + /// Cubic root implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr cbrt(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::cbrt(arg)); + #else + if(builtin_isnan(arg) || builtin_isinf(arg)) + return expr(arg); + return expr(builtin_signbit(arg) ? -static_cast<float>(std::pow(-static_cast<double>(arg), 1.0/3.0)) : + static_cast<float>(std::pow(static_cast<double>(arg), 1.0/3.0))); + #endif + } + + /// Hypotenuse implementation. + /// \param x first argument + /// \param y second argument + /// \return function value stored in single-preicision + static expr hypot(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::hypot(x, y)); + #else + return expr((builtin_isinf(x) || builtin_isinf(y)) ? std::numeric_limits<float>::infinity() : + static_cast<float>(std::sqrt(static_cast<double>(x)*x+static_cast<double>(y)*y))); + #endif + } + + /// Power implementation. + /// \param base value to exponentiate + /// \param exp power to expontiate to + /// \return function value stored in single-preicision + static expr pow(float base, float exp) { return expr(std::pow(base, exp)); } + + /// Sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr sin(float arg) { return expr(std::sin(arg)); } + + /// Cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr cos(float arg) { return expr(std::cos(arg)); } + + /// Tan implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr tan(float arg) { return expr(std::tan(arg)); } + + /// Arc sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr asin(float arg) { return expr(std::asin(arg)); } + + /// Arc cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr acos(float arg) { return expr(std::acos(arg)); } + + /// Arc tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr atan(float arg) { return expr(std::atan(arg)); } + + /// Arc tangent implementation. + /// \param x first argument + /// \param y second argument + /// \return function value stored in single-preicision + static expr atan2(float x, float y) { return expr(std::atan2(x, y)); } + + /// Hyperbolic sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr sinh(float arg) { return expr(std::sinh(arg)); } + + /// Hyperbolic cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr cosh(float arg) { return expr(std::cosh(arg)); } + + /// Hyperbolic tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr tanh(float arg) { return expr(std::tanh(arg)); } + + /// Hyperbolic area sine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr asinh(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::asinh(arg)); + #else + return expr((arg==-std::numeric_limits<float>::infinity()) ? arg : static_cast<float>(std::log(arg+std::sqrt(arg*arg+1.0)))); + #endif + } + + /// Hyperbolic area cosine implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr acosh(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::acosh(arg)); + #else + return expr((arg<-1.0f) ? std::numeric_limits<float>::quiet_NaN() : static_cast<float>(std::log(arg+std::sqrt(arg*arg-1.0)))); + #endif + } + + /// Hyperbolic area tangent implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr atanh(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::atanh(arg)); + #else + return expr(static_cast<float>(0.5*std::log((1.0+arg)/(1.0-arg)))); + #endif + } + + /// Error function implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr erf(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::erf(arg)); + #else + return expr(static_cast<float>(erf(static_cast<double>(arg)))); + #endif + } + + /// Complementary implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr erfc(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::erfc(arg)); + #else + return expr(static_cast<float>(1.0-erf(static_cast<double>(arg)))); + #endif + } + + /// Gamma logarithm implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr lgamma(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::lgamma(arg)); + #else + if(builtin_isinf(arg)) + return expr(std::numeric_limits<float>::infinity()); + if(arg < 0.0f) + { + float i, f = std::modf(-arg, &i); + if(f == 0.0f) + return expr(std::numeric_limits<float>::infinity()); + return expr(static_cast<float>(1.1447298858494001741434273513531- + std::log(std::abs(std::sin(3.1415926535897932384626433832795*f)))-lgamma(1.0-arg))); + } + return expr(static_cast<float>(lgamma(static_cast<double>(arg)))); + #endif + } + + /// Gamma implementation. + /// \param arg function argument + /// \return function value stored in single-preicision + static expr tgamma(float arg) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::tgamma(arg)); + #else + if(arg == 0.0f) + return builtin_signbit(arg) ? expr(-std::numeric_limits<float>::infinity()) : expr(std::numeric_limits<float>::infinity()); + if(arg < 0.0f) + { + float i, f = std::modf(-arg, &i); + if(f == 0.0f) + return expr(std::numeric_limits<float>::quiet_NaN()); + double value = 3.1415926535897932384626433832795 / (std::sin(3.1415926535897932384626433832795*f)*std::exp(lgamma(1.0-arg))); + return expr(static_cast<float>((std::fmod(i, 2.0f)==0.0f) ? -value : value)); + } + if(builtin_isinf(arg)) + return expr(arg); + return expr(static_cast<float>(std::exp(lgamma(static_cast<double>(arg))))); + #endif + } + + /// Floor implementation. + /// \param arg value to round + /// \return rounded value + static half floor(half arg) { return half(binary, round_half<std::round_toward_neg_infinity>(arg.data_)); } + + /// Ceiling implementation. + /// \param arg value to round + /// \return rounded value + static half ceil(half arg) { return half(binary, round_half<std::round_toward_infinity>(arg.data_)); } + + /// Truncation implementation. + /// \param arg value to round + /// \return rounded value + static half trunc(half arg) { return half(binary, round_half<std::round_toward_zero>(arg.data_)); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static half round(half arg) { return half(binary, round_half_up(arg.data_)); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long lround(half arg) { return detail::half2int_up<long>(arg.data_); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static half rint(half arg) { return half(binary, round_half<half::round_style>(arg.data_)); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long lrint(half arg) { return detail::half2int<half::round_style,long>(arg.data_); } + + #if HALF_ENABLE_CPP11_LONG_LONG + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long long llround(half arg) { return detail::half2int_up<long long>(arg.data_); } + + /// Nearest integer implementation. + /// \param arg value to round + /// \return rounded value + static long long llrint(half arg) { return detail::half2int<half::round_style,long long>(arg.data_); } + #endif + + /// Decompression implementation. + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return normalized significant + static half frexp(half arg, int *exp) + { + int m = arg.data_ & 0x7FFF, e = -14; + if(m >= 0x7C00 || !m) + return *exp = 0, arg; + for(; m<0x400; m<<=1,--e) ; + return *exp = e+(m>>10), half(binary, (arg.data_&0x8000)|0x3800|(m&0x3FF)); + } + + /// Decompression implementation. + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part + static half modf(half arg, half *iptr) + { + unsigned int e = arg.data_ & 0x7FFF; + if(e >= 0x6400) + return *iptr = arg, half(binary, arg.data_&(0x8000U|-(e>0x7C00))); + if(e < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + e >>= 10; + unsigned int mask = (1<<(25-e)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(binary, arg.data_&0x8000); + for(; m<0x400; m<<=1,--e) ; + return half(binary, static_cast<uint16>((arg.data_&0x8000)|(e<<10)|(m&0x3FF))); + } + + /// Scaling implementation. + /// \param arg number to scale + /// \param exp power of two to scale by + /// \return scaled number + static half scalbln(half arg, long exp) + { + unsigned int m = arg.data_ & 0x7FFF; + if(m >= 0x7C00 || !m) + return arg; + for(; m<0x400; m<<=1,--exp) ; + exp += m >> 10; + uint16 value = arg.data_ & 0x8000; + if(exp > 30) + { + if(half::round_style == std::round_toward_zero) + value |= 0x7BFF; + else if(half::round_style == std::round_toward_infinity) + value |= 0x7C00 - (value>>15); + else if(half::round_style == std::round_toward_neg_infinity) + value |= 0x7BFF + (value>>15); + else + value |= 0x7C00; + } + else if(exp > 0) + value |= (exp<<10) | (m&0x3FF); + else if(exp > -11) + { + m = (m&0x3FF) | 0x400; + if(half::round_style == std::round_to_nearest) + { + m += 1 << -exp; + #if HALF_ROUND_TIES_TO_EVEN + m -= (m>>(1-exp)) & 1; + #endif + } + else if(half::round_style == std::round_toward_infinity) + m += ((value>>15)-1) & ((1<<(1-exp))-1U); + else if(half::round_style == std::round_toward_neg_infinity) + m += -(value>>15) & ((1<<(1-exp))-1U); + value |= m >> (1-exp); + } + else if(half::round_style == std::round_toward_infinity) + value -= (value>>15) - 1; + else if(half::round_style == std::round_toward_neg_infinity) + value += value >> 15; + return half(binary, value); + } + + /// Exponent implementation. + /// \param arg number to query + /// \return floating point exponent + static int ilogb(half arg) + { + int abs = arg.data_ & 0x7FFF; + if(!abs) + return FP_ILOGB0; + if(abs < 0x7C00) + { + int exp = (abs>>10) - 15; + if(abs < 0x400) + for(; abs<0x200; abs<<=1,--exp) ; + return exp; + } + if(abs > 0x7C00) + return FP_ILOGBNAN; + return INT_MAX; + } + + /// Exponent implementation. + /// \param arg number to query + /// \return floating point exponent + static half logb(half arg) + { + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(binary, 0xFC00); + if(abs < 0x7C00) + { + int exp = (abs>>10) - 15; + if(abs < 0x400) + for(; abs<0x200; abs<<=1,--exp) ; + uint16 bits = (exp<0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6, e = 18; + for(; m<0x400; m<<=1,--e) ; + bits |= (e<<10) + m; + } + return half(binary, bits); + } + if(abs > 0x7C00) + return arg; + return half(binary, 0x7C00); + } + + /// Enumeration implementation. + /// \param from number to increase/decrease + /// \param to direction to enumerate into + /// \return next representable number + static half nextafter(half from, half to) + { + uint16 fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00) + return from; + if(tabs > 0x7C00 || from.data_ == to.data_ || !(fabs|tabs)) + return to; + if(!fabs) + return half(binary, (to.data_&0x8000)+1); + bool lt = ((fabs==from.data_) ? static_cast<int>(fabs) : -static_cast<int>(fabs)) < + ((tabs==to.data_) ? static_cast<int>(tabs) : -static_cast<int>(tabs)); + return half(binary, from.data_+(((from.data_>>15)^static_cast<unsigned>(lt))<<1)-1); + } + + /// Enumeration implementation. + /// \param from number to increase/decrease + /// \param to direction to enumerate into + /// \return next representable number + static half nexttoward(half from, long double to) + { + if(isnan(from)) + return from; + long double lfrom = static_cast<long double>(from); + if(builtin_isnan(to) || lfrom == to) + return half(static_cast<float>(to)); + if(!(from.data_&0x7FFF)) + return half(binary, (static_cast<detail::uint16>(builtin_signbit(to))<<15)+1); + return half(binary, from.data_+(((from.data_>>15)^static_cast<unsigned>(lfrom<to))<<1)-1); + } + + /// Sign implementation + /// \param x first operand + /// \param y second operand + /// \return composed value + static half copysign(half x, half y) { return half(binary, x.data_^((x.data_^y.data_)&0x8000)); } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if infinite number + /// \retval false else + static int fpclassify(half arg) + { + unsigned int abs = arg.data_ & 0x7FFF; + return abs ? ((abs>0x3FF) ? ((abs>=0x7C00) ? ((abs>0x7C00) ? FP_NAN : FP_INFINITE) : FP_NORMAL) :FP_SUBNORMAL) : FP_ZERO; + } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if finite number + /// \retval false else + static bool isfinite(half arg) { return (arg.data_&0x7C00) != 0x7C00; } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if infinite number + /// \retval false else + static bool isinf(half arg) { return (arg.data_&0x7FFF) == 0x7C00; } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if not a number + /// \retval false else + static bool isnan(half arg) { return (arg.data_&0x7FFF) > 0x7C00; } + + /// Classification implementation. + /// \param arg value to classify + /// \retval true if normal number + /// \retval false else + static bool isnormal(half arg) { return ((arg.data_&0x7C00)!=0) & ((arg.data_&0x7C00)!=0x7C00); } + + /// Sign bit implementation. + /// \param arg value to check + /// \retval true if signed + /// \retval false if unsigned + static bool signbit(half arg) { return (arg.data_&0x8000) != 0; } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + static bool isequal(half x, half y) { return (x.data_==y.data_ || !((x.data_|y.data_)&0x7FFF)) && !isnan(x); } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + static bool isnotequal(half x, half y) { return (x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF)) || isnan(x); } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x > \a y + /// \retval false else + static bool isgreater(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) > ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x >= \a y + /// \retval false else + static bool isgreaterequal(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) >= ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x < \a y + /// \retval false else + static bool isless(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) < ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x <= \a y + /// \retval false else + static bool islessequal(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + return xabs<=0x7C00 && yabs<=0x7C00 && (((xabs==x.data_) ? xabs : -xabs) <= ((yabs==y.data_) ? yabs : -yabs)); + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if either \a x > \a y nor \a x < \a y + /// \retval false else + static bool islessgreater(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + if(xabs > 0x7C00 || yabs > 0x7C00) + return false; + int a = (xabs==x.data_) ? xabs : -xabs, b = (yabs==y.data_) ? yabs : -yabs; + return a < b || a > b; + } + + /// Comparison implementation. + /// \param x first operand + /// \param y second operand + /// \retval true if operand unordered + /// \retval false else + static bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + + private: + static double erf(double arg) + { + if(builtin_isinf(arg)) + return (arg<0.0) ? -1.0 : 1.0; + double x2 = arg * arg, ax2 = 0.147 * x2, value = std::sqrt(1.0-std::exp(-x2*(1.2732395447351626861510701069801+ax2)/(1.0+ax2))); + return builtin_signbit(arg) ? -value : value; + } + + static double lgamma(double arg) + { + double v = 1.0; + for(; arg<8.0; ++arg) v *= arg; + double w = 1.0 / (arg*arg); + return (((((((-0.02955065359477124183006535947712*w+0.00641025641025641025641025641026)*w+ + -0.00191752691752691752691752691753)*w+8.4175084175084175084175084175084e-4)*w+ + -5.952380952380952380952380952381e-4)*w+7.9365079365079365079365079365079e-4)*w+ + -0.00277777777777777777777777777778)*w+0.08333333333333333333333333333333)/arg + + 0.91893853320467274178032973640562 - std::log(v) - arg + (arg-0.5) * std::log(arg); + } + }; + + /// Wrapper for unary half-precision functions needing specialization for individual argument types. + /// \tparam T argument type + template<typename T> struct unary_specialized + { + /// Negation implementation. + /// \param arg value to negate + /// \return negated value + static HALF_CONSTEXPR half negate(half arg) { return half(binary, arg.data_^0x8000); } + + /// Absolute value implementation. + /// \param arg function argument + /// \return absolute value + static half fabs(half arg) { return half(binary, arg.data_&0x7FFF); } + }; + template<> struct unary_specialized<expr> + { + static HALF_CONSTEXPR expr negate(float arg) { return expr(-arg); } + static expr fabs(float arg) { return expr(std::fabs(arg)); } + }; + + /// Wrapper for binary half-precision functions needing specialization for individual argument types. + /// \tparam T first argument type + /// \tparam U first argument type + template<typename T,typename U> struct binary_specialized + { + /// Minimum implementation. + /// \param x first operand + /// \param y second operand + /// \return minimum value + static expr fmin(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::fmin(x, y)); + #else + if(builtin_isnan(x)) + return expr(y); + if(builtin_isnan(y)) + return expr(x); + return expr(std::min(x, y)); + #endif + } + + /// Maximum implementation. + /// \param x first operand + /// \param y second operand + /// \return maximum value + static expr fmax(float x, float y) + { + #if HALF_ENABLE_CPP11_CMATH + return expr(std::fmax(x, y)); + #else + if(builtin_isnan(x)) + return expr(y); + if(builtin_isnan(y)) + return expr(x); + return expr(std::max(x, y)); + #endif + } + }; + template<> struct binary_specialized<half,half> + { + static half fmin(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + if(xabs > 0x7C00) + return y; + if(yabs > 0x7C00) + return x; + return (((xabs==x.data_) ? xabs : -xabs) > ((yabs==y.data_) ? yabs : -yabs)) ? y : x; + } + static half fmax(half x, half y) + { + int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF; + if(xabs > 0x7C00) + return y; + if(yabs > 0x7C00) + return x; + return (((xabs==x.data_) ? xabs : -xabs) < ((yabs==y.data_) ? yabs : -yabs)) ? y : x; + } + }; + + /// Helper class for half casts. + /// This class template has to be specialized for all valid cast argument to define an appropriate static `cast` member + /// function and a corresponding `type` member denoting its return type. + /// \tparam T destination type + /// \tparam U source type + /// \tparam R rounding mode to use + template<typename T,typename U,std::float_round_style R=static_cast<std::float_round_style>(HALF_ROUND_STYLE)> struct half_caster {}; + template<typename U,std::float_round_style R> struct half_caster<half,U,R> + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic<U>::value, "half_cast from non-arithmetic type unsupported"); + #endif + + static half cast(U arg) { return cast_impl(arg, is_float<U>()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half<R>(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half<R>(arg)); } + }; + template<typename T,std::float_round_style R> struct half_caster<T,half,R> + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic<T>::value, "half_cast to non-arithmetic type unsupported"); + #endif + + static T cast(half arg) { return cast_impl(arg, is_float<T>()); } + + private: + static T cast_impl(half arg, true_type) { return half2float<T>(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int<R,T>(arg.data_); } + }; + template<typename T,std::float_round_style R> struct half_caster<T,expr,R> + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic<T>::value, "half_cast to non-arithmetic type unsupported"); + #endif + + static T cast(expr arg) { return cast_impl(arg, is_float<T>()); } + + private: + static T cast_impl(float arg, true_type) { return static_cast<T>(arg); } + static T cast_impl(half arg, false_type) { return half2int<R,T>(arg.data_); } + }; + template<std::float_round_style R> struct half_caster<half,half,R> + { + static half cast(half arg) { return arg; } + }; + template<std::float_round_style R> struct half_caster<half,expr,R> : half_caster<half,half,R> {}; + + /// \name Comparison operators + /// \{ + + /// Comparison for equality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + template<typename T,typename U> typename enable<bool,T,U>::type operator==(T x, U y) { return functions::isequal(x, y); } + + /// Comparison for inequality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + template<typename T,typename U> typename enable<bool,T,U>::type operator!=(T x, U y) { return functions::isnotequal(x, y); } + + /// Comparison for less than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + template<typename T,typename U> typename enable<bool,T,U>::type operator<(T x, U y) { return functions::isless(x, y); } + + /// Comparison for greater than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + template<typename T,typename U> typename enable<bool,T,U>::type operator>(T x, U y) { return functions::isgreater(x, y); } + + /// Comparison for less equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + template<typename T,typename U> typename enable<bool,T,U>::type operator<=(T x, U y) { return functions::islessequal(x, y); } + + /// Comparison for greater equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + template<typename T,typename U> typename enable<bool,T,U>::type operator>=(T x, U y) { return functions::isgreaterequal(x, y); } + + /// \} + /// \name Arithmetic operators + /// \{ + + /// Add halfs. + /// \param x left operand + /// \param y right operand + /// \return sum of half expressions + template<typename T,typename U> typename enable<expr,T,U>::type operator+(T x, U y) { return functions::plus(x, y); } + + /// Subtract halfs. + /// \param x left operand + /// \param y right operand + /// \return difference of half expressions + template<typename T,typename U> typename enable<expr,T,U>::type operator-(T x, U y) { return functions::minus(x, y); } + + /// Multiply halfs. + /// \param x left operand + /// \param y right operand + /// \return product of half expressions + template<typename T,typename U> typename enable<expr,T,U>::type operator*(T x, U y) { return functions::multiplies(x, y); } + + /// Divide halfs. + /// \param x left operand + /// \param y right operand + /// \return quotient of half expressions + template<typename T,typename U> typename enable<expr,T,U>::type operator/(T x, U y) { return functions::divides(x, y); } + + /// Identity. + /// \param arg operand + /// \return uncahnged operand + template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator+(T arg) { return arg; } + + /// Negation. + /// \param arg operand + /// \return negated operand + template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator-(T arg) { return unary_specialized<T>::negate(arg); } + + /// \} + /// \name Input and output + /// \{ + + /// Output operator. + /// \param out output stream to write into + /// \param arg half expression to write + /// \return reference to output stream + template<typename T,typename charT,typename traits> typename enable<std::basic_ostream<charT,traits>&,T>::type + operator<<(std::basic_ostream<charT,traits> &out, T arg) { return functions::write(out, arg); } + + /// Input operator. + /// \param in input stream to read from + /// \param arg half to read into + /// \return reference to input stream + template<typename charT,typename traits> std::basic_istream<charT,traits>& + operator>>(std::basic_istream<charT,traits> &in, half &arg) { return functions::read(in, arg); } + + /// \} + /// \name Basic mathematical operations + /// \{ + + /// Absolute value. + /// \param arg operand + /// \return absolute value of \a arg +// template<typename T> typename enable<T,T>::type abs(T arg) { return unary_specialized<T>::fabs(arg); } + inline half abs(half arg) { return unary_specialized<half>::fabs(arg); } + inline expr abs(expr arg) { return unary_specialized<expr>::fabs(arg); } + + /// Absolute value. + /// \param arg operand + /// \return absolute value of \a arg +// template<typename T> typename enable<T,T>::type fabs(T arg) { return unary_specialized<T>::fabs(arg); } + inline half fabs(half arg) { return unary_specialized<half>::fabs(arg); } + inline expr fabs(expr arg) { return unary_specialized<expr>::fabs(arg); } + + /// Remainder of division. + /// \param x first operand + /// \param y second operand + /// \return remainder of floating point division. +// template<typename T,typename U> typename enable<expr,T,U>::type fmod(T x, U y) { return functions::fmod(x, y); } + inline expr fmod(half x, half y) { return functions::fmod(x, y); } + inline expr fmod(half x, expr y) { return functions::fmod(x, y); } + inline expr fmod(expr x, half y) { return functions::fmod(x, y); } + inline expr fmod(expr x, expr y) { return functions::fmod(x, y); } + + /// Remainder of division. + /// \param x first operand + /// \param y second operand + /// \return remainder of floating point division. +// template<typename T,typename U> typename enable<expr,T,U>::type remainder(T x, U y) { return functions::remainder(x, y); } + inline expr remainder(half x, half y) { return functions::remainder(x, y); } + inline expr remainder(half x, expr y) { return functions::remainder(x, y); } + inline expr remainder(expr x, half y) { return functions::remainder(x, y); } + inline expr remainder(expr x, expr y) { return functions::remainder(x, y); } + + /// Remainder of division. + /// \param x first operand + /// \param y second operand + /// \param quo address to store some bits of quotient at + /// \return remainder of floating point division. +// template<typename T,typename U> typename enable<expr,T,U>::type remquo(T x, U y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(half x, half y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(half x, expr y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(expr x, half y, int *quo) { return functions::remquo(x, y, quo); } + inline expr remquo(expr x, expr y, int *quo) { return functions::remquo(x, y, quo); } + + /// Fused multiply add. + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return ( \a x * \a y ) + \a z rounded as one operation. +// template<typename T,typename U,typename V> typename enable<expr,T,U,V>::type fma(T x, U y, V z) { return functions::fma(x, y, z); } + inline expr fma(half x, half y, half z) { return functions::fma(x, y, z); } + inline expr fma(half x, half y, expr z) { return functions::fma(x, y, z); } + inline expr fma(half x, expr y, half z) { return functions::fma(x, y, z); } + inline expr fma(half x, expr y, expr z) { return functions::fma(x, y, z); } + inline expr fma(expr x, half y, half z) { return functions::fma(x, y, z); } + inline expr fma(expr x, half y, expr z) { return functions::fma(x, y, z); } + inline expr fma(expr x, expr y, half z) { return functions::fma(x, y, z); } + inline expr fma(expr x, expr y, expr z) { return functions::fma(x, y, z); } + + /// Maximum of half expressions. + /// \param x first operand + /// \param y second operand + /// \return maximum of operands +// template<typename T,typename U> typename result<T,U>::type fmax(T x, U y) { return binary_specialized<T,U>::fmax(x, y); } + inline half fmax(half x, half y) { return binary_specialized<half,half>::fmax(x, y); } + inline expr fmax(half x, expr y) { return binary_specialized<half,expr>::fmax(x, y); } + inline expr fmax(expr x, half y) { return binary_specialized<expr,half>::fmax(x, y); } + inline expr fmax(expr x, expr y) { return binary_specialized<expr,expr>::fmax(x, y); } + + /// Minimum of half expressions. + /// \param x first operand + /// \param y second operand + /// \return minimum of operands +// template<typename T,typename U> typename result<T,U>::type fmin(T x, U y) { return binary_specialized<T,U>::fmin(x, y); } + inline half fmin(half x, half y) { return binary_specialized<half,half>::fmin(x, y); } + inline expr fmin(half x, expr y) { return binary_specialized<half,expr>::fmin(x, y); } + inline expr fmin(expr x, half y) { return binary_specialized<expr,half>::fmin(x, y); } + inline expr fmin(expr x, expr y) { return binary_specialized<expr,expr>::fmin(x, y); } + + /// Positive difference. + /// \param x first operand + /// \param y second operand + /// \return \a x - \a y or 0 if difference negative +// template<typename T,typename U> typename enable<expr,T,U>::type fdim(T x, U y) { return functions::fdim(x, y); } + inline expr fdim(half x, half y) { return functions::fdim(x, y); } + inline expr fdim(half x, expr y) { return functions::fdim(x, y); } + inline expr fdim(expr x, half y) { return functions::fdim(x, y); } + inline expr fdim(expr x, expr y) { return functions::fdim(x, y); } + + /// Get NaN value. + /// \return quiet NaN + inline half nanh(const char*) { return functions::nanh(); } + + /// \} + /// \name Exponential functions + /// \{ + + /// Exponential function. + /// \param arg function argument + /// \return e raised to \a arg +// template<typename T> typename enable<expr,T>::type exp(T arg) { return functions::exp(arg); } + inline expr exp(half arg) { return functions::exp(arg); } + inline expr exp(expr arg) { return functions::exp(arg); } + + /// Exponential minus one. + /// \param arg function argument + /// \return e raised to \a arg subtracted by 1 +// template<typename T> typename enable<expr,T>::type expm1(T arg) { return functions::expm1(arg); } + inline expr expm1(half arg) { return functions::expm1(arg); } + inline expr expm1(expr arg) { return functions::expm1(arg); } + + /// Binary exponential. + /// \param arg function argument + /// \return 2 raised to \a arg +// template<typename T> typename enable<expr,T>::type exp2(T arg) { return functions::exp2(arg); } + inline expr exp2(half arg) { return functions::exp2(arg); } + inline expr exp2(expr arg) { return functions::exp2(arg); } + + /// Natural logorithm. + /// \param arg function argument + /// \return logarithm of \a arg to base e +// template<typename T> typename enable<expr,T>::type log(T arg) { return functions::log(arg); } + inline expr log(half arg) { return functions::log(arg); } + inline expr log(expr arg) { return functions::log(arg); } + + /// Common logorithm. + /// \param arg function argument + /// \return logarithm of \a arg to base 10 +// template<typename T> typename enable<expr,T>::type log10(T arg) { return functions::log10(arg); } + inline expr log10(half arg) { return functions::log10(arg); } + inline expr log10(expr arg) { return functions::log10(arg); } + + /// Natural logorithm. + /// \param arg function argument + /// \return logarithm of \a arg plus 1 to base e +// template<typename T> typename enable<expr,T>::type log1p(T arg) { return functions::log1p(arg); } + inline expr log1p(half arg) { return functions::log1p(arg); } + inline expr log1p(expr arg) { return functions::log1p(arg); } + + /// Binary logorithm. + /// \param arg function argument + /// \return logarithm of \a arg to base 2 +// template<typename T> typename enable<expr,T>::type log2(T arg) { return functions::log2(arg); } + inline expr log2(half arg) { return functions::log2(arg); } + inline expr log2(expr arg) { return functions::log2(arg); } + + /// \} + /// \name Power functions + /// \{ + + /// Square root. + /// \param arg function argument + /// \return square root of \a arg +// template<typename T> typename enable<expr,T>::type sqrt(T arg) { return functions::sqrt(arg); } + inline expr sqrt(half arg) { return functions::sqrt(arg); } + inline expr sqrt(expr arg) { return functions::sqrt(arg); } + + /// Cubic root. + /// \param arg function argument + /// \return cubic root of \a arg +// template<typename T> typename enable<expr,T>::type cbrt(T arg) { return functions::cbrt(arg); } + inline expr cbrt(half arg) { return functions::cbrt(arg); } + inline expr cbrt(expr arg) { return functions::cbrt(arg); } + + /// Hypotenuse function. + /// \param x first argument + /// \param y second argument + /// \return square root of sum of squares without internal over- or underflows +// template<typename T,typename U> typename enable<expr,T,U>::type hypot(T x, U y) { return functions::hypot(x, y); } + inline expr hypot(half x, half y) { return functions::hypot(x, y); } + inline expr hypot(half x, expr y) { return functions::hypot(x, y); } + inline expr hypot(expr x, half y) { return functions::hypot(x, y); } + inline expr hypot(expr x, expr y) { return functions::hypot(x, y); } + + /// Power function. + /// \param base first argument + /// \param exp second argument + /// \return \a base raised to \a exp +// template<typename T,typename U> typename enable<expr,T,U>::type pow(T base, U exp) { return functions::pow(base, exp); } + inline expr pow(half base, half exp) { return functions::pow(base, exp); } + inline expr pow(half base, expr exp) { return functions::pow(base, exp); } + inline expr pow(expr base, half exp) { return functions::pow(base, exp); } + inline expr pow(expr base, expr exp) { return functions::pow(base, exp); } + + /// \} + /// \name Trigonometric functions + /// \{ + + /// Sine function. + /// \param arg function argument + /// \return sine value of \a arg +// template<typename T> typename enable<expr,T>::type sin(T arg) { return functions::sin(arg); } + inline expr sin(half arg) { return functions::sin(arg); } + inline expr sin(expr arg) { return functions::sin(arg); } + + /// Cosine function. + /// \param arg function argument + /// \return cosine value of \a arg +// template<typename T> typename enable<expr,T>::type cos(T arg) { return functions::cos(arg); } + inline expr cos(half arg) { return functions::cos(arg); } + inline expr cos(expr arg) { return functions::cos(arg); } + + /// Tangent function. + /// \param arg function argument + /// \return tangent value of \a arg +// template<typename T> typename enable<expr,T>::type tan(T arg) { return functions::tan(arg); } + inline expr tan(half arg) { return functions::tan(arg); } + inline expr tan(expr arg) { return functions::tan(arg); } + + /// Arc sine. + /// \param arg function argument + /// \return arc sine value of \a arg +// template<typename T> typename enable<expr,T>::type asin(T arg) { return functions::asin(arg); } + inline expr asin(half arg) { return functions::asin(arg); } + inline expr asin(expr arg) { return functions::asin(arg); } + + /// Arc cosine function. + /// \param arg function argument + /// \return arc cosine value of \a arg +// template<typename T> typename enable<expr,T>::type acos(T arg) { return functions::acos(arg); } + inline expr acos(half arg) { return functions::acos(arg); } + inline expr acos(expr arg) { return functions::acos(arg); } + + /// Arc tangent function. + /// \param arg function argument + /// \return arc tangent value of \a arg +// template<typename T> typename enable<expr,T>::type atan(T arg) { return functions::atan(arg); } + inline expr atan(half arg) { return functions::atan(arg); } + inline expr atan(expr arg) { return functions::atan(arg); } + + /// Arc tangent function. + /// \param x first argument + /// \param y second argument + /// \return arc tangent value +// template<typename T,typename U> typename enable<expr,T,U>::type atan2(T x, U y) { return functions::atan2(x, y); } + inline expr atan2(half x, half y) { return functions::atan2(x, y); } + inline expr atan2(half x, expr y) { return functions::atan2(x, y); } + inline expr atan2(expr x, half y) { return functions::atan2(x, y); } + inline expr atan2(expr x, expr y) { return functions::atan2(x, y); } + + /// \} + /// \name Hyperbolic functions + /// \{ + + /// Hyperbolic sine. + /// \param arg function argument + /// \return hyperbolic sine value of \a arg +// template<typename T> typename enable<expr,T>::type sinh(T arg) { return functions::sinh(arg); } + inline expr sinh(half arg) { return functions::sinh(arg); } + inline expr sinh(expr arg) { return functions::sinh(arg); } + + /// Hyperbolic cosine. + /// \param arg function argument + /// \return hyperbolic cosine value of \a arg +// template<typename T> typename enable<expr,T>::type cosh(T arg) { return functions::cosh(arg); } + inline expr cosh(half arg) { return functions::cosh(arg); } + inline expr cosh(expr arg) { return functions::cosh(arg); } + + /// Hyperbolic tangent. + /// \param arg function argument + /// \return hyperbolic tangent value of \a arg +// template<typename T> typename enable<expr,T>::type tanh(T arg) { return functions::tanh(arg); } + inline expr tanh(half arg) { return functions::tanh(arg); } + inline expr tanh(expr arg) { return functions::tanh(arg); } + + /// Hyperbolic area sine. + /// \param arg function argument + /// \return area sine value of \a arg +// template<typename T> typename enable<expr,T>::type asinh(T arg) { return functions::asinh(arg); } + inline expr asinh(half arg) { return functions::asinh(arg); } + inline expr asinh(expr arg) { return functions::asinh(arg); } + + /// Hyperbolic area cosine. + /// \param arg function argument + /// \return area cosine value of \a arg +// template<typename T> typename enable<expr,T>::type acosh(T arg) { return functions::acosh(arg); } + inline expr acosh(half arg) { return functions::acosh(arg); } + inline expr acosh(expr arg) { return functions::acosh(arg); } + + /// Hyperbolic area tangent. + /// \param arg function argument + /// \return area tangent value of \a arg +// template<typename T> typename enable<expr,T>::type atanh(T arg) { return functions::atanh(arg); } + inline expr atanh(half arg) { return functions::atanh(arg); } + inline expr atanh(expr arg) { return functions::atanh(arg); } + + /// \} + /// \name Error and gamma functions + /// \{ + + /// Error function. + /// \param arg function argument + /// \return error function value of \a arg +// template<typename T> typename enable<expr,T>::type erf(T arg) { return functions::erf(arg); } + inline expr erf(half arg) { return functions::erf(arg); } + inline expr erf(expr arg) { return functions::erf(arg); } + + /// Complementary error function. + /// \param arg function argument + /// \return 1 minus error function value of \a arg +// template<typename T> typename enable<expr,T>::type erfc(T arg) { return functions::erfc(arg); } + inline expr erfc(half arg) { return functions::erfc(arg); } + inline expr erfc(expr arg) { return functions::erfc(arg); } + + /// Natural logarithm of gamma function. + /// \param arg function argument + /// \return natural logarith of gamma function for \a arg +// template<typename T> typename enable<expr,T>::type lgamma(T arg) { return functions::lgamma(arg); } + inline expr lgamma(half arg) { return functions::lgamma(arg); } + inline expr lgamma(expr arg) { return functions::lgamma(arg); } + + /// Gamma function. + /// \param arg function argument + /// \return gamma function value of \a arg +// template<typename T> typename enable<expr,T>::type tgamma(T arg) { return functions::tgamma(arg); } + inline expr tgamma(half arg) { return functions::tgamma(arg); } + inline expr tgamma(expr arg) { return functions::tgamma(arg); } + + /// \} + /// \name Rounding + /// \{ + + /// Nearest integer not less than half value. + /// \param arg half to round + /// \return nearest integer not less than \a arg +// template<typename T> typename enable<half,T>::type ceil(T arg) { return functions::ceil(arg); } + inline half ceil(half arg) { return functions::ceil(arg); } + inline half ceil(expr arg) { return functions::ceil(arg); } + + /// Nearest integer not greater than half value. + /// \param arg half to round + /// \return nearest integer not greater than \a arg +// template<typename T> typename enable<half,T>::type floor(T arg) { return functions::floor(arg); } + inline half floor(half arg) { return functions::floor(arg); } + inline half floor(expr arg) { return functions::floor(arg); } + + /// Nearest integer not greater in magnitude than half value. + /// \param arg half to round + /// \return nearest integer not greater in magnitude than \a arg +// template<typename T> typename enable<half,T>::type trunc(T arg) { return functions::trunc(arg); } + inline half trunc(half arg) { return functions::trunc(arg); } + inline half trunc(expr arg) { return functions::trunc(arg); } + + /// Nearest integer. + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases +// template<typename T> typename enable<half,T>::type round(T arg) { return functions::round(arg); } + inline half round(half arg) { return functions::round(arg); } + inline half round(expr arg) { return functions::round(arg); } + + /// Nearest integer. + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases +// template<typename T> typename enable<long,T>::type lround(T arg) { return functions::lround(arg); } + inline long lround(half arg) { return functions::lround(arg); } + inline long lround(expr arg) { return functions::lround(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template<typename T> typename enable<half,T>::type nearbyint(T arg) { return functions::nearbyint(arg); } + inline half nearbyint(half arg) { return functions::rint(arg); } + inline half nearbyint(expr arg) { return functions::rint(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template<typename T> typename enable<half,T>::type rint(T arg) { return functions::rint(arg); } + inline half rint(half arg) { return functions::rint(arg); } + inline half rint(expr arg) { return functions::rint(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template<typename T> typename enable<long,T>::type lrint(T arg) { return functions::lrint(arg); } + inline long lrint(half arg) { return functions::lrint(arg); } + inline long lrint(expr arg) { return functions::lrint(arg); } + #if HALF_ENABLE_CPP11_LONG_LONG + /// Nearest integer. + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases +// template<typename T> typename enable<long long,T>::type llround(T arg) { return functions::llround(arg); } + inline long long llround(half arg) { return functions::llround(arg); } + inline long long llround(expr arg) { return functions::llround(arg); } + + /// Nearest integer using half's internal rounding mode. + /// \param arg half expression to round + /// \return nearest integer using default rounding mode +// template<typename T> typename enable<long long,T>::type llrint(T arg) { return functions::llrint(arg); } + inline long long llrint(half arg) { return functions::llrint(arg); } + inline long long llrint(expr arg) { return functions::llrint(arg); } + #endif + + /// \} + /// \name Floating point manipulation + /// \{ + + /// Decompress floating point number. + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return significant in range [0.5, 1) +// template<typename T> typename enable<half,T>::type frexp(T arg, int *exp) { return functions::frexp(arg, exp); } + inline half frexp(half arg, int *exp) { return functions::frexp(arg, exp); } + inline half frexp(expr arg, int *exp) { return functions::frexp(arg, exp); } + + /// Multiply by power of two. + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp +// template<typename T> typename enable<half,T>::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp); } + inline half ldexp(half arg, int exp) { return functions::scalbln(arg, exp); } + inline half ldexp(expr arg, int exp) { return functions::scalbln(arg, exp); } + + /// Extract integer and fractional parts. + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part +// template<typename T> typename enable<half,T>::type modf(T arg, half *iptr) { return functions::modf(arg, iptr); } + inline half modf(half arg, half *iptr) { return functions::modf(arg, iptr); } + inline half modf(expr arg, half *iptr) { return functions::modf(arg, iptr); } + + /// Multiply by power of two. + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp +// template<typename T> typename enable<half,T>::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp); } + inline half scalbn(half arg, int exp) { return functions::scalbln(arg, exp); } + inline half scalbn(expr arg, int exp) { return functions::scalbln(arg, exp); } + + /// Multiply by power of two. + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp +// template<typename T> typename enable<half,T>::type scalbln(T arg, long exp) { return functions::scalbln(arg, exp); } + inline half scalbln(half arg, long exp) { return functions::scalbln(arg, exp); } + inline half scalbln(expr arg, long exp) { return functions::scalbln(arg, exp); } + + /// Extract exponent. + /// \param arg number to query + /// \return floating point exponent + /// \retval FP_ILOGB0 for zero + /// \retval FP_ILOGBNAN for NaN + /// \retval MAX_INT for infinity +// template<typename T> typename enable<int,T>::type ilogb(T arg) { return functions::ilogb(arg); } + inline int ilogb(half arg) { return functions::ilogb(arg); } + inline int ilogb(expr arg) { return functions::ilogb(arg); } + + /// Extract exponent. + /// \param arg number to query + /// \return floating point exponent +// template<typename T> typename enable<half,T>::type logb(T arg) { return functions::logb(arg); } + inline half logb(half arg) { return functions::logb(arg); } + inline half logb(expr arg) { return functions::logb(arg); } + + /// Next representable value. + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to +// template<typename T,typename U> typename enable<half,T,U>::type nextafter(T from, U to) { return functions::nextafter(from, to); } + inline half nextafter(half from, half to) { return functions::nextafter(from, to); } + inline half nextafter(half from, expr to) { return functions::nextafter(from, to); } + inline half nextafter(expr from, half to) { return functions::nextafter(from, to); } + inline half nextafter(expr from, expr to) { return functions::nextafter(from, to); } + + /// Next representable value. + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to +// template<typename T> typename enable<half,T>::type nexttoward(T from, long double to) { return functions::nexttoward(from, to); } + inline half nexttoward(half from, long double to) { return functions::nexttoward(from, to); } + inline half nexttoward(expr from, long double to) { return functions::nexttoward(from, to); } + + /// Take sign. + /// \param x value to change sign for + /// \param y value to take sign from + /// \return value equal to \a x in magnitude and to \a y in sign +// template<typename T,typename U> typename enable<half,T,U>::type copysign(T x, U y) { return functions::copysign(x, y); } + inline half copysign(half x, half y) { return functions::copysign(x, y); } + inline half copysign(half x, expr y) { return functions::copysign(x, y); } + inline half copysign(expr x, half y) { return functions::copysign(x, y); } + inline half copysign(expr x, expr y) { return functions::copysign(x, y); } + + /// \} + /// \name Floating point classification + /// \{ + + + /// Classify floating point value. + /// \param arg number to classify + /// \retval FP_ZERO for positive and negative zero + /// \retval FP_SUBNORMAL for subnormal numbers + /// \retval FP_INFINITY for positive and negative infinity + /// \retval FP_NAN for NaNs + /// \retval FP_NORMAL for all other (normal) values +// template<typename T> typename enable<int,T>::type fpclassify(T arg) { return functions::fpclassify(arg); } + inline int fpclassify(half arg) { return functions::fpclassify(arg); } + inline int fpclassify(expr arg) { return functions::fpclassify(arg); } + + /// Check if finite number. + /// \param arg number to check + /// \retval true if neither infinity nor NaN + /// \retval false else +// template<typename T> typename enable<bool,T>::type isfinite(T arg) { return functions::isfinite(arg); } + inline bool isfinite(half arg) { return functions::isfinite(arg); } + inline bool isfinite(expr arg) { return functions::isfinite(arg); } + + /// Check for infinity. + /// \param arg number to check + /// \retval true for positive or negative infinity + /// \retval false else +// template<typename T> typename enable<bool,T>::type isinf(T arg) { return functions::isinf(arg); } + inline bool isinf(half arg) { return functions::isinf(arg); } + inline bool isinf(expr arg) { return functions::isinf(arg); } + + /// Check for NaN. + /// \param arg number to check + /// \retval true for NaNs + /// \retval false else +// template<typename T> typename enable<bool,T>::type isnan(T arg) { return functions::isnan(arg); } + inline bool isnan(half arg) { return functions::isnan(arg); } + inline bool isnan(expr arg) { return functions::isnan(arg); } + + /// Check if normal number. + /// \param arg number to check + /// \retval true if normal number + /// \retval false if either subnormal, zero, infinity or NaN +// template<typename T> typename enable<bool,T>::type isnormal(T arg) { return functions::isnormal(arg); } + inline bool isnormal(half arg) { return functions::isnormal(arg); } + inline bool isnormal(expr arg) { return functions::isnormal(arg); } + + /// Check sign. + /// \param arg number to check + /// \retval true for negative number + /// \retval false for positive number +// template<typename T> typename enable<bool,T>::type signbit(T arg) { return functions::signbit(arg); } + inline bool signbit(half arg) { return functions::signbit(arg); } + inline bool signbit(expr arg) { return functions::signbit(arg); } + + /// \} + /// \name Comparison + /// \{ + + /// Comparison for greater than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else +// template<typename T,typename U> typename enable<bool,T,U>::type isgreater(T x, U y) { return functions::isgreater(x, y); } + inline bool isgreater(half x, half y) { return functions::isgreater(x, y); } + inline bool isgreater(half x, expr y) { return functions::isgreater(x, y); } + inline bool isgreater(expr x, half y) { return functions::isgreater(x, y); } + inline bool isgreater(expr x, expr y) { return functions::isgreater(x, y); } + + /// Comparison for greater equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else +// template<typename T,typename U> typename enable<bool,T,U>::type isgreaterequal(T x, U y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(half x, half y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(half x, expr y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(expr x, half y) { return functions::isgreaterequal(x, y); } + inline bool isgreaterequal(expr x, expr y) { return functions::isgreaterequal(x, y); } + + /// Comparison for less than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else +// template<typename T,typename U> typename enable<bool,T,U>::type isless(T x, U y) { return functions::isless(x, y); } + inline bool isless(half x, half y) { return functions::isless(x, y); } + inline bool isless(half x, expr y) { return functions::isless(x, y); } + inline bool isless(expr x, half y) { return functions::isless(x, y); } + inline bool isless(expr x, expr y) { return functions::isless(x, y); } + + /// Comparison for less equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else +// template<typename T,typename U> typename enable<bool,T,U>::type islessequal(T x, U y) { return functions::islessequal(x, y); } + inline bool islessequal(half x, half y) { return functions::islessequal(x, y); } + inline bool islessequal(half x, expr y) { return functions::islessequal(x, y); } + inline bool islessequal(expr x, half y) { return functions::islessequal(x, y); } + inline bool islessequal(expr x, expr y) { return functions::islessequal(x, y); } + + /// Comarison for less or greater. + /// \param x first operand + /// \param y second operand + /// \retval true if either less or greater + /// \retval false else +// template<typename T,typename U> typename enable<bool,T,U>::type islessgreater(T x, U y) { return functions::islessgreater(x, y); } + inline bool islessgreater(half x, half y) { return functions::islessgreater(x, y); } + inline bool islessgreater(half x, expr y) { return functions::islessgreater(x, y); } + inline bool islessgreater(expr x, half y) { return functions::islessgreater(x, y); } + inline bool islessgreater(expr x, expr y) { return functions::islessgreater(x, y); } + + /// Check if unordered. + /// \param x first operand + /// \param y second operand + /// \retval true if unordered (one or two NaN operands) + /// \retval false else +// template<typename T,typename U> typename enable<bool,T,U>::type isunordered(T x, U y) { return functions::isunordered(x, y); } + inline bool isunordered(half x, half y) { return functions::isunordered(x, y); } + inline bool isunordered(half x, expr y) { return functions::isunordered(x, y); } + inline bool isunordered(expr x, half y) { return functions::isunordered(x, y); } + inline bool isunordered(expr x, expr y) { return functions::isunordered(x, y); } + + /// \name Casting + /// \{ + + /// Cast to or from half-precision floating point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the given rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// It uses the default rounding mode. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s is just a no-op. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + template<typename T,typename U> T half_cast(U arg) { return half_caster<T,U>::cast(arg); } + + /// Cast to or from half-precision floating point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the given rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s is just a no-op. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam R rounding mode to use. + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + template<typename T,std::float_round_style R,typename U> T half_cast(U arg) { return half_caster<T,U,R>::cast(arg); } + /// \} + } + + using detail::operator==; + using detail::operator!=; + using detail::operator<; + using detail::operator>; + using detail::operator<=; + using detail::operator>=; + using detail::operator+; + using detail::operator-; + using detail::operator*; + using detail::operator/; + using detail::operator<<; + using detail::operator>>; + + using detail::abs; + using detail::fabs; + using detail::fmod; + using detail::remainder; + using detail::remquo; + using detail::fma; + using detail::fmax; + using detail::fmin; + using detail::fdim; + using detail::nanh; + using detail::exp; + using detail::expm1; + using detail::exp2; + using detail::log; + using detail::log10; + using detail::log1p; + using detail::log2; + using detail::sqrt; + using detail::cbrt; + using detail::hypot; + using detail::pow; + using detail::sin; + using detail::cos; + using detail::tan; + using detail::asin; + using detail::acos; + using detail::atan; + using detail::atan2; + using detail::sinh; + using detail::cosh; + using detail::tanh; + using detail::asinh; + using detail::acosh; + using detail::atanh; + using detail::erf; + using detail::erfc; + using detail::lgamma; + using detail::tgamma; + using detail::ceil; + using detail::floor; + using detail::trunc; + using detail::round; + using detail::lround; + using detail::nearbyint; + using detail::rint; + using detail::lrint; +#if HALF_ENABLE_CPP11_LONG_LONG + using detail::llround; + using detail::llrint; +#endif + using detail::frexp; + using detail::ldexp; + using detail::modf; + using detail::scalbn; + using detail::scalbln; + using detail::ilogb; + using detail::logb; + using detail::nextafter; + using detail::nexttoward; + using detail::copysign; + using detail::fpclassify; + using detail::isfinite; + using detail::isinf; + using detail::isnan; + using detail::isnormal; + using detail::signbit; + using detail::isgreater; + using detail::isgreaterequal; + using detail::isless; + using detail::islessequal; + using detail::islessgreater; + using detail::isunordered; + + using detail::half_cast; +} + + +/// Extensions to the C++ standard library. +namespace std +{ + /// Numeric limits for half-precision floats. + /// Because of the underlying single-precision implementation of many operations, it inherits some properties from + /// `std::numeric_limits<float>`. + template<> class numeric_limits<half_float::half> : public numeric_limits<float> + { + public: + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Rounding mode. + /// Due to the mix of internal single-precision computations (using the rounding mode of the underlying + /// single-precision implementation) with the rounding mode of the single-to-half conversions, the actual rounding + /// mode might be `std::round_indeterminate` if the default half-precision rounding mode doesn't match the + /// single-precision rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = (std::numeric_limits<float>::round_style== + half_float::half::round_style) ? half_float::half::round_style : round_indeterminate; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0400); } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0xFBFF); } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7BFF); } + + /// Difference between one and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x1400); } + + /// Maximum rounding error. + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { return half_float::half(half_float::detail::binary, (round_style==std::round_to_nearest) ? 0x3800 : 0x3C00); } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7C00); } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7FFF); } + + /// Signalling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7DFF); } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0001); } + }; + +#if HALF_ENABLE_CPP11_HASH + /// Hash function for half-precision floats. + /// This is only defined if C++11 `std::hash` is supported and enabled. + template<> struct hash<half_float::half> //: unary_function<half_float::half,size_t> + { + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const + { return hash<half_float::detail::uint16>()(static_cast<unsigned>(arg.data_)&-(arg.data_!=0x8000)); } + }; +#endif +} + + +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#ifdef HALF_POP_WARNINGS + #pragma warning(pop) + #undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 6b23cda0d86a77487af7d63b3e7a0dfeae57bb37..0fe66e4b64e4113901db2bcd525e1895e642c6de 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -35,37 +35,34 @@ private: /// @brief Name of the graphview std::string mName; + /// @brief GraphView root node + NodePtr mRootNode; + /// @brief Set of nodes included in the GraphView std::set<NodePtr> mNodes; /// @brief Set of nodes included in the graphview with names std::map<std::string, NodePtr> mNodeRegistry; - /// @brief Nodes without input link - std::set<NodePtr> mInputNodes; + /// @brief GraphView inputs + std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes; - /// @brief Nodes without output link - std::set<NodePtr> mOutputNodes; + /// @brief GraphView outputs + std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes; public: - GraphView(std::string name="") + GraphView(const std::string& name="") : mName(name) { // ctor } - // GraphView(std::set<NodePtr> nodes, std::string name="") - // : mName(name) - // { - // add(nodes); - // } - bool operator==(const GraphView &gv) const { return mNodes == gv.mNodes; } - NodePtr operator[](std::string name) + NodePtr operator[](const std::string& name) { assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView."); return mNodeRegistry.at(name); @@ -105,57 +102,88 @@ public: return mNodes.find(nodePtr) != mNodes.end(); } + NodePtr getRootNode() { + return mRootNode; + } + /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ - inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; } + inline std::set<NodePtr> inputNodes() const noexcept { + std::set<NodePtr> nodes; + for (auto node : mInputNodes) { + nodes.insert(node.first); + } + return nodes; + } /** @brief Get reference to the set of output Nodes. */ - inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; } - + inline std::set<NodePtr> outputNodes() const noexcept { + std::set<NodePtr> nodes; + for (auto node : mOutputNodes) { + nodes.insert(node.first); + } + return nodes; + } /** @brief Assess if the given Node is an input Node of the GraphView object. */ inline bool isInputNode(NodePtr nodePtr) const { - return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false; + const auto nodes = inputNodes(); + return (nodes.find(nodePtr) != nodes.end()) ? true : false; } /** @brief Assess if the given Node is an output Node of the GraphView object. */ inline bool isOutputNode(NodePtr nodePtr) const { - return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false; + const auto nodes = outputNodes(); + return (nodes.find(nodePtr) != nodes.end()) ? true : false; } + void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); + void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); + + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; }; + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; }; + /** - * @brief List outside dataInput connections of the GraphView object's inputNodes. + * @brief List outside data input connections of the GraphView. + * Data inputs exclude inputs expecting parameters (weights or bias). + * The vector size is garanteed to match the number of outside data inputs of the GraphView. If there is + * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; /** - * @brief List dataInput connections of the GraphView object's inputNodes. + * @brief List all dataInput connections (within and outside) of the specified GraphView node named "name". + * Data inputs exclude inputs expecting parameters (weights or bias). * @param name Name of the Node. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** - * @brief List outside input connections of the GraphView object's inputNodes. + * @brief List outside input connections of the GraphView. The vector + * size is garanteed to match the number of outside inputs of the GraphView. If there is + * no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; /** - * @brief List input connections of the specified GraphView object's inputNode. + * @brief List all input connections (within and outside) of the specified GraphView node named "name". * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const; /** - * @brief List output connections of the GraphView object's outputNodes. + * @brief List outside output connections of the GraphView. The vector + * size is garanteed to match the number of outputs of the GraphView. If there is + * no connection to a given output, the corresponding sub-vector will be empty. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; /** - * @brief Specific i-th output connection of the GraphView object. + * @brief List all output connections (within and outside) of the specified GraphView node named "name". * @param nodeName Name of the Node of which to show the output. * @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> */ @@ -175,7 +203,7 @@ public: * If not, add a Transpose Operator. * 4 - Propagate Tensor dimensions through the consecutive Operators. */ - void compile(const std::string& backend, const Aidge::DataType datatype); + void compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device = 0); /** * @brief Compute dimensions of input/output Tensors for each Operator of the @@ -184,7 +212,7 @@ public: void forwardDims(); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ - void setBackend(const std::string &backend); + void setBackend(const std::string &backend, DeviceIdx_t device = 0); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setDataType(const DataType &datatype); @@ -229,7 +257,7 @@ public: * @brief Get the operator with the corresponding name if it is in the * GraphView. * @param nodeName Name of the node. - * @return NodePtr returns a new empty node if the one asked for + * @return NodePtr returns a nullptr if the one asked for * was not found. */ NodePtr getNode(const std::string& nodeName) const; @@ -252,20 +280,34 @@ public: * in the GraphView automatically. Default: true. */ void add(NodePtr otherNode, bool includeLearnableParam = true); + + /** + * @brief Include a set of Nodes to the current GraphView object. + * @param otherNodes + * @param includeLearnableParam + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). + */ + bool add(std::set<NodePtr> otherNodes, + bool includeLearnableParam = true); + /** * @brief Include a set of Nodes to the current GraphView object. + * The first element of the otherNodes pair is the start node and + * the second is the remaining nodes to add. * @param otherNodes * @param includeLearnableParam + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - void add(std::set<NodePtr> otherNodes, + bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes, bool includeLearnableParam = true); /** * @brief Include every Node inside another GraphView to the current * GraphView. * @param other_graph GraphView containing the Nodes to include. + * @return true if graph ordering is unique (meaning inputs/outputs order is well defined). */ - void add(std::shared_ptr<GraphView> otherGraph); + bool add(std::shared_ptr<GraphView> otherGraph); /** * @brief Include a Node in the current GraphView and link it to another @@ -350,26 +392,27 @@ public: IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx); - /** * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible. * Both sets should include all the necessary Producers. - * @details Replaced Nodes are removed from any GraphView pointing at them all. - * The oldNodes set should have only one input/output - * Tensor for automatic connections of newNodes set. - * @param oldNodes actual set of shared_ptr<Node> to replace. - * @param newNodes new set of shared_ptr<Node>. - * @return true - * @return false + * @details There are 3 cases of replacement: + * Case 1: same number of input/output connections for oldNodes and newNodes sets. + * - input/output connections are replacated according to their IDs. + * Case 2: different number of input/output connections for oldNodes and newNodes sets. + * - only a single parent/child node for the newNodes set, every input/output is + * connected to it. + * - several parents/children nodes for newNodes set => impossible to know, return false + * Case 3: newNodes set is empty + * - same number of input/output connections in oldNodes, parents and children are linked according + * to these connections IDs + * - different number of input/output connections in oldNodes => return false + * @param oldNodes + * @param newNodes + * @return true replacement has been performed + * @return false no replacement has been performed */ static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); - void updateInputNodes(); - /** - * @brief Process from zero the set of output Nodes. - */ - void updateOutputNodes(); - /** * @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones. * @return std::shared_ptr<GraphView> @@ -403,6 +446,7 @@ public: /** * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return IOIndex_t */ IOIndex_t getNbFreeDataInputs() const; @@ -413,33 +457,34 @@ private: /////////////////////////////////////////////////////// /** - * @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object. + * @brief Get the number of dataInput that are outside the GraphView. + * Data inputs exclude inputs expecting parameters (weights or bias). + * This number matches the size of the vector returned by GraphView::dataInputs(). * @return IOIndex_t */ IOIndex_t getNbDataInputs() const; /** - * @brief Update the set of inputNodes with a new Node, checking if it can be - * added and removing any Node not part of mInputNode anymore. + * @brief Automatically update GraphView inputs/outputs with a new Node, checking if + * it this Node becomes an input/output for the graph and if previous inputs are still + * inputs/outputs after adding this node. * @param nodePtr */ - void updateInputNodes(NodePtr node); + void updateInputsOutputsNew(NodePtr newNode); /** - * @brief Update the set of outputNodes with a new Node, checking if it can be - * added and removing any Node not part of mOutputNode anymore. + * @brief Automatically update GraphView inputs/outputs with a Node removed, checking if + * it this Node was an input/output for the graph and if this node childs become new inputs/outputs + * for the graph. * @param nodePtr */ - void updateOutputNodes(NodePtr node); + void updateInputsOutputsDelete(NodePtr deletedNode); /////////////////////////////////////////////////////// // TOPOLOGY /////////////////////////////////////////////////////// void _forwardDims(std::set<NodePtr> listNodes); - - void removeInputNode(const std::string nodeName); - void removeOutputNode(const std::string nodeName); }; } // namespace Aidge diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 118d925e1e5b7c4fcd0c353236998ff831f7e42d..5ae4eb5d893244fa842e6bb0435c0a8ab3bc0ac5 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -140,7 +140,8 @@ public: /** * @brief List of pair <Parent, ID of the data intput>. When an input is not - * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; @@ -167,6 +168,7 @@ public: /** * @brief Get the lowest index in the InputData Parent list equal to the * nullptr. + * Data inputs exclude inputs expecting parameters (weights or bias). * @return std::size_t */ inline IOIndex_t getFirstFreeDataInput() const { @@ -180,7 +182,9 @@ public: IOIndex_t getNbFreeDataInputs() const; /** - * @brief List input ids of children linked to outputs of the node + * @brief List input ids of children linked to outputs of the node. The vector + * size is garanteed to match the number of outputs of the node. If there is + * no connection to a given output, the corresponding sub-vector will be empty. * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ @@ -203,7 +207,8 @@ public: inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); } /** - * @brief Number of input specifically for data + * @brief Number of input specifically for data. + * Data inputs exclude inputs expecting parameters (weights or bias). * @details [data, data, weight, bias] => 2 * @return IOIndex_t */ diff --git a/include/aidge/graph/Testing.hpp b/include/aidge/graph/Testing.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ecacdf66298cb83c919ad447c82463206836a3e9 --- /dev/null +++ b/include/aidge/graph/Testing.hpp @@ -0,0 +1,67 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_GRAPH_TESTING_H_ +#define AIDGE_CORE_GRAPH_TESTING_H_ + +#include <cstddef> +#include <vector> +#include <set> +#include <random> // std::mt19937::result_type +#include <utility> // std::pair + +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +/** + * Random (directed) graph generator +*/ +struct RandomGraph { + /// @brief If true, the generated graph is a DAG (no cycle) + bool acyclic = false; + /// @brief Connection density (between 0 and 1) + float density = 0.5; + /// @brief Max number of inputs per node (regardless if they are connected or not) + std::size_t maxIn = 5; + /// @brief Average number of inputs per node (regardless if they are connected or not) + float avgIn = 1.5; + /// @brief Max number of outputs per node (regardless if they are connected or not) + std::size_t maxOut = 2; + /// @brief Average number of outputs per node (regardless if they are connected or not) + float avgOut = 1.1; + /// @brief List of node types that should be generated in the graph (as GenericOperator) + std::vector<std::string> types = {"Fictive"}; + /// @brief Weights of each node type, used to compute the probability of generating this type + std::vector<float> typesWeights = {1.0}; + /// @brief Type of node that should be omitted from the generated topology + std::string omitType; + + /** + * Generate a DAG according to the parameters of the class. + * @param seed Random seed. For an identical seed, an identical topology is + * generated, but with a random node ordering in the return set of nodes. + * @param nbNodes Number of nodes to generate. + */ + std::pair<NodePtr, std::set<NodePtr>> gen(std::mt19937::result_type seed, std::size_t nbNodes) const; +}; + +std::string nodePtrToType(NodePtr node); +std::string nodePtrToName(NodePtr node); +std::set<std::string> nodePtrTo(const std::set<NodePtr>& nodes, + std::string(*nodeTo)(NodePtr) = nodePtrToType); +std::vector<std::pair<std::string, IOIndex_t>> nodePtrTo( + const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes, + std::string(*nodeTo)(NodePtr) = nodePtrToType); + +} // namespace Aidge + +#endif /* AIDGE_CORE_GRAPH_TESTING_H_ */ diff --git a/include/aidge/graphRegex/GraphStrInterpreter.hpp b/include/aidge/graphRegex/GraphStrInterpreter.hpp index 98dca0e9f84de0be2614aed0e47c9d86ae552674..38e89b3733e1a07062661fa520485f92fbd7f026 100644 --- a/include/aidge/graphRegex/GraphStrInterpreter.hpp +++ b/include/aidge/graphRegex/GraphStrInterpreter.hpp @@ -1,7 +1,6 @@ #ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ #define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ -#include <iostream> #include <sstream> #include <memory> #include <algorithm> diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index f5521a1d12728a7957cb67c09861ee673e21cbae..9aed8299a67ab719141b6fe199ebf3f52fb7d387 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -30,7 +30,7 @@ namespace Aidge { class Add_Op : public OperatorTensor, public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> { public: - static constexpr const char* Type = "Add"; + static const std::string Type; Add_Op(const IOIndex_t nbIn) : OperatorTensor(Type, nbIn, 0, 1) @@ -76,20 +76,15 @@ public: // } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Add_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - for (std::size_t i = 0; i < nbInputs(); ++i) { - getInput(i)->setBackend(name); - } + mOutputs[0]->setBackend(name, device); } - static const std::vector<std::string> getInputsName(){ + static const std::vector<std::string> getInputsName() { return {"data_input_0", "data_input_n"}; } - static const std::vector<std::string> getOutputsName(){ + static const std::vector<std::string> getOutputsName() { return {"data_output"}; } }; diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 5fb1d5b16c55f7f5b6cea4db02d3aa955831e08b..a2098ff36b40b78eb12a36fe28793e8dd73d9d9c 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -36,7 +36,7 @@ class AvgPooling_Op : public OperatorTensor, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "AvgPooling"; + static const std::string Type; AvgPooling_Op() = delete; @@ -94,22 +94,24 @@ public: } - std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> - computeReceptiveField(const std::size_t firstIdx, + std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> + computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override final { if (outputIdx != 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); } + if (firstEltDims.size() != outputDims.size()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "outputDims and firstEltDims should have the size of the output Tensor dimensions."); + } if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) { // Offset - const auto outputIdxDims = mOutputs[0]->getCoord(firstIdx); - std::vector<DimSize_t> inputIdxDims = outputIdxDims; + std::vector<DimSize_t> inputIdxDims = firstEltDims; for (DimIdx_t i = 0; i < (DIM+2); ++i) { - if (((outputDims[i] + outputIdxDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); + if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]); } } @@ -126,20 +128,17 @@ public: + (this->template getAttr<AvgPoolingAttr::KernelDims>()[static_cast<std::size_t>(i)] - 1)); inputIdxDims[2+i] *= this->template getAttr<AvgPoolingAttr::StrideDims>()[static_cast<std::size_t>(i)]; } - std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res; - res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInputs[0]->getIdx(inputIdxDims), inputDims)); + std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> res; + res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(inputIdxDims, inputDims)); return res; } AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -150,6 +149,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string AvgPooling_Op<DIM>::Type = "AvgPooling"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index be850d377e5a1781b2cb04b5040c257ecc30cd92..4a0f40c034c7738a33eb8a9569fac4aa2fff465d 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -33,7 +33,7 @@ class BatchNorm_Op : public OperatorTensor, public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>, public StaticAttributes<BatchNormAttr, float, float> { public: - static constexpr const char *Type = "BatchNorm"; + static const std::string Type; BatchNorm_Op() = delete; @@ -82,27 +82,27 @@ public: associated &= !(getInput(i)->empty()); } if (associated) { - const DimSize_t nbChannels = getInput(0)->dims()[1]; + const DimSize_t nbFeatures = getInput(0)->dims()[1]; for (std::size_t i = nbData(); i < nbInputs(); ++i) { - if(getInput(i)->size() != nbChannels) { + if(getInput(i)->size() != nbFeatures) { // /!\ Input size should be handled BEFORE calling this function // This should raise an error - getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]})); + getInput(i)->resize({getInput(0)->dims()[1]}); } } mOutputs[0]->resize(getInput(0)->dims()); } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); + mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround - getInput(1)->setBackend(name); - getInput(2)->setBackend(name); - getInput(3)->setBackend(name); - getInput(4)->setBackend(name); + // By default, automatically set backend for scale, shift, mean and variance + getInput(1)->setBackend(name, device); + getInput(2)->setBackend(name, device); + getInput(3)->setBackend(name, device); + getInput(4)->setBackend(name, device); } static const std::vector<std::string> getInputsName() { @@ -113,16 +113,20 @@ public: } }; +template <DimIdx_t DIM> +const std::string BatchNorm_Op<DIM>::Type = "BatchNorm"; + template <DimSize_t DIM> -inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, +inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures, + const float epsilon = 1.0e-5F, const float momentum = 0.1F, const std::string& name = "") { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); - addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); - addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift"); - addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean"); - addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance"); + addProducer(batchNorm, 1, {nbFeatures}, "scale"); + addProducer(batchNorm, 2, {nbFeatures}, "shift"); + addProducer(batchNorm, 3, {nbFeatures}, "batch_mean"); + addProducer(batchNorm, 4, {nbFeatures}, "batch_variance"); return batchNorm; } } // namespace Aidge diff --git a/include/aidge/operator/Cast.hpp b/include/aidge/operator/Cast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7cc3985674219daf087381049d3a845299b3e250 --- /dev/null +++ b/include/aidge/operator/Cast.hpp @@ -0,0 +1,75 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_CAST_H_ +#define AIDGE_CORE_OPERATOR_CAST_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Cast_Op : public OperatorTensor, + public Registrable<Cast_Op, std::string, std::unique_ptr<OperatorImpl>(const Cast_Op&)> { +public: + static const std::string Type; + + Cast_Op() : OperatorTensor(Type, 1, 0, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Cast_Op(const Cast_Op& op) + : OperatorTensor(op) + { + mImpl = op.mImpl ? Registrar<Cast_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Cast_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Cast_Op>(*this); + } + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { + if (Registrar<Cast_Op>::exists({name})) { + mImpl = Registrar<Cast_Op>::create({name})(*this); + } + mOutputs[0]->setBackend(name, device); + } + + void forward() override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Cast(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Cast_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_CAST_H_ */ \ No newline at end of file diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 78e21f85250c361053857e27c582e1487aeec64e..06cc468bd7266bbcfeb6802f274c536ec09867fc 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -32,7 +32,7 @@ class Concat_Op : public OperatorTensor, public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>, public StaticAttributes<ConcatAttr, DimSize_t> { public: - static constexpr const char* Type = "Concat"; + static const std::string Type; using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>; template <ConcatAttr e> @@ -101,14 +101,9 @@ public: } } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Concat_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - for (std::size_t i = 0; i < nbInputs(); ++i) { - getInput(i)->setBackend(name); - } + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index b62d393bc37859f24c4f54f8ce1ba4458bf11ab4..be5fb3e393ced7ee7a53e27426b4247e48b478e8 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -36,7 +36,7 @@ class Conv_Op : public OperatorTensor, DimSize_t, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "Conv"; + static const std::string Type; Conv_Op() = delete; @@ -119,19 +119,21 @@ public: } -std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { +std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { if (outputIdx != 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); } + if (firstEltDims.size() != outputDims.size()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "outputDims and firstEltDims should have the size of the output Tensor dimensions."); + } if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) { // Offset - const auto outputIdxDims = mOutputs[0]->getCoord(firstIdx); - auto inputIdxDims = outputIdxDims; // batch idx is the same + auto inputIdxDims = firstEltDims; // batch idx is the same inputIdxDims[1] = 0; // each channel is used so start with the first one for (DimIdx_t i = 0; i < (DIM+2); ++i) { - if (((outputDims[i] + outputIdxDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); + if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]); } } @@ -155,29 +157,29 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel weightDims.push_back(this->template getAttr<ConvAttr::KernelDims>()[i]); } std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0); - weightIdxDims[0] = outputIdxDims[1]; + weightIdxDims[0] = firstEltDims[1]; // Bias const std::vector<DimSize_t> biasDims{outputDims[1]}; // the number of output channel - const std::vector<DimSize_t> biasIdxDims{outputIdxDims[1]}; + const std::vector<DimSize_t> biasIdxDims{firstEltDims[1]}; // Result - std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res; - res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(getInput(0)->getIdx(inputIdxDims), inputDims)); - res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(getInput(1)->getIdx(weightIdxDims), weightDims)); - res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(getInput(2)->getIdx(biasIdxDims), biasDims)); + std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> res; + res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(inputIdxDims, inputDims)); + res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(weightIdxDims, weightDims)); + res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(biasIdxDims, biasDims)); return res; } AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); + mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround - getInput(1)->setBackend(name); - getInput(2)->setBackend(name); + // By default, automatically set backend for weight and bias inputs + getInput(1)->setBackend(name, device); + getInput(2)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -188,6 +190,21 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel } }; +template <DimIdx_t DIM> +const std::string Conv_Op<DIM>::Type = "Conv"; + +/** + * @brief Perform a convolution on the input Tensor. + * + * @tparam DIM Number of dimensions for the feature map. + * @param inChannels Number of input channels. + * @param outChannels Number of output channels. + * @param kernelDims Dimensions of the kernel. Must be the same number of dimensions as the feature map. + * @param name Name of the operator. + * @param strideDims Dimensions of the stride attribute. Must be the same number of dimensions as the feature map. + * @param dilationDims Dimensions of the dilation attribute. Must be the same number of dimensions as the feature map. + * @return std::shared_ptr<Node> A Node containing the operator. + */ template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> Conv(DimSize_t inChannels, DimSize_t outChannels, @@ -198,9 +215,8 @@ inline std::shared_ptr<Node> Conv(DimSize_t inChannels, // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(inChannels, outChannels, kernelDims, strideDims, dilationDims), name); - // addProducer(conv, 1, append(append(kernel_dims, in_channels), out_channels), "w"); addProducer(conv, 1, append(outChannels, append(inChannels, kernelDims)), "w"); - addProducer(conv, 2, std::array<DimSize_t, 1>({outChannels}), "b"); + addProducer(conv, 2, {outChannels}, "b"); return conv; } diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index c95315f6d63e817354fc82dded4e3cfb4ed1b704..9d0c0bf408a2f634f96881cd339c330340d5e344 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -37,7 +37,7 @@ class ConvDepthWise_Op : public OperatorTensor, DimSize_t, std::array<DimSize_t, DIM>> { public: - static constexpr const char *Type = "ConvDepthWise"; + static const std::string Type; ConvDepthWise_Op() = delete; @@ -115,18 +115,20 @@ public: } } - std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { + std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { if (outputIdx != 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); } + if (firstEltDims.size() != outputDims.size()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "outputDims and firstEltDims should have the size of the output Tensor dimensions."); + } if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) { // Offset - const auto outputIdxDims = mOutputs[0]->getCoord(firstIdx); - auto inputIdxDims = outputIdxDims; // batch idx is the same + auto inputIdxDims = firstEltDims; // batch idx is the same for (DimIdx_t i = 0; i < (DIM+2); ++i) { - if (((outputDims[i] + outputIdxDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); + if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]); } } @@ -149,29 +151,29 @@ public: weightDims.push_back(this->template getAttr<ConvDepthWiseAttr::KernelDims>()[i]); } std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0); - weightIdxDims[0] = outputIdxDims[1]; + weightIdxDims[0] = firstEltDims[1]; // Bias const std::vector<DimSize_t> biasDims{outputDims[1]}; // the number of output channel - const std::vector<DimSize_t> biasIdxDims{outputIdxDims[1]}; + const std::vector<DimSize_t> biasIdxDims{firstEltDims[1]}; // Result - std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res; - res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(getInput(0)->getIdx(inputIdxDims), inputDims)); - res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(getInput(1)->getIdx(weightIdxDims), weightDims)); - res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(getInput(2)->getIdx(biasIdxDims), biasDims)); + std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> res; + res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(inputIdxDims, inputDims)); + res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(weightIdxDims, weightDims)); + res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(biasIdxDims, biasDims)); return res; } AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); + mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround - getInput(1)->setBackend(name); - getInput(2)->setBackend(name); + // By default, automatically set backend for weight and bias inputs + getInput(1)->setBackend(name, device); + getInput(2)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -182,6 +184,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string ConvDepthWise_Op<DIM>::Type = "ConvDepthWise"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> ConvDepthWise(const DimSize_t nbChannels, const std::array<DimSize_t, DIM> &kernelDims, @@ -192,7 +197,7 @@ inline std::shared_ptr<Node> ConvDepthWise(const DimSize_t nbChannels, static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(nbChannels, kernelDims, strideDims, dilationDims), name); addProducer(convDW, 1, append(nbChannels, append(DimSize_t(1), kernelDims)), "w"); - addProducer(convDW, 2, std::array<DimSize_t, 1>({nbChannels}), "b"); + addProducer(convDW, 2, {nbChannels}, "b"); return convDW; } diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index fcdb03a6be36bc9e1be7d69d01005f92b535d00c..94b755e0fdb0f76d54cd4f046fb8b08dda05b6b2 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -29,7 +29,7 @@ class Div_Op : public OperatorTensor, public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> { public: - static constexpr const char* Type = "Div"; + static const std::string Type; Div_Op() : OperatorTensor(Type, 2, 0, 1) {} @@ -54,13 +54,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Div_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Erf.hpp b/include/aidge/operator/Erf.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6395756f3b08c5838d390ab45d38fa9c03cb91cb --- /dev/null +++ b/include/aidge/operator/Erf.hpp @@ -0,0 +1,75 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_ERF_H_ +#define AIDGE_CORE_OPERATOR_ERF_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Erf_Op : public OperatorTensor, + public Registrable<Erf_Op, std::string, std::unique_ptr<OperatorImpl>(const Erf_Op&)> { +public: + static const std::string Type; + + Erf_Op() : OperatorTensor(Type, 1, 0, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Erf_Op(const Erf_Op& op) + : OperatorTensor(op) + { + mImpl = op.mImpl ? Registrar<Erf_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Erf_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Erf_Op>(*this); + } + + void setBackend(const std::string& name) override { + mImpl = Registrar<Erf_Op>::create(name)(*this); + mOutputs[0]->setBackend(name); + + // FIXME: temporary workaround + getInput(0)->setBackend(name); + } + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Erf(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Erf_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_ERF_H_ */ diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 8dea38335dd052f2dbf7d0aa7fc4f7fe84741a06..a73734ad20e10fe2a3e1d0d12d40e584b4540fb4 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -35,7 +35,7 @@ class FC_Op : public OperatorTensor, std::unique_ptr<OperatorImpl>(const FC_Op &)>, public StaticAttributes<FCAttr, DimSize_t, bool> { public: - static constexpr const char* Type = "FC"; + static const std::string Type; FC_Op() = delete; @@ -77,7 +77,7 @@ public: } mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); if (inputIdx == 0 && getInput(0)->nbDims() == 1) - mInputs[inputIdx]->resize(std::array<DimSize_t, 2>({1, getInput(inputIdx)->size()})); + mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()}); } void computeOutputDims() override final { @@ -95,14 +95,13 @@ public: } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<FC_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); + mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); - getInput(2)->setBackend(name); + // By default, automatically set backend for weight and bias inputs + getInput(1)->setBackend(name, device); + getInput(2)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -116,8 +115,8 @@ public: inline std::shared_ptr<Node> FC(DimSize_t inChannels, DimSize_t outChannels, bool noBias = false, const std::string& name = "") { // FIXME: properly handle default w&b initialization in every cases auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(outChannels, noBias), name); - addProducer(fc, 1, std::array<DimSize_t, 2>({outChannels, inChannels}), "w"); - addProducer(fc, 2, (noBias ? std::array<DimSize_t, 1>({0}) : std::array<DimSize_t, 1>({outChannels})), "b"); // already sets bias dims + addProducer(fc, 1, {outChannels, inChannels}, "w"); + addProducer(fc, 2, {(noBias ? 0 : outChannels)}, "b"); // already sets bias dims return fc; } } // namespace Aidge diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8276222811f6cc02c062d85e7ae99d72edead7a --- /dev/null +++ b/include/aidge/operator/Gather.hpp @@ -0,0 +1,100 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_GATHER_H_ +#define AIDGE_CORE_OPERATOR_GATHER_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class GatherAttr { Axis }; + +class Gather_Op : public OperatorTensor, + public Registrable<Gather_Op, + std::string, + std::unique_ptr<OperatorImpl>(const Gather_Op&)>, + public StaticAttributes<GatherAttr, int> { + +public: + static const std::string Type; + + Gather_Op() = delete; + + + using Attributes_ = StaticAttributes<GatherAttr, int>; + template <GatherAttr e> using attr = typename Attributes_::template attr<e>; + Gather_Op(int axis) + : OperatorTensor(Type, 2, 0, 1), + Attributes_( + attr<GatherAttr::Axis>(axis)) + {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Gather_Op(const Gather_Op& op) + : OperatorTensor(op), + Attributes_(op) + { + mImpl = op.mImpl ? Registrar<Gather_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Gather_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Gather_Op>(*this); + } + + void computeOutputDims() override final; + + void setBackend(const std::string& name) override { + mImpl = Registrar<Gather_Op>::create(name)(*this); + mOutputs[0]->setBackend(name); + + // FIXME: temporary workaround + getInput(0)->setBackend(name); + getInput(1)->setBackend(name); + } + + static const std::vector<std::string> getInputsName(){ + return {"data_input", "indexes"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Gather(int axis = 0, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"}; +} + +#endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */ diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 505c5344990453c8f4ab84fa3893e75b216d7a54..c966b5f5c1bb4914f3e46f96493da87a6707b1ff 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -36,7 +36,7 @@ private: ComputeDimsFunc mComputeOutputDims; public: - GenericOperator_Op(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) + GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) : OperatorTensor(type, nbData, nbParam, nbOut) {} @@ -97,7 +97,7 @@ public: ~GenericOperator_Op() = default; - void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); } + void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { printf("setBackend: not available yet.\n"); } void setDataType(const DataType& /*datatype*/) const override { printf("setDataType: not available yet.\n"); } void forward() override final { if(mImpl){ @@ -125,7 +125,7 @@ public: * @param name (optional) name of the Operator. * @return std::shared_ptr<Node> Node associated with the Generic Operator. */ -inline std::shared_ptr<Node> GenericOperator(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut, +inline std::shared_ptr<Node> GenericOperator(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name); } diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index c5cd9bb62e0097c9a0e646caaf14cddd73bf512d..7348fa10a96c55914bae68983b5e3bd4a9c40b12 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -37,7 +37,7 @@ namespace Aidge { class Identity_Op : public OperatorTensor, public Registrable<Identity_Op, std::string, std::unique_ptr<OperatorImpl>(const Identity_Op&)> { public: - static constexpr const char* Type = "Identity"; + static const std::string Type; Identity_Op() : OperatorTensor(Type, 1, 0, 0) @@ -103,10 +103,10 @@ public: } return mInputs[outputIdx]; } - void setBackend(const std::string& name) override final { + void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final { // setBackend do nothing, Identity node has no backend it just pass the same Tensor } - void setDataType(const DataType& dataType) const override final { + void setDataType(const DataType& /*dataType*/) const override final { // setDatatype do nothing, Identity node has no backend it just pass the same Tensor } diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 2474e2e5af4139b77cace03b27b603fb66b7699a..5976f1d88d70ae7fb716f4038e57da95242c3551 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -33,7 +33,7 @@ class LeakyReLU_Op : public OperatorTensor, public Registrable<LeakyReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const LeakyReLU_Op&)>, public StaticAttributes<LeakyReLUAttr, float> { public: - static constexpr const char* Type = "LeakyReLU"; + static const std::string Type; LeakyReLU_Op() = delete; @@ -67,12 +67,9 @@ public: - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<LeakyReLU_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 90930dd22a36f84a7479e245eb09d9c28dfd031d..3d80193be3f669b00e5a138470269e52d0715780 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -35,7 +35,7 @@ class MatMul_Op : public OperatorTensor, std::unique_ptr<OperatorImpl>(const MatMul_Op &)>, public StaticAttributes<MatMulAttr, DimSize_t> { public: - static constexpr const char* Type = "MatMul"; + static const std::string Type; MatMul_Op() = delete; @@ -83,13 +83,9 @@ public: } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<MatMul_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -103,7 +99,7 @@ public: inline std::shared_ptr<Node> MatMul(DimSize_t inChannels, DimSize_t outChannels, const std::string& name = "") { // FIXME: properly handle default w initialization in every cases auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(outChannels), name); - addProducer(matmul, 1, std::array<DimSize_t, 2>({outChannels, inChannels}), "w"); + addProducer(matmul, 1, {outChannels, inChannels}, "w"); return matmul; } } // namespace Aidge diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index c46ddb3797e2303ee27814c96ef060156bdc9108..467a69d73c98a21c85e956acf42536e197833cbd 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -36,7 +36,7 @@ class MaxPooling_Op : public OperatorTensor, std::array<DimSize_t, DIM>, bool> { public: - static constexpr const char *Type = "MaxPooling"; + static const std::string Type; MaxPooling_Op() = delete; @@ -104,12 +104,9 @@ public: } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -120,6 +117,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string MaxPooling_Op<DIM>::Type = "MaxPooling"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> MaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 4c8feb46c3e3db33bd380302e3e0683f1b8734f5..1fe050b295e102bcdd4e5bd3651d126754b79618 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -25,16 +25,9 @@ public: // Micro-graph handling: std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph std::shared_ptr<SequentialScheduler> mScheduler; - // Need to store an ordored list of input/output operators for the micro-graph, - // because input/output nodes in a GraphView are unordered. - // TODO: refactor GraphView to handle ordered input/output? - std::vector<std::pair<std::shared_ptr<OperatorTensor>, IOIndex_t>> mInputOps; - std::vector<std::pair<std::shared_ptr<OperatorTensor>, IOIndex_t>> mOutputOps; public: - MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, - std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), - std::vector<NodePtr> outputNodes = std::vector<NodePtr>()); + MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph); /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -47,7 +40,7 @@ public: /** * @brief Clone the operator using its copy-constructor. - * @see Operator::MatMul_Op + * @see Operator::MetaOperator_Op */ std::shared_ptr<Operator> clone() const override { return std::make_shared<MetaOperator_Op>(*this); @@ -64,8 +57,8 @@ public: void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); - const auto& inputOp = mInputOps[inputIdx]; - inputOp.first->associateInput(inputOp.second, data); + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + inputOp.first->getOperator()->associateInput(inputOp.second, data); // Associate inputs for custom implementation mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -77,7 +70,7 @@ public: } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { if (Registrar<MetaOperator_Op>::exists({name, type()})) { // A custom implementation exists for this meta operator mImpl = Registrar<MetaOperator_Op>::create({name, type()})(*this); @@ -86,7 +79,7 @@ public: // The micro-graph should always be set to the right backend, since it // shares input/output tensors. // Input/output tensors backend are updated here. - mGraph->setBackend(name); + mGraph->setBackend(name, device); } void setDataType(const DataType &datatype) const override { @@ -106,15 +99,15 @@ public: assert(false && "not implemented"); } + inline bool isAtomic() const noexcept override final { return false; } + }; inline std::shared_ptr<Node> MetaOperator(const char *type, const std::shared_ptr<GraphView>& graph, - const std::string& name = "", - std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), - std::vector<NodePtr> outputNodes = std::vector<NodePtr>()) + const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph, inputNodes, outputNodes), name); + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name); } } // namespace Aidge diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index 9ec6cdb928cdfa433b04ea23c69344133a3c7064..2832f9fce005e0ae9d2bab98bf764c68f93e3cda 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -32,10 +32,8 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, // Construct micro-graph auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); - // Need to specify the ordered list of input operators - const std::vector<NodePtr> orderedInputNodes = {pad, conv}; - auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name, orderedInputNodes); + auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name); addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); addProducer(metaOp, 2, {out_channels}, "b"); return metaOp; @@ -66,12 +64,10 @@ inline std::shared_ptr<Node> PaddedConvDepthWise(const DimSize_t nb_channels, // Construct micro-graph auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(nb_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); - // Need to specify the ordered list of input operators - const std::vector<NodePtr> orderedInputNodes = {pad, conv}; - auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name, orderedInputNodes); - addProducer(metaOp, 1, std::array<DimSize_t,0>({}), "w"); - addProducer(metaOp, 2, std::array<DimSize_t,0>({}), "b"); + auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name); + addProducer(metaOp, 1, append(nb_channels, append(DimSize_t(1), kernel_dims)), "w"); + addProducer(metaOp, 2, {nb_channels}, "b"); return metaOp; } diff --git a/include/aidge/operator/Move.hpp b/include/aidge/operator/Move.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62fb9897384673c695895b54557b4cf637aa2447 --- /dev/null +++ b/include/aidge/operator/Move.hpp @@ -0,0 +1,75 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_MOVE_H_ +#define AIDGE_CORE_OPERATOR_MOVE_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Move_Op : public OperatorTensor, + public Registrable<Move_Op, std::tuple<std::string, std::string>, std::unique_ptr<OperatorImpl>(const Move_Op&)> { +public: + static const std::string Type; + + Move_Op() : OperatorTensor(Type, 1, 0, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Move_Op(const Move_Op& op) + : OperatorTensor(op) + { + mImpl = op.mImpl ? Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), mOutputs[0]->getImpl()->backend()})(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Move_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Move_Op>(*this); + } + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { + if (mInputs[0]->getImpl() && Registrar<Move_Op>::exists({mInputs[0]->getImpl()->backend(), name})) { + mImpl = Registrar<Move_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this); + } + mOutputs[0]->setBackend(name, device); + } + + void forward() override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Move(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Move_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_MOVE_H_ */ \ No newline at end of file diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index 337fe6e65cc040e67ee033516731a7ba8de86d2d..78b2fa5f98c9dae66ae291769f2de08d7805a738 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -31,7 +31,7 @@ namespace Aidge { class Mul_Op : public OperatorTensor, public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> { public: - static constexpr const char* Type = "Mul"; + static const std::string Type; Mul_Op() : OperatorTensor(Type, 2, 0, 1) {} @@ -56,13 +56,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Mul_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 1f4cdd23f9a765924305ebeb43e3e6ee1ad73496..dd4ad16441f536fd786036672d57817b892cf155 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -44,7 +44,7 @@ private: public: Operator() = delete; - Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) + Operator(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) : mType(type), mOperatorType(operatorType), mNbData(nbData), @@ -105,7 +105,7 @@ public: // IMPLEMENTATION /////////////////////////////////////////////////////// - virtual void setBackend(const std::string& name) = 0; + virtual void setBackend(const std::string& name, DeviceIdx_t device = 0) = 0; virtual void setDataType(const DataType& dataType) const = 0; /** @@ -157,6 +157,8 @@ public: return mOperatorType; } + virtual inline bool isAtomic() const noexcept { return true; } + inline IOIndex_t nbInputs() const noexcept { return mNbData+mNbParam; }; inline IOIndex_t nbData() const noexcept { return mNbData; }; inline IOIndex_t nbParam() const noexcept { return mNbParam; }; diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 126e5d467d0f341a8c5b8c5d16d188ebe92135d0..504a416488651d43126a60981cd8afe0f95821f2 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -40,7 +40,7 @@ protected: public: OperatorTensor() = delete; - OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, + OperatorTensor(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut) : Operator(type, nbData, nbParam, nbOut, OperatorType::Tensor), mInputs(std::vector<std::shared_ptr<Tensor>>(nbData + nbParam, nullptr)), @@ -100,7 +100,7 @@ public: * @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> * For each dataInput Tensor of the Operator, the first index and dimensions of the feature area. */ - virtual std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const; + virtual std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const; virtual void computeOutputDims(); virtual bool outputDimsForwarded() const; /////////////////////////////////////////////////// diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp index 279b8b3d2c173d18c65c17e50385954a88fde77e..56245dd2dfd62d4dc765de6e3d43b08c144cc62b 100644 --- a/include/aidge/operator/Pad.hpp +++ b/include/aidge/operator/Pad.hpp @@ -37,7 +37,7 @@ class Pad_Op : public OperatorTensor, PadBorderType, double> { public: - static constexpr const char *Type = "Pad"; + static const std::string Type; Pad_Op() = delete; @@ -97,12 +97,9 @@ public: } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -113,6 +110,9 @@ public: } }; +template <DimIdx_t DIM> +const std::string Pad_Op<DIM>::Type = "Pad"; + template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, 2*DIM> &beginEndTuples, const std::string& name = "", diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index a5cd3a9b047f9a32665cc2de1ead4f2221fed4aa..d498cacc7c5b2ddc3269f3ebc77707aead8eb52d 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -29,7 +29,7 @@ namespace Aidge { class Pow_Op : public OperatorTensor, public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> { public: - static constexpr const char* Type = "Pow"; + static const std::string Type; Pow_Op() : OperatorTensor(Type, 2, 0, 1) {} @@ -54,13 +54,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Pow_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index a3f6e085ce3849c1b057f0fdb043093b338b48a1..ee00ead696efe623a4e051994f470a38397777ec 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -29,7 +29,7 @@ class Producer_Op public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( const Producer_Op &)> { public: - static constexpr const char* Type = "Producer"; + static const std::string Type; template <std::size_t DIM> Producer_Op(const std::array<DimSize_t, DIM>& dims) @@ -76,9 +76,9 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Producer_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 15dec9be8516f71f5f4dfd0aec6a2985671da53d..0bb7cdffe421b973ae7c86b4569e7464b3cf6da4 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -28,7 +28,7 @@ namespace Aidge { class ReLU_Op : public OperatorTensor, public Registrable<ReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const ReLU_Op&)> { public: - static constexpr const char* Type = "ReLU"; + static const std::string Type; ReLU_Op() : OperatorTensor(Type, 1, 0, 1) {} @@ -51,12 +51,9 @@ public: } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<ReLU_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0acd21b28fac54e7e6d30e8219ead0e04ef777f6 --- /dev/null +++ b/include/aidge/operator/ReduceMean.hpp @@ -0,0 +1,146 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_REDUCEMEAN_H_ +#define AIDGE_CORE_OPERATOR_REDUCEMEAN_H_ + +#include <array> +#include <cmath> +#include <numeric> +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class ReduceMeanAttr { Axes, KeepDims }; + +template <DimIdx_t DIM> +class ReduceMean_Op : public OperatorTensor, + public Registrable<ReduceMean_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ReduceMean_Op<DIM> &)>, + public StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t> { + + public: + static const std::string Type; + + ReduceMean_Op() = delete; + + using Attributes_ = StaticAttributes<ReduceMeanAttr, std::array<int, DIM>, DimSize_t>; + template <ReduceMeanAttr e> + using attr = typename Attributes_::template attr<e>; + + constexpr ReduceMean_Op(const std::array<int, DIM> &axes, DimSize_t keep_dims) + : OperatorTensor(Type, 1, 0, 1), + Attributes_(attr<ReduceMeanAttr::Axes>(axes), + attr<ReduceMeanAttr::KeepDims>(keep_dims)) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ReduceMean_Op(const ReduceMean_Op<DIM>& op) + : OperatorTensor(op), + Attributes_(op) + { + mImpl = op.mImpl ? Registrar<ReduceMean_Op<DIM>>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ReduceMean_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ReduceMean_Op<DIM>>(*this); + } + + void computeOutputDims() override final { + if (!getInput(0)->empty()) { + std::vector<DimSize_t> outDims; + for(std::size_t d=0; d<getInput(0)->dims().size(); ++d) + { + bool reducedDim = false; + for(std::size_t i=0; i<DIM; ++i) + { + int axis_ = this->template getAttr<ReduceMeanAttr::Axes>()[i]; + std::size_t axis= axis_>=0? axis_: axis_ + getInput(0)->nbDims(); + if(axis == d) + { + reducedDim = true; + break; + } + } + if(reducedDim) + { + if(this->template getAttr<ReduceMeanAttr::KeepDims>()) + outDims.push_back(1); + } + else + outDims.push_back(getInput(0)->dims()[d]); + } + if(outDims.size()>0) + mOutputs[0]->resize(outDims); + else + mOutputs[0]->resize({1}); + } + } + + void setBackend(const std::string &name) override { + mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this); + mOutputs[0]->setBackend(name); + + // FIXME: temporary workaround + getInput(0)->setBackend(name); + } + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> ReduceMean(const std::array<int, DIM> &axes, + DimSize_t keep_dims=1, + const std::string& name = "") { + // FIXME: properly handle default w&b initialization in every cases + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ReduceMean, not supported"); + return std::make_shared<Node>(std::make_shared<ReduceMean_Op<static_cast<DimIdx_t>(DIM)>>(axes, keep_dims), name); + +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> ReduceMean( + int const (&axes)[DIM], + DimSize_t keep_dims = 1, + const std::string& name = "") { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ReduceMean, not supported"); + return ReduceMean(to_array(axes), keep_dims, name); +} + +template <DimIdx_t DIM> +const std::string ReduceMean_Op<DIM>::Type = "ReduceMean"; + +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::ReduceMeanAttr>::data[] = {"Axes", "KeepDims"}; +} + +#endif /* AIDGE_CORE_OPERATOR_REDUCEMEAN_H_ */ diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1ffa045960037f35167ae2d6e8904c49e2c55560 --- /dev/null +++ b/include/aidge/operator/Reshape.hpp @@ -0,0 +1,97 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_RESHAPE_H_ +#define AIDGE_CORE_OPERATOR_RESHAPE_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +enum class ReshapeAttr { Shape }; + +class Reshape_Op : public OperatorTensor, + public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>, + public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { + +public: + static const std::string Type; + + Reshape_Op() = delete; + + using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; + template <ReshapeAttr e> + using attr = typename Attributes_::template attr<e>; + + Reshape_Op(const std::vector<std::int64_t>& shape) + : OperatorTensor(Type, 1, 0, 1), + Attributes_(attr<ReshapeAttr::Shape>(shape)) + {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Reshape_Op(const Reshape_Op& op) + : OperatorTensor(op), + Attributes_(op) + { + mImpl = op.mImpl ? Registrar<Reshape_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Reshape_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Reshape_Op>(*this); + } + + void computeOutputDims() override final; + + void setBackend(const std::string& name) override { + mImpl = Registrar<Reshape_Op>::create(name)(*this); + mOutputs[0]->setBackend(name); + + // FIXME: temporary workaround + getInput(0)->setBackend(name); + } + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, + const std::string &name = "") { + // FIXME: properly handle default w&b initialization in every cases + return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" }; +} + +#endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */ diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 98e082ac27f7cdf90d5d0464d811f116ae9f59ae..54f1d98d2f61d18dd821c9f0a6b574bb52b0c9f0 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -32,7 +32,7 @@ class Scaling_Op : public OperatorTensor, public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, public StaticAttributes<ScalingAttr, float, size_t, bool> { public: - static constexpr const char* Type = "Scaling"; + static const std::string Type; Scaling_Op() = delete; @@ -66,11 +66,9 @@ public: return std::make_shared<Scaling_Op>(*this); } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Scaling_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - // FIXME: temporary workaround - mInputs[0]->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName() { diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index b92c1818d49b53d4a2eda9a8d2704a06ca2980ca..12a7425f3339b7fbc0ae010639aacf23d97b0f5f 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -24,25 +24,26 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class SliceAttr { Beginning, SliceDims }; +enum class SliceAttr { Starts, Ends, Axes }; class Slice_Op : public OperatorTensor, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, - public StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>> { + public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> { public: - static constexpr const char *Type = "Slice"; + static const std::string Type; Slice_Op() = delete; - using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>>; + using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; - Slice_Op(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims) + Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes) : OperatorTensor(Type, 1, 0, 1), - Attributes_(attr<SliceAttr::Beginning>(beginningPos), - attr<SliceAttr::SliceDims>(sliceDims)) + Attributes_(attr<SliceAttr::Starts>(starts), + attr<SliceAttr::Ends>(ends), + attr<SliceAttr::Axes>(axes)) {} /** @@ -65,37 +66,11 @@ public: */ std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } - void computeOutputDims() override final { - if (!getInput(0) || (getInput(0)->empty())) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); - } - std::vector<DimSize_t> outputDims = std::vector<DimSize_t>(getInput(0)->nbDims()); - const std::vector<DimSize_t> inputDims = getInput(0)->dims(); - - // Check that the sliced Tensor is actually part of the input Tensor - // For a 5*5 tensor ('x') and a 3*3 slice kernel ('o'): - // xxxxx xxxxx - // xxxxx xxxxx - // xxooo --> ok xxxoo --> out of bound - // xxooo xxxoo - // xxooo xxxoo - std::vector<std::size_t> beginningCoords = mInputs[0]->getCoord(this->template getAttr<SliceAttr::Beginning>()); - for (std::size_t i = 0; i < getInput(0)->nbDims(); ++i) { - if (beginningCoords[i] + this->template getAttr<SliceAttr::SliceDims>()[i] > inputDims[i]) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); - } else { - outputDims[i] = this->template getAttr<SliceAttr::SliceDims>()[i]; - } - } - mOutputs[0]->resize(outputDims); - } + void computeOutputDims() override final; - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<Slice_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -106,17 +81,31 @@ public: } }; - -inline std::shared_ptr<Node> Slice(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims, +/** + * @brief Exract a sub-Tensor from a bigger original Tensor. + * @param starts Indexes for each dimension of the first element. + * Can be a negative value. Negative values start their reference from the last index. + * ``-1`` referes to the last index of a dimension. + * @param ends Indexes for each dimension of the last element. + * Can be a negative value. Negative values start their reference from the last index. + * ``-1`` referes to the last index of a dimension. + * @param axes Dimensions for which start/end indexes apply. Not specifying a dimensions + * means the whole dimensions is extracted. + * @param name Name of the Operator. + * @return std::shared_ptr<Node> A Node containing the Operator. + */ +inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts, + const std::vector<std::int32_t> ends, + const std::vector<std::int32_t> axes, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases - return std::make_shared<Node>(std::make_shared<Slice_Op>(beginningPos, sliceDims), name); + return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Beginning", "SliceDims" }; +const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" }; } #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index d5c91945e83469dc9c6fef2b5adef026790b568d..ed6689dc97ef17276df260cd90649f2a75b10007 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -16,29 +16,44 @@ #include <memory> #include <vector> -#include "aidge/utils/Registrar.hpp" -#include "aidge/operator/OperatorTensor.hpp" + #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/data/Data.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" namespace Aidge { +enum class SoftmaxAttr { AxisIdx }; class Softmax_Op : public OperatorTensor, - public Registrable<Softmax_Op, std::string, std::unique_ptr<OperatorImpl>(const Softmax_Op&)> { + public Registrable<Softmax_Op, + std::string, + std::unique_ptr<OperatorImpl>(const Softmax_Op&)>, + public StaticAttributes<SoftmaxAttr, int> { + public: - static constexpr const char* Type = "Softmax"; + static const std::string Type; + + Softmax_Op() = delete; - Softmax_Op() : OperatorTensor(Type, 1, 0, 1) {} + using Attributes_ = StaticAttributes<SoftmaxAttr, int>; + template <SoftmaxAttr e> using attr = typename Attributes_::template attr<e>; + Softmax_Op(int axis) + : OperatorTensor(Type, 1, 0, 1), + Attributes_(attr<SoftmaxAttr::AxisIdx>(axis)) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ Softmax_Op(const Softmax_Op& op) - : OperatorTensor(op) + : OperatorTensor(op), + Attributes_(op) { mImpl = op.mImpl ? Registrar<Softmax_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -51,12 +66,9 @@ public: return std::make_shared<Softmax_Op>(*this); } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Softmax_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ @@ -67,9 +79,14 @@ public: } }; -inline std::shared_ptr<Node> Softmax(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Softmax_Op>(), name); +inline std::shared_ptr<Node> Softmax(int axis, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Softmax_Op>(axis), name); } +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::SoftmaxAttr>::data[] = {"Axis"}; } #endif /* AIDGE_CORE_OPERATOR_SOFTMAX_H_ */ diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp index 1fe609fc2913afcda735ba2859126188aad4de5f..32adfdb93db1e9da857f4147efdcfe64bbb34475 100644 --- a/include/aidge/operator/Sqrt.hpp +++ b/include/aidge/operator/Sqrt.hpp @@ -34,7 +34,7 @@ public: const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: - static constexpr const char* Type = "Sqrt"; + static const std::string Type; Sqrt_Op() : OperatorTensor(Type, 1, 0, 1) {} @@ -56,12 +56,9 @@ public: return std::make_shared<Sqrt_Op>(*this); } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Sqrt_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index d141ad42015838e89e6d59c22bcefe56e795170c..ee5efa24dc24ebcd5ad4c45491c968caf691eee9 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -34,7 +34,7 @@ public: const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: - static constexpr const char* Type = "Sub"; + static const std::string Type; Sub_Op() : OperatorTensor(Type, 2, 0, 1) {} @@ -59,13 +59,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Sub_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f111be76cd712265e92e2e4c3e0220f79e13b1f7 --- /dev/null +++ b/include/aidge/operator/Transpose.hpp @@ -0,0 +1,125 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_TRANSPOSE_H_ +#define AIDGE_CORE_OPERATOR_TRANSPOSE_H_ + +#include <array> +#include <cmath> +#include <numeric> +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class TransposeAttr { OutputDimsOrder }; + +template <DimIdx_t DIM> +class Transpose_Op : public OperatorTensor, + public Registrable<Transpose_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Transpose_Op<DIM> &)>, + public StaticAttributes<TransposeAttr, + std::array<DimSize_t, DIM>> { + + public: + static const std::string Type; + + Transpose_Op() = delete; + + using Attributes_ = StaticAttributes<TransposeAttr, + std::array<DimSize_t, DIM>>; + template <TransposeAttr e> + using attr = typename Attributes_::template attr<e>; + + constexpr Transpose_Op(const std::array<DimSize_t, DIM> &output_dims_order) + : OperatorTensor(Type, 1, 0, 1), + Attributes_(attr<TransposeAttr::OutputDimsOrder>(output_dims_order)) { } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Transpose_Op(const Transpose_Op<DIM>& op) + : OperatorTensor(op), + Attributes_(op) + { + mImpl = op.mImpl ? Registrar<Transpose_Op<DIM>>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Transpose_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Transpose_Op<DIM>>(*this); + } + + void computeOutputDims() override final { + if (!getInput(0)->empty()) { + auto attr = (this)->getStaticAttributes(); + const std::array<DimSize_t, DIM>& outDimsOrder = static_cast<const std::array<DimSize_t, DIM>&>(std::get<0>(attr)); + std::vector<DimSize_t> outputDims; + for (std::size_t i = 0; i < DIM; ++i) { + outputDims.push_back(getInput(0)->dims()[outDimsOrder[i]]); + } + mOutputs[0]->resize(outputDims); + } + } + + void setBackend(const std::string &name) override { + mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this); + mOutputs[0]->setBackend(name); + + // FIXME: temporary workaround + getInput(0)->setBackend(name); + } + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> Transpose(const std::array<DimSize_t, DIM> &output_dims_order, + const std::string& name = "") { + // FIXME: properly handle default w&b initialization in every cases + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Transpose, not supported"); + return std::make_shared<Node>(std::make_shared<Transpose_Op<static_cast<DimIdx_t>(DIM)>>(output_dims_order), name); +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> Transpose( + DimSize_t const (&output_dims_order)[DIM], + const std::string& name = "") { + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Transpose, not supported"); + return Transpose(to_array(output_dims_order), name); +} + +template <DimIdx_t DIM> +const std::string Transpose_Op<DIM>::Type = "Transpose"; + +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::TransposeAttr>::data[] = {"OutputDimsOrder"}; +} + +#endif /* AIDGE_CORE_OPERATOR_TRANSPOSE_H_ */ diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp index a17ead8f8f5fa5106c375050ef5b82e6f149535a..fb4bc22c69ec2b4e8dcc6178c9fcda0a85190f78 100644 --- a/include/aidge/recipies/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -38,10 +38,28 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void fuseMulAdd(std::shared_ptr<GraphView> graphView); +// REMOVE Dropout + +/** + * @brief Remove ``Dropout`` Node. + * + * @param nodes Node to remove. + */ +void removeDropout(std::shared_ptr<Node> dropout); + + +void removeDropout(std::shared_ptr<MatchSolution> solution); + +/** + * @brief Remove ``Dropout`` Node. + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ +void removeDropout(std::shared_ptr<GraphView> graphView); // REMOVE FLATTEN + FC -> FC @@ -58,7 +76,7 @@ void removeFlatten(std::shared_ptr<MatchSolution> solution); /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void removeFlatten(std::shared_ptr<GraphView> graphView); @@ -80,7 +98,7 @@ void fuseBatchNorm(std::shared_ptr<MatchSolution> solution); * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * - * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + * @param graphView Graph view to use graph matching on, in order to apply transformations. */ void fuseBatchNorm(std::shared_ptr<GraphView> graphView); @@ -89,6 +107,14 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No // std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); // void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); + +/** + * Add Convert operators where needed to ensure no conversion needs to be done + * at the Operator level. +*/ +void explicitCastMove(std::shared_ptr<GraphView> graphView); + + } // namespace Aidge #endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ diff --git a/include/aidge/utils/ArrayHelpers.hpp b/include/aidge/utils/ArrayHelpers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b0db3ca11c10c10a3ce63c3c4809cf7ae09173da --- /dev/null +++ b/include/aidge/utils/ArrayHelpers.hpp @@ -0,0 +1,125 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_UTILS_ARRAYHELPERS_H_ +#define AIDGE_CORE_UTILS_ARRAYHELPERS_H_ + +#include <array> + +namespace Aidge { + +// Helper to create default arrays +template <typename T, std::size_t ... Is> +constexpr std::array<T, sizeof...(Is)> +create_array_impl(T value, std::index_sequence<Is...>) +{ + // cast Is to void to remove the warning: unused value + return {{(static_cast<void>(Is), value)...}}; +} + +template <typename T, std::size_t N> +constexpr std::array<T, N> create_array(const T& value) +{ + return create_array_impl(value, std::make_index_sequence<N>()); +} + + +// Helper to convert vector to array +template <typename T, typename Iter, std::size_t... Is> +constexpr auto to_array(Iter &iter, std::index_sequence<Is...>) -> std::array<T, sizeof...(Is)> { + return {{((void)Is, T(*iter++))...}}; +} + +/** + * @brief Convert an object with an iterator to an std::array. + */ +template <std::size_t N, typename U = void, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type, + typename T = std::conditional_t<std::is_same<U, void>{}, V, U>> +constexpr auto to_array(Iter iter) -> std::array<T, N> { + return to_array<T>(iter, std::make_index_sequence<N>{}); +} + +namespace detail { + +template <class T, std::size_t N, std::size_t... I> +constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], std::index_sequence<I...>) { + return {{a[I]...}}; +} + +} // namespace detail + +/** + * @brief Convert a C-stype array into a C++ std::array. + * + * @tparam T Data type. + * @tparam N Number of elements. + * @param a C-style array to convert. + * @return constexpr std::array<std::remove_cv_t<T>, N> + */ +template <class T, std::size_t N> +constexpr std::array<std::remove_cv_t<T>, N> to_array(T (&a)[N]) { + return detail::to_array_impl(a, std::make_index_sequence<N>{}); +} + +template <typename T, std::size_t N, std::size_t... I> +constexpr std::array<T, N + 1> append(std::array<T, N> a, T t, std::index_sequence<I...>) { + return std::array<T, N + 1>{a[I]..., t}; +} + +template <typename T, std::size_t N, std::size_t... I> +constexpr std::array<T, N + 1> append(T t, std::array<T, N> a, std::index_sequence<I...>) { + return std::array<T, N + 1>{t, a[I]...}; +} + +/** + * @brief Create a new array concatenating the initial one with the value to + * add. + * @details append({1,2,7}, 3) -> {1,2,7,3} + * + * @tparam T Data type. + * @tparam N Number of elements in the initilial array. + * @param a Initial array. + * @param t Element to add. + * @return constexpr std::array<T, N + 1> + */ +template <typename T, std::size_t N> +constexpr std::array<T, N + 1> append(std::array<T, N> a, T t) { + return append(a, t, std::make_index_sequence<N>()); +} + +template <typename T, std::size_t N> +constexpr std::array<T, N + 1> append(T t, std::array<T, N> a) { + return append(t, a, std::make_index_sequence<N>()); +} + +// Generic helper for initializing a Tensor +template <typename T, std::size_t SIZE_0> +struct Array1D { + T data[SIZE_0]; +}; + +template <typename T, std::size_t SIZE_0, std::size_t SIZE_1> +struct Array2D { + T data[SIZE_0][SIZE_1]; +}; + +template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2> +struct Array3D { + T data[SIZE_0][SIZE_1][SIZE_2]; +}; + +template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3> +struct Array4D { + T data[SIZE_0][SIZE_1][SIZE_2][SIZE_3]; +}; +} + +#endif /* AIDGE_CORE_UTILS_ARRAYHELPERS_H_ */ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index ece74509d466800c870d73d1e0bbe1d639f8bf54..66a07eb0ce21354b20f1ca416cc68d26d9bd6280 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -51,26 +51,29 @@ public: template <class C> struct Registrar { - Registrar(const typename C::registrar_key& key, typename C::registrar_type func) { + typedef typename C::registrar_key registrar_key; + typedef typename C::registrar_type registrar_type; + + Registrar(const registrar_key& key, registrar_type func) { //printf("REGISTRAR: %s\n", key.c_str()); bool newInsert; std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func)); //assert(newInsert && "registrar already exists"); } - static bool exists(const typename C::registrar_key& key) { + static bool exists(const registrar_key& key) { const auto it = C::registry().find(key); return (it != C::registry().end()); } - static auto create(const typename C::registrar_key& key){ + static auto create(const registrar_key& key){ const auto it = C::registry().find(key); assert(it != C::registry().end() && "invalid registrar key"); return (*it).second; } - static std::vector<typename C::registrar_key> getKeys(){ - std::vector<typename C::registrar_key> keys; + static std::vector<registrar_key> getKeys(){ + std::vector<registrar_key> keys; for(auto keyValue : C::registry()) keys.push_back(keyValue.first); return keys; diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp index 50ed0895e82bb468dee57264534f0ec3a486a815..a90a08b01915c461bc8951c08ee2dbd979b957de 100644 --- a/include/aidge/utils/StaticAttributes.hpp +++ b/include/aidge/utils/StaticAttributes.hpp @@ -16,6 +16,7 @@ #include <cassert> #include <cstddef> #include <typeinfo> +#include <array> #include "aidge/utils/Attributes.hpp" #include "aidge/utils/ErrorHandling.hpp" diff --git a/include/aidge/utils/TensorUtils.hpp b/include/aidge/utils/TensorUtils.hpp index 6387619546c66922e48cf95a8a56487d4b0d0641..1bfe0929bf67bb0c6d3b893f3dbaf6993dcfd6ff 100644 --- a/include/aidge/utils/TensorUtils.hpp +++ b/include/aidge/utils/TensorUtils.hpp @@ -14,6 +14,7 @@ #include <cmath> // std::abs #include "aidge/data/Tensor.hpp" +namespace Aidge { /** * @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is: * @@ -31,22 +32,23 @@ * @param absolute absolute error allowed (shoulmd be positive) * @return true if both tensor are approximately equal and have the datatype, shape. Else return false */ -template <typename T> -bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute){ - assert(t1.dataType() == t2.dataType()); - assert(t1.dataType() == NativeType<T>::type); +template <typename T1, typename T2 = T1> +bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float absolute = 1e-8f){ + assert(t1.dataType() == NativeType<T1>::type); + assert(t2.dataType() == NativeType<T2>::type); assert(relative >= 0); assert(absolute >= 0 && absolute<=1); if (t1.size() != t2.size()){ return false; } - for(size_t i; i < t1.size(); ++i){ - if (static_cast<float>(std::abs(t1.get<T>(i) - t2.get<T>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T>(i)))))){ + for(size_t i = 0; i < t1.size(); ++i){ + if (static_cast<float>(std::abs(t1.get<T1>(i) - t2.get<T2>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T2>(i)))))){ return false; } } return true; } +} #endif /* AIDGE_CORE_UTILS_TENSOR_UTILS_H_s */ diff --git a/include/aidge/utils/Types.h b/include/aidge/utils/Types.h index d65279f1f4d36498ea7653428332690fc99a5def..b601df1cb8f8fa81cd2339e7eb393f7297e63499 100644 --- a/include/aidge/utils/Types.h +++ b/include/aidge/utils/Types.h @@ -24,6 +24,10 @@ namespace Aidge /// Tensor ////////////////////////////////////// +/// @brief Device index in a given backend +using DeviceIdx_t = std::uint8_t; +constexpr DeviceIdx_t MaxDeviceIdx = std::numeric_limits<DeviceIdx_t>::max(); + /// @brief Number of elements used for scheduling using NbElts_t = std::size_t; constexpr NbElts_t MaxElts = std::numeric_limits<NbElts_t>::max(); diff --git a/include/aidge/utils/future_std/span.hpp b/include/aidge/utils/future_std/span.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ba8d6c0317135ac9a934891a8510b844fbb0dc85 --- /dev/null +++ b/include/aidge/utils/future_std/span.hpp @@ -0,0 +1,618 @@ + +/* +This is an implementation of C++20's std::span +http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/n4820.pdf +*/ + +// Copyright Tristan Brindle 2018. +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file ../../LICENSE_1_0.txt or copy at +// https://www.boost.org/LICENSE_1_0.txt) + +#ifndef AIDGE_CORE_UTILS_FUTURE_STD_SPAN_H_ +#define AIDGE_CORE_UTILS_FUTURE_STD_SPAN_H_ + +#include <array> +#include <cstddef> +#include <cstdint> +#include <type_traits> + +#ifndef TCB_SPAN_NO_EXCEPTIONS +// Attempt to discover whether we're being compiled with exception support +#if !(defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) +#define TCB_SPAN_NO_EXCEPTIONS +#endif +#endif + +#ifndef TCB_SPAN_NO_EXCEPTIONS +#include <cstdio> +#include <stdexcept> +#endif + +// Various feature test macros + +#ifndef TCB_SPAN_NAMESPACE_NAME +#define TCB_SPAN_NAMESPACE_NAME future_std +#endif + +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define TCB_SPAN_HAVE_CPP17 +#endif + +#if __cplusplus >= 201402L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L) +#define TCB_SPAN_HAVE_CPP14 +#endif + +namespace TCB_SPAN_NAMESPACE_NAME { + +// Establish default contract checking behavior +#if !defined(TCB_SPAN_THROW_ON_CONTRACT_VIOLATION) && \ + !defined(TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION) && \ + !defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#if defined(NDEBUG) || !defined(TCB_SPAN_HAVE_CPP14) +#define TCB_SPAN_NO_CONTRACT_CHECKING +#else +#define TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION +#endif +#endif + +#if defined(TCB_SPAN_THROW_ON_CONTRACT_VIOLATION) +struct contract_violation_error : std::logic_error { + explicit contract_violation_error(const char* msg) : std::logic_error(msg) + {} +}; + +inline void contract_violation(const char* msg) +{ + throw contract_violation_error(msg); +} + +#elif defined(TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION) +[[noreturn]] inline void contract_violation(const char* /*unused*/) +{ + std::terminate(); +} +#endif + +#if !defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#define TCB_SPAN_STRINGIFY(cond) #cond +#define TCB_SPAN_EXPECT(cond) \ + cond ? (void) 0 : contract_violation("Expected " TCB_SPAN_STRINGIFY(cond)) +#else +#define TCB_SPAN_EXPECT(cond) +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_inline_variables) +#define TCB_SPAN_INLINE_VAR inline +#else +#define TCB_SPAN_INLINE_VAR +#endif + +#if defined(TCB_SPAN_HAVE_CPP14) || \ + (defined(__cpp_constexpr) && __cpp_constexpr >= 201304) +#define TCB_SPAN_HAVE_CPP14_CONSTEXPR +#endif + +#if defined(TCB_SPAN_HAVE_CPP14_CONSTEXPR) +#define TCB_SPAN_CONSTEXPR14 constexpr +#else +#define TCB_SPAN_CONSTEXPR14 +#endif + +#if defined(TCB_SPAN_HAVE_CPP14_CONSTEXPR) && \ + (!defined(_MSC_VER) || _MSC_VER > 1900) +#define TCB_SPAN_CONSTEXPR_ASSIGN constexpr +#else +#define TCB_SPAN_CONSTEXPR_ASSIGN +#endif + +#if defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#define TCB_SPAN_CONSTEXPR11 constexpr +#else +#define TCB_SPAN_CONSTEXPR11 TCB_SPAN_CONSTEXPR14 +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_deduction_guides) +#define TCB_SPAN_HAVE_DEDUCTION_GUIDES +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_byte) +#define TCB_SPAN_HAVE_STD_BYTE +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_array_constexpr) +#define TCB_SPAN_HAVE_CONSTEXPR_STD_ARRAY_ETC +#endif + +#if defined(TCB_SPAN_HAVE_CONSTEXPR_STD_ARRAY_ETC) +#define TCB_SPAN_ARRAY_CONSTEXPR constexpr +#else +#define TCB_SPAN_ARRAY_CONSTEXPR +#endif + +#ifdef TCB_SPAN_HAVE_STD_BYTE +using byte = std::byte; +#else +using byte = unsigned char; +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) +#define TCB_SPAN_NODISCARD [[nodiscard]] +#else +#define TCB_SPAN_NODISCARD +#endif + +TCB_SPAN_INLINE_VAR constexpr std::size_t dynamic_extent = SIZE_MAX; + +template <typename ElementType, std::size_t Extent = dynamic_extent> +class span; + +namespace detail { + +template <typename E, std::size_t S> +struct span_storage { + constexpr span_storage() noexcept = default; + + constexpr span_storage(E* p_ptr, std::size_t /*unused*/) noexcept + : ptr(p_ptr) + {} + + E* ptr = nullptr; + static constexpr std::size_t size = S; +}; + +template <typename E> +struct span_storage<E, dynamic_extent> { + constexpr span_storage() noexcept = default; + + constexpr span_storage(E* p_ptr, std::size_t p_size) noexcept + : ptr(p_ptr), size(p_size) + {} + + E* ptr = nullptr; + std::size_t size = 0; +}; + +// Reimplementation of C++17 std::size() and std::data() +#if defined(TCB_SPAN_HAVE_CPP17) || \ + defined(__cpp_lib_nonmember_container_access) +using std::data; +using std::size; +#else +template <class C> +constexpr auto size(const C& c) -> decltype(c.size()) +{ + return c.size(); +} + +template <class T, std::size_t N> +constexpr std::size_t size(const T (&)[N]) noexcept +{ + return N; +} + +template <class C> +constexpr auto data(C& c) -> decltype(c.data()) +{ + return c.data(); +} + +template <class C> +constexpr auto data(const C& c) -> decltype(c.data()) +{ + return c.data(); +} + +template <class T, std::size_t N> +constexpr T* data(T (&array)[N]) noexcept +{ + return array; +} + +template <class E> +constexpr const E* data(std::initializer_list<E> il) noexcept +{ + return il.begin(); +} +#endif // TCB_SPAN_HAVE_CPP17 + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_void_t) +using std::void_t; +#else +template <typename...> +using void_t = void; +#endif + +template <typename T> +using uncvref_t = + typename std::remove_cv<typename std::remove_reference<T>::type>::type; + +template <typename> +struct is_span : std::false_type {}; + +template <typename T, std::size_t S> +struct is_span<span<T, S>> : std::true_type {}; + +template <typename> +struct is_std_array : std::false_type {}; + +template <typename T, std::size_t N> +struct is_std_array<std::array<T, N>> : std::true_type {}; + +template <typename, typename = void> +struct has_size_and_data : std::false_type {}; + +template <typename T> +struct has_size_and_data<T, void_t<decltype(detail::size(std::declval<T>())), + decltype(detail::data(std::declval<T>()))>> + : std::true_type {}; + +template <typename C, typename U = uncvref_t<C>> +struct is_container { + static constexpr bool value = + !is_span<U>::value && !is_std_array<U>::value && + !std::is_array<U>::value && has_size_and_data<C>::value; +}; + +template <typename T> +using remove_pointer_t = typename std::remove_pointer<T>::type; + +template <typename, typename, typename = void> +struct is_container_element_type_compatible : std::false_type {}; + +template <typename T, typename E> +struct is_container_element_type_compatible< + T, E, + typename std::enable_if< + !std::is_same< + typename std::remove_cv<decltype(detail::data(std::declval<T>()))>::type, + void>::value && + std::is_convertible< + remove_pointer_t<decltype(detail::data(std::declval<T>()))> (*)[], + E (*)[]>::value + >::type> + : std::true_type {}; + +template <typename, typename = size_t> +struct is_complete : std::false_type {}; + +template <typename T> +struct is_complete<T, decltype(sizeof(T))> : std::true_type {}; + +} // namespace detail + +template <typename ElementType, std::size_t Extent> +class span { + static_assert(std::is_object<ElementType>::value, + "A span's ElementType must be an object type (not a " + "reference type or void)"); + static_assert(detail::is_complete<ElementType>::value, + "A span's ElementType must be a complete type (not a forward " + "declaration)"); + static_assert(!std::is_abstract<ElementType>::value, + "A span's ElementType cannot be an abstract class type"); + + using storage_type = detail::span_storage<ElementType, Extent>; + +public: + // constants and types + using element_type = ElementType; + using value_type = typename std::remove_cv<ElementType>::type; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = element_type*; + using const_pointer = const element_type*; + using reference = element_type&; + using const_reference = const element_type&; + using iterator = pointer; + using reverse_iterator = std::reverse_iterator<iterator>; + + static constexpr size_type extent = Extent; + + // [span.cons], span constructors, copy, assignment, and destructor + template < + std::size_t E = Extent, + typename std::enable_if<(E == dynamic_extent || E <= 0), int>::type = 0> + constexpr span() noexcept + {} + + TCB_SPAN_CONSTEXPR11 span(pointer ptr, size_type count) + : storage_(ptr, count) + { + TCB_SPAN_EXPECT(extent == dynamic_extent || count == extent); + } + + TCB_SPAN_CONSTEXPR11 span(pointer first_elem, pointer last_elem) + : storage_(first_elem, last_elem - first_elem) + { + TCB_SPAN_EXPECT(extent == dynamic_extent || + last_elem - first_elem == + static_cast<std::ptrdiff_t>(extent)); + } + + template <std::size_t N, std::size_t E = Extent, + typename std::enable_if< + (E == dynamic_extent || N == E) && + detail::is_container_element_type_compatible< + element_type (&)[N], ElementType>::value, + int>::type = 0> + constexpr span(element_type (&arr)[N]) noexcept : storage_(arr, N) + {} + + template <typename T, std::size_t N, std::size_t E = Extent, + typename std::enable_if< + (E == dynamic_extent || N == E) && + detail::is_container_element_type_compatible< + std::array<T, N>&, ElementType>::value, + int>::type = 0> + TCB_SPAN_ARRAY_CONSTEXPR span(std::array<T, N>& arr) noexcept + : storage_(arr.data(), N) + {} + + template <typename T, std::size_t N, std::size_t E = Extent, + typename std::enable_if< + (E == dynamic_extent || N == E) && + detail::is_container_element_type_compatible< + const std::array<T, N>&, ElementType>::value, + int>::type = 0> + TCB_SPAN_ARRAY_CONSTEXPR span(const std::array<T, N>& arr) noexcept + : storage_(arr.data(), N) + {} + + template < + typename Container, std::size_t E = Extent, + typename std::enable_if< + E == dynamic_extent && detail::is_container<Container>::value && + detail::is_container_element_type_compatible< + Container&, ElementType>::value, + int>::type = 0> + constexpr span(Container& cont) + : storage_(detail::data(cont), detail::size(cont)) + {} + + template < + typename Container, std::size_t E = Extent, + typename std::enable_if< + E == dynamic_extent && detail::is_container<Container>::value && + detail::is_container_element_type_compatible< + const Container&, ElementType>::value, + int>::type = 0> + constexpr span(const Container& cont) + : storage_(detail::data(cont), detail::size(cont)) + {} + + constexpr span(const span& other) noexcept = default; + + template <typename OtherElementType, std::size_t OtherExtent, + typename std::enable_if< + (Extent == dynamic_extent || OtherExtent == dynamic_extent || + Extent == OtherExtent) && + std::is_convertible<OtherElementType (*)[], + ElementType (*)[]>::value, + int>::type = 0> + constexpr span(const span<OtherElementType, OtherExtent>& other) noexcept + : storage_(other.data(), other.size()) + {} + + ~span() noexcept = default; + + TCB_SPAN_CONSTEXPR_ASSIGN span& + operator=(const span& other) noexcept = default; + + // [span.sub], span subviews + template <std::size_t Count> + TCB_SPAN_CONSTEXPR11 span<element_type, Count> first() const + { + TCB_SPAN_EXPECT(Count <= size()); + return {data(), Count}; + } + + template <std::size_t Count> + TCB_SPAN_CONSTEXPR11 span<element_type, Count> last() const + { + TCB_SPAN_EXPECT(Count <= size()); + return {data() + (size() - Count), Count}; + } + + template <std::size_t Offset, std::size_t Count = dynamic_extent> + using subspan_return_t = + span<ElementType, Count != dynamic_extent + ? Count + : (Extent != dynamic_extent ? Extent - Offset + : dynamic_extent)>; + + template <std::size_t Offset, std::size_t Count = dynamic_extent> + TCB_SPAN_CONSTEXPR11 subspan_return_t<Offset, Count> subspan() const + { + TCB_SPAN_EXPECT(Offset <= size() && + (Count == dynamic_extent || Offset + Count <= size())); + return {data() + Offset, + Count != dynamic_extent ? Count : size() - Offset}; + } + + TCB_SPAN_CONSTEXPR11 span<element_type, dynamic_extent> + first(size_type count) const + { + TCB_SPAN_EXPECT(count <= size()); + return {data(), count}; + } + + TCB_SPAN_CONSTEXPR11 span<element_type, dynamic_extent> + last(size_type count) const + { + TCB_SPAN_EXPECT(count <= size()); + return {data() + (size() - count), count}; + } + + TCB_SPAN_CONSTEXPR11 span<element_type, dynamic_extent> + subspan(size_type offset, size_type count = dynamic_extent) const + { + TCB_SPAN_EXPECT(offset <= size() && + (count == dynamic_extent || offset + count <= size())); + return {data() + offset, + count == dynamic_extent ? size() - offset : count}; + } + + // [span.obs], span observers + constexpr size_type size() const noexcept { return storage_.size; } + + constexpr size_type size_bytes() const noexcept + { + return size() * sizeof(element_type); + } + + TCB_SPAN_NODISCARD constexpr bool empty() const noexcept + { + return size() == 0; + } + + // [span.elem], span element access + TCB_SPAN_CONSTEXPR11 reference operator[](size_type idx) const + { + TCB_SPAN_EXPECT(idx < size()); + return *(data() + idx); + } + + TCB_SPAN_CONSTEXPR11 reference front() const + { + TCB_SPAN_EXPECT(!empty()); + return *data(); + } + + TCB_SPAN_CONSTEXPR11 reference back() const + { + TCB_SPAN_EXPECT(!empty()); + return *(data() + (size() - 1)); + } + + constexpr pointer data() const noexcept { return storage_.ptr; } + + // [span.iterators], span iterator support + constexpr iterator begin() const noexcept { return data(); } + + constexpr iterator end() const noexcept { return data() + size(); } + + TCB_SPAN_ARRAY_CONSTEXPR reverse_iterator rbegin() const noexcept + { + return reverse_iterator(end()); + } + + TCB_SPAN_ARRAY_CONSTEXPR reverse_iterator rend() const noexcept + { + return reverse_iterator(begin()); + } + +private: + storage_type storage_{}; +}; + +#ifdef TCB_SPAN_HAVE_DEDUCTION_GUIDES + +/* Deduction Guides */ +template <class T, size_t N> +span(T (&)[N])->span<T, N>; + +template <class T, size_t N> +span(std::array<T, N>&)->span<T, N>; + +template <class T, size_t N> +span(const std::array<T, N>&)->span<const T, N>; + +template <class Container> +span(Container&)->span<typename std::remove_reference< + decltype(*detail::data(std::declval<Container&>()))>::type>; + +template <class Container> +span(const Container&)->span<const typename Container::value_type>; + +#endif // TCB_HAVE_DEDUCTION_GUIDES + +template <typename ElementType, std::size_t Extent> +constexpr span<ElementType, Extent> +make_span(span<ElementType, Extent> s) noexcept +{ + return s; +} + +template <typename T, std::size_t N> +constexpr span<T, N> make_span(T (&arr)[N]) noexcept +{ + return {arr}; +} + +template <typename T, std::size_t N> +TCB_SPAN_ARRAY_CONSTEXPR span<T, N> make_span(std::array<T, N>& arr) noexcept +{ + return {arr}; +} + +template <typename T, std::size_t N> +TCB_SPAN_ARRAY_CONSTEXPR span<const T, N> +make_span(const std::array<T, N>& arr) noexcept +{ + return {arr}; +} + +template <typename Container> +constexpr span<typename std::remove_reference< + decltype(*detail::data(std::declval<Container&>()))>::type> +make_span(Container& cont) +{ + return {cont}; +} + +template <typename Container> +constexpr span<const typename Container::value_type> +make_span(const Container& cont) +{ + return {cont}; +} + +template <typename ElementType, std::size_t Extent> +span<const byte, ((Extent == dynamic_extent) ? dynamic_extent + : sizeof(ElementType) * Extent)> +as_bytes(span<ElementType, Extent> s) noexcept +{ + return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()}; +} + +template < + class ElementType, size_t Extent, + typename std::enable_if<!std::is_const<ElementType>::value, int>::type = 0> +span<byte, ((Extent == dynamic_extent) ? dynamic_extent + : sizeof(ElementType) * Extent)> +as_writable_bytes(span<ElementType, Extent> s) noexcept +{ + return {reinterpret_cast<byte*>(s.data()), s.size_bytes()}; +} + +template <std::size_t N, typename E, std::size_t S> +constexpr auto get(span<E, S> s) -> decltype(s[N]) +{ + return s[N]; +} + +} // namespace TCB_SPAN_NAMESPACE_NAME + +namespace std { + +template <typename ElementType, size_t Extent> +class tuple_size<TCB_SPAN_NAMESPACE_NAME::span<ElementType, Extent>> + : public integral_constant<size_t, Extent> {}; + +template <typename ElementType> +class tuple_size<TCB_SPAN_NAMESPACE_NAME::span< + ElementType, TCB_SPAN_NAMESPACE_NAME::dynamic_extent>>; // not defined + +template <size_t I, typename ElementType, size_t Extent> +class tuple_element<I, TCB_SPAN_NAMESPACE_NAME::span<ElementType, Extent>> { +public: + static_assert(Extent != TCB_SPAN_NAMESPACE_NAME::dynamic_extent && + I < Extent, + ""); + using type = ElementType; +}; + +} // end namespace std + +#endif // AIDGE_CORE_UTILS_FUTURE_STD_SPAN_H_ diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index babc534bdc23e87e17e21312d18b51b04baee7ca..fa109a9af4b1146b60f0fffc80b8dfc6e4a2c256 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -42,7 +42,7 @@ void addCtor(py::class_<Tensor, std::set<std::string> availableBackends = Tensor::getAvailableBackends(); if (availableBackends.find("cpu") != availableBackends.end()){ newTensor->setBackend("cpu"); - newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); + newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr), newTensor->size()); }else{ printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); } @@ -71,7 +71,8 @@ void init_Tensor(py::module& m){ (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); pyClassTensor.def(py::init<>()) - .def("set_backend", &Tensor::setBackend, py::arg("name")) + .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true) + .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("dtype", &Tensor::dataType) .def("size", &Tensor::size) diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 61392470adaeb7db8812a3063edc5f8eee1d3083..32151a66a46f7d7da73473c90effa760ebc93891 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -56,7 +56,7 @@ void init_GraphView(py::module& m) { :type include_learnable_parameters: bool, optional )mydelimiter") - .def("add", (void (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add, + .def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add, py::arg("other_graph"), R"mydelimiter( Include a GraphView to the current GraphView object. @@ -97,9 +97,10 @@ void init_GraphView(py::module& m) { .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) .def("forward_dims", &GraphView::forwardDims) + .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype")) .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) - .def("set_backend", &GraphView::setBackend, py::arg("backend")) + .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) // .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { // // TODO : Should return error if backend not compatible with get // if (idx >= b.size()) throw py::index_error(); diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index aa5c21372730536662106a035307d885fa011107..1f655b50a38dddf597f51879411535ff655ed694 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -137,6 +137,8 @@ void init_Node(py::module& m) { :rtype: int )mydelimiter") + .def("get_parent", &Node::getParent, py::arg("in_id")) + .def("get_parents", &Node::getParents, R"mydelimiter( Get parents. diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index ff0b9e0dfcb0d1c5e5567a938b1ca74faf242bed..411a2e1b6ae78065a79b92f25c23dac13e341997 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) { .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); - m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } void init_BatchNorm(py::module &m) { diff --git a/python_binding/operator/pybind_Concat.cpp b/python_binding/operator/pybind_Concat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b7e5d6b99194e914e48dc6263d0bdcd6a4a8a2f --- /dev/null +++ b/python_binding/operator/pybind_Concat.cpp @@ -0,0 +1,28 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <string> + +#include "aidge/operator/Concat.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Concat(py::module& m) { + py::class_<Concat_Op, std::shared_ptr<Concat_Op>, OperatorTensor, Attributes>(m, "ConcatOp", py::multiple_inheritance()) + .def("get_inputs_name", &Concat_Op::getInputsName) + .def("get_outputs_name", &Concat_Op::getOutputsName); + + m.def("Concat", &Concat, py::arg("nbIn"), py::arg("axis"), py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 71231b8218ac6af28c97ec29039301bc25b2d195..2200cd3fec1450011d6e0b5197f8b99b4dfeb4c3 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -11,7 +11,6 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> -#include <iostream> #include <string> #include <vector> #include <array> diff --git a/python_binding/operator/pybind_Erf.cpp b/python_binding/operator/pybind_Erf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..806867f61c3580543c184d529edc2856ee8d7a6c --- /dev/null +++ b/python_binding/operator/pybind_Erf.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Erf.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Erf(py::module& m) { + py::class_<Erf_Op, std::shared_ptr<Erf_Op>, OperatorTensor>(m, "ErfOp", py::multiple_inheritance()) + .def("get_inputs_name", &Erf_Op::getInputsName) + .def("get_outputs_name", &Erf_Op::getOutputsName); + + m.def("Erf", &Erf, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Gather.cpp b/python_binding/operator/pybind_Gather.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f9768e38fbdceef4a15cc74430bc2205bb32cb6a --- /dev/null +++ b/python_binding/operator/pybind_Gather.cpp @@ -0,0 +1,28 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <string> + +#include "aidge/operator/Gather.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Gather(py::module& m) { + py::class_<Gather_Op, std::shared_ptr<Gather_Op>, OperatorTensor, Attributes>(m, "GatherOp", py::multiple_inheritance()) + .def("get_inputs_name", &Gather_Op::getInputsName) + .def("get_outputs_name", &Gather_Op::getOutputsName); + + m.def("Gather", &Gather, py::arg("axis"), py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index d1eff7b387f9b339e6641a8049e020a7e8a4f021..f5c5145e0a86d939b96e6d2a579dfa2579f8b3a5 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -128,9 +128,7 @@ void init_MetaOperatorDefs(py::module &m) { m.def("meta_operator", &MetaOperator, py::arg("type"), py::arg("graph"), - py::arg("name") = "", - py::arg("input_nodes") = std::vector<NodePtr>(), - py::arg("output_nodes") = std::vector<NodePtr>() + py::arg("name") = "" ); } diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index f9482eda2f93b5492cfcc89175da69d140f23df8..79a85cb92cf27c7edb745c36eefe61ae86c66786 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,6 +20,7 @@ namespace Aidge { void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) + .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx")) .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) @@ -29,7 +30,7 @@ void init_Operator(py::module& m){ .def("nb_outputs", &Operator::nbOutputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDataType, py::arg("dataType")) - .def("set_backend", &Operator::setBackend, py::arg("name")) + .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0) .def("forward", &Operator::forward) // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) diff --git a/python_binding/operator/pybind_ReduceMean.cpp b/python_binding/operator/pybind_ReduceMean.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5de98b69adde5133dde302f7306bc8a5c471eef --- /dev/null +++ b/python_binding/operator/pybind_ReduceMean.cpp @@ -0,0 +1,54 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <string> +#include <vector> +#include <array> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/ReduceMean.hpp" +#include "aidge/utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_ReduceMeanOp(py::module &m) { + py::class_<ReduceMean_Op<DIM>, std::shared_ptr<ReduceMean_Op<DIM>>, OperatorTensor, Attributes>( + m, ("ReduceMeanOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) + .def("get_inputs_name", &ReduceMean_Op<DIM>::getInputsName) + .def("get_outputs_name", &ReduceMean_Op<DIM>::getOutputsName) + ; + + m.def(("ReduceMean" + std::to_string(DIM) + "D").c_str(), [](const std::vector<int>& axes, + DimSize_t keepDims, + const std::string& name) { + AIDGE_ASSERT(axes.size() == DIM, "axes size [%ld] does not match DIM [%d]", axes.size(), DIM); + + return ReduceMean<DIM>(to_array<DIM>(axes.begin()), keepDims, name); + }, py::arg("axes"), + py::arg("keep_dims") = 1, + py::arg("name") = ""); +} + + +void init_ReduceMean(py::module &m) { + declare_ReduceMeanOp<1>(m); + declare_ReduceMeanOp<2>(m); + declare_ReduceMeanOp<3>(m); + + // FIXME: + // m.def("ReduceMean1D", static_cast<NodeAPI(*)(const char*, int, int, int const + // (&)[1])>(&ReduceMean)); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Reshape.cpp b/python_binding/operator/pybind_Reshape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d34a411c719bdbb1144edaa65b50050d705e0d90 --- /dev/null +++ b/python_binding/operator/pybind_Reshape.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Reshape.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Reshape(py::module& m) { + py::class_<Reshape_Op, std::shared_ptr<Reshape_Op>, OperatorTensor>(m, "ReshapeOp", py::multiple_inheritance()) + .def("get_inputs_name", &Reshape_Op::getInputsName) + .def("get_outputs_name", &Reshape_Op::getOutputsName); + + m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Slice.cpp b/python_binding/operator/pybind_Slice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7bfd1b4f00579ed29658db73b71f2c596048fe75 --- /dev/null +++ b/python_binding/operator/pybind_Slice.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Slice.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Slice(py::module& m) { + py::class_<Slice_Op, std::shared_ptr<Slice_Op>, OperatorTensor>(m, "SliceOp", py::multiple_inheritance()) + .def("get_inputs_name", &Slice_Op::getInputsName) + .def("get_outputs_name", &Slice_Op::getOutputsName); + + m.def("Slice", &Slice, py::arg("starts"), py::arg("ends"), py::arg("axes"), py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Softmax.cpp b/python_binding/operator/pybind_Softmax.cpp index dc29e2171ff6f0fbbb5c80183778d8f20cbe085b..04e92d39971a731931397e943aba6e296a81a14d 100644 --- a/python_binding/operator/pybind_Softmax.cpp +++ b/python_binding/operator/pybind_Softmax.cpp @@ -19,10 +19,10 @@ namespace py = pybind11; namespace Aidge { void init_Softmax(py::module& m) { - py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, OperatorTensor>(m, "SoftmaxOp", py::multiple_inheritance()) + py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, OperatorTensor, Attributes>(m, "SoftmaxOp", py::multiple_inheritance()) .def("get_inputs_name", &Softmax_Op::getInputsName) .def("get_outputs_name", &Softmax_Op::getOutputsName); - m.def("Softmax", &Softmax, py::arg("name") = ""); + m.def("Softmax", &Softmax, py::arg("axis"), py::arg("name") = ""); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Transpose.cpp b/python_binding/operator/pybind_Transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e92e9c2aaafe2d20220da053a2b9d799fbe8466d --- /dev/null +++ b/python_binding/operator/pybind_Transpose.cpp @@ -0,0 +1,52 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <vector> +#include <array> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Transpose.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Types.h" +#include "aidge/data/Tensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> +void declare_Transpose(py::module &m) { + py::class_<Transpose_Op<DIM>, std::shared_ptr<Transpose_Op<DIM>>, OperatorTensor, Attributes>( + m, ("TransposeOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) + .def("get_inputs_name", &Transpose_Op<DIM>::getInputsName) + .def("get_outputs_name", &Transpose_Op<DIM>::getOutputsName); + + m.def(("Transpose" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& output_dims_order, + const std::string& name) { + AIDGE_ASSERT(output_dims_order.size() == DIM, "output_dims_order size [%ld] does not match DIM [%d]", output_dims_order.size(), DIM); + return Transpose<DIM>(to_array<DIM>(output_dims_order.begin()), name); + }, py::arg("output_dims_order"), + py::arg("name") = ""); + +} + +void init_Transpose(py::module &m) { + declare_Transpose<2>(m); + declare_Transpose<3>(m); + declare_Transpose<4>(m); + declare_Transpose<5>(m); + declare_Transpose<6>(m); + +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 0da3ea03a36b0dd6e23d482c73c4bc3b3b468c22..ea61f05adc1ac4471d80bd8f9ebbd780dfdf39bd 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -25,10 +25,13 @@ void init_OperatorTensor(py::module&); void init_Add(py::module&); void init_AvgPooling(py::module&); void init_BatchNorm(py::module&); +void init_Concat(py::module&); void init_Conv(py::module&); void init_ConvDepthWise(py::module&); void init_Div(py::module&); +void init_Erf(py::module&); void init_FC(py::module&); +void init_Gather(py::module&); void init_GenericOperator(py::module&); void init_LeakyReLU(py::module&); void init_MatMul(py::module&); @@ -38,10 +41,14 @@ void init_Mul(py::module&); void init_Producer(py::module&); void init_Pad(py::module&); void init_Pow(py::module&); +void init_ReduceMean(py::module&); void init_ReLU(py::module&); +void init_Reshape(py::module&); +void init_Slice(py::module&); void init_Softmax(py::module&); void init_Sqrt(py::module&); void init_Sub(py::module&); +void init_Transpose(py::module&); void init_Identity(py::module&); void init_Node(py::module&); @@ -74,10 +81,13 @@ void init_Aidge(py::module& m){ init_Add(m); init_AvgPooling(m); init_BatchNorm(m); + init_Concat(m); init_Conv(m); init_ConvDepthWise(m); init_Div(m); + init_Erf(m); init_FC(m); + init_Gather(m); init_GenericOperator(m); init_LeakyReLU(m); init_MatMul(m); @@ -87,10 +97,14 @@ void init_Aidge(py::module& m){ init_Pad(m); init_Pow(m); + init_ReduceMean(m); init_ReLU(m); + init_Reshape(m); + init_Slice(m); init_Softmax(m); init_Sqrt(m); init_Sub(m); + init_Transpose(m); init_Identity(m); init_Producer(m); diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 820b6e12b11116b874170bd25a6dc75675894257..bd058defb21c13cea1323e4748129c92519de039 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -12,9 +12,11 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <cstddef> #include <string> #include "aidge/recipies/Recipies.hpp" +#include "aidge/utils/Types.h" namespace py = pybind11; @@ -28,7 +30,7 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. @@ -36,6 +38,13 @@ void init_Recipies(py::module &m) { // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); + m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a dropout operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -63,7 +72,10 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - + + m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling), + py::arg("node"), py::arg("axis"), py::arg("nb_slices")); + // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( // Recipie to remove a flatten operator. diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 85479d41f51e74dee4079e78a37e7f3a520639e2..d963b81d501f5cd2faf4f69810c897bb4b4da86d 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -21,6 +21,7 @@ void init_Scheduler(py::module& m){ .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) + .def("resetScheduling", &SequentialScheduler::resetScheduling) .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling) ; diff --git a/python_binding/utils/pybind_TensorUtils.cpp b/python_binding/utils/pybind_TensorUtils.cpp index 78825a5f3b8d45f22f76c57bd780dc7019fbc123..d82db0355ad641062ec89b1b331c74ccfde4c0b6 100644 --- a/python_binding/utils/pybind_TensorUtils.cpp +++ b/python_binding/utils/pybind_TensorUtils.cpp @@ -51,7 +51,7 @@ void addTensorUtilsFunction(py::module &m){ void init_TensorUtils(py::module &m) { addTensorUtilsFunction<float>(m); addTensorUtilsFunction<double>(m); - addTensorUtilsFunction<int>(m); - addTensorUtilsFunction<long>(m); + addTensorUtilsFunction<int32_t>(m); + addTensorUtilsFunction<int64_t>(m); } } // namespace Aidge diff --git a/src/backend/TensorImpl.cpp b/src/backend/TensorImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3982ee1fed9c9198b539bf9a28edd461992b791f --- /dev/null +++ b/src/backend/TensorImpl.cpp @@ -0,0 +1,51 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/data/Tensor.hpp" +#include "aidge/backend/TensorImpl.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { + if (&srcImpl == this) { + return; + } + + if (srcImpl.device() != device()) { + if (srcImpl.backend() == backend()) { + // Same backend, but different device + copyFromDevice(srcImpl.rawPtr(), length, srcImpl.device()); + } + else if (srcImpl.hostPtr() != nullptr) { + // Different backend, but input is valid on host + copyFromHost(srcImpl.hostPtr(), length); + } + else if (hostPtr() != nullptr) { + // Different backend, but dst is valid on host + srcImpl.copyToHost(hostPtr(), length); + } + else { + // No direct link possible from src to dst device + // SLOW SOLUTION: must pass through the host, requires TWO copies + // Allocate a temporary host buffer just for the copy + // We might reuse a pre-allocated buffer, but for now this feature is not provided because: + // - There is currently no concrete use case + // - Just providing a pointer would be unsafe (risk of buffer overflow...) + auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]); + srcImpl.copyToHost(tmpHostBuffer.get(), length); + copyFromHost(tmpHostBuffer.get(), length); + } + } + else { + // Same device: simple copy on device + copy(srcImpl.rawPtr(), length); + } +} diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da0c626d78dd1cc4452bfc07bf6c6a7f58b8d1e4 --- /dev/null +++ b/src/data/Tensor.cpp @@ -0,0 +1,140 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +void Aidge::Tensor::copyCast(const Tensor& src) { + if (&src == this) { + return; + } + + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + } + resize(src.dims()); + + AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device"); + getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); +} + +void Aidge::Tensor::copyFrom(const Tensor& src) { + if (&src == this) { + return; + } + + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + } + resize(src.dims()); + + AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type"); + getImpl()->copyFrom(*(src.getImpl()), src.size()); +} + +void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { + if (&src == this) { + return; + } + + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + } + resize(src.dims()); + + if (dataType() != src.dataType()) { + // First move data to the target device (only if needed) + const auto device = getImpl()->device(); + const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second); + // Second, copy-cast data (necessary) + getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType()); + } + else { + // Directly copy, no conversion necessary + // Avoid making a double copy if both data type and device are the same + getImpl()->copyFrom(*(src.getImpl()), src.size()); + } +} + +Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) { + // Scott Meyers' solution to avoid code duplication + return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refCast(fallback, dt)); +} + +const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const { + AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it"); + + if (dt == dataType()) { + return *this; + } + else { + if (this == fallback.get()) { + // if refFrom() was called before, just change the type + fallback->setDataType(dt); + } + else { + if (!fallback) { + fallback = std::make_shared<Tensor>(dt); + } + else { + fallback->setDataType(dt, false); // don't keep previous data (no copy) + } + + const auto device = getImpl()->device(); + fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) + fallback->resize(dims()); + fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); + } + return *fallback; + } +} + +Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device) { + // Scott Meyers' solution to avoid code duplication + return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refFrom(fallback, backend, device)); +} + +const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device) const { + AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refFrom() it"); + + if (std::make_pair(backend, device) == getImpl()->device()) { + return *this; + } + else { + if (this == fallback.get()) { + // if refCast() was called before, just change the backend + fallback->setBackend(backend, device); + } + else { + if (!fallback) { + fallback = std::make_shared<Tensor>(dataType()); + } + else { + fallback->setDataType(dataType(), false); // don't keep previous data (no copy) + } + + fallback->setBackend(backend, device, false); // don't keep previous data (no copy) + fallback->resize(dims()); + fallback->getImpl()->copyFrom(*getImpl(), size()); + } + return *fallback; + } +} diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index ce956d115e282c43751619070dd8a10ac5c9cfae..c2439a459dcbe1b53d6aa31fd467ca3cd137aa23 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -71,33 +71,74 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { typeCounter[currentType] = 0; ++typeCounter[currentType]; - const std::string givenName = + std::string givenName = (node_ptr->name().empty()) - ? currentType + std::to_string(typeCounter[currentType]) - : node_ptr->name(); + ? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>" + : "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + "#" + std::to_string(typeCounter[currentType]) + " )</em></sub>\""; namePtrTable[node_ptr] = (currentType + "_" + std::to_string(typeCounter[currentType])); - std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), - givenName.c_str()); + + if (node_ptr == mRootNode) { + std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } + else { + std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(), + givenName.c_str()); + } } + // Write every link - std::size_t emptyInputCounter = 0; for (const std::shared_ptr<Node> &node_ptr : mNodes) { - for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) { - if ((pa_ptr == nullptr) || !inView(pa_ptr)) { - std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter, - emptyInputCounter, namePtrTable[node_ptr].c_str()); - ++emptyInputCounter; - } else { - std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(), - namePtrTable[node_ptr].c_str()); - } + IOIndex_t outputIdx = 0; + for (auto childs : node_ptr->getOrderedChildren()) { + for (auto child : childs) { + if (child != nullptr) { + IOIndex_t inputIdx = 0; + for (auto parent : child->inputs()) { + if (parent.first == node_ptr && parent.second == outputIdx) { + if (mNodes.find(child) != mNodes.end()) { + std::fprintf(fp, "%s-->|%u→%u|%s\n", namePtrTable[node_ptr].c_str(), + outputIdx, inputIdx, namePtrTable[child].c_str()); + } + else if (verbose) { + std::fprintf(fp, "%s-->|%u→%u|%p:::externalCls\n", namePtrTable[node_ptr].c_str(), + outputIdx, inputIdx, static_cast<void*>(child.get())); + } + break; + } + ++inputIdx; + } + } } + ++outputIdx; + } + } + + size_t inputIdx = 0; + for (auto input : mInputNodes) { + std::fprintf(fp, "input%lu((in#%lu)):::inputCls--->|→%u|%s\n", inputIdx, inputIdx, + input.second, namePtrTable[input.first].c_str()); + ++inputIdx; + } + + size_t outputIdx = 0; + for (auto output : mOutputNodes) { + std::fprintf(fp, "%s--->|%u→|output%lu((out#%lu)):::outputCls\n", + namePtrTable[output.first].c_str(), output.second, + outputIdx, outputIdx); + ++outputIdx; } + + std::fprintf(fp, "classDef inputCls fill:#afa\n"); + std::fprintf(fp, "classDef outputCls fill:#ffa\n"); + std::fprintf(fp, "classDef externalCls fill:#ccc\n"); + std::fprintf(fp, "classDef rootCls stroke:#f00\n"); + if (verbose) { - for (const auto &c : typeCounter) { + for (const auto &c : typeCounter) { std::printf("%s - %zu\n", c.first.c_str(), c.second); - } + } } std::fprintf(fp, "\n"); @@ -108,20 +149,60 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { // TENSOR MANAGEMENT /////////////////////////////////////////////////////// +void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) { + AIDGE_ASSERT(inputs.size() <= mInputNodes.size(), "too many specified number of inputs"); + + std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes); + for (auto input : inputs) { + auto it = std::find(ignoredInputs.begin(), ignoredInputs.end(), input); + AIDGE_ASSERT(it != ignoredInputs.end(), "unknown or duplicate input"); + ignoredInputs.erase(it); + } + + mInputNodes = inputs; + mInputNodes.insert(mInputNodes.end(), ignoredInputs.begin(), ignoredInputs.end()); +} + +void Aidge::GraphView::setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs) { + AIDGE_ASSERT(outputs.size() <= mOutputNodes.size(), "too many specified number of outputs"); + + std::vector<std::pair<NodePtr, IOIndex_t>> ignoredOutputs(mOutputNodes); + for (auto output : outputs) { + auto it = std::find(ignoredOutputs.begin(), ignoredOutputs.end(), output); + AIDGE_ASSERT(it != ignoredOutputs.end(), "unknown or duplicate output"); + ignoredOutputs.erase(it); + } + + mOutputNodes = outputs; + mOutputNodes.insert(mOutputNodes.end(), ignoredOutputs.begin(), ignoredOutputs.end()); +} + Aidge::IOIndex_t Aidge::GraphView::getNbDataInputs() const { - return std::accumulate(mInputNodes.cbegin(), mInputNodes.cend(), 0, - [](IOIndex_t sumData, const std::shared_ptr<Node> inNode) { - return sumData + inNode->nbData(); - } - ); + IOIndex_t nbDataInput = 0; + for (const std::shared_ptr<Node> &inNode : inputNodes()) { + // We cannot simply add inNode->nbDataInputs(), as input nodes may already + // have some inputs connected within the GraphView, which would therefore not + // constitue inputs (from outside) for the GraphView! + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inNode->dataInputs(); + + for (const auto& input : inputNodeinputs) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { + ++nbDataInput; + } + } + } + return nbDataInput; } Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { - return std::accumulate(mInputNodes.cbegin(), mInputNodes.cend(), 0, - [](IOIndex_t sumData, const std::shared_ptr<Node> inNode) { - return sumData + inNode->getNbFreeDataInputs(); - } - ); + IOIndex_t nbIn = 0; + // Free inputs within the GraphView are logically also free inputs from outside + // the GraphView. + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { + nbIn += inputNode->getNbFreeDataInputs(); + } + return nbIn; } @@ -129,12 +210,12 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); for (const auto& input : inputNodeinputs) { - if (mNodes.find(input.first) == mNodes.end()) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } @@ -147,12 +228,12 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); for (const auto& input : inputNodeinputs) { - if (mNodes.find(input.first) == mNodes.end()) { + if (input.first == nullptr || mNodes.find(input.first) == mNodes.end()) { res.push_back(input); } } @@ -166,10 +247,10 @@ Aidge::GraphView::inputs(std::string name) const { return mNodeRegistry.at(name)->inputs(); } -void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype) { +void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device) { // Backend // TODO: add Backend attribute to Operator - setBackend(backend); + setBackend(backend, device); // Data type // TODO: manage Datatype attribute in OperatorImpl setDataType(datatype); @@ -238,9 +319,9 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { } } -void Aidge::GraphView::setBackend(const std::string &backend) { +void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) { for (auto node : getNodes()) { - node->getOperator()->setBackend(backend); + node->getOperator()->setBackend(backend, device); } } @@ -250,68 +331,28 @@ void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) { } } -void Aidge::GraphView::updateOutputNodes() { - mOutputNodes.clear(); - for (const std::shared_ptr<Node>& go_it : mNodes) { - if (go_it->nbOutputs() != - go_it->nbValidOutputs()) { // an output linked to nothing - mOutputNodes.insert(go_it); - continue; - } - for (const std::shared_ptr<Node>& ch_ptr : go_it->getChildren()) { - if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph - mOutputNodes.insert(go_it); - break; - } - } - } -} - -void Aidge::GraphView::updateOutputNodes(std::shared_ptr<Node> node) { - if (node->nbOutputs() != - node->nbValidOutputs()) { // an output linked to nothing - mOutputNodes.insert(node); - } else { // don't enter if was already added to outputNodes - for (const std::shared_ptr<Node> &ch_ptr : node->getChildren()) { - if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph - mOutputNodes.insert(node); - break; - } - } - } - // update other outputNodes - for (const std::shared_ptr<Node> &pa_ptr : - node->getParents()) { // check if any parent is in OutputNodes too - if ((pa_ptr != nullptr) && - (mOutputNodes.find(pa_ptr) != - mOutputNodes.end())) { // it's a match! Must check if the outputNode - // found is still an outputNode - bool remove = (pa_ptr->nbOutputs() == pa_ptr->nbValidOutputs()); - for (const std::shared_ptr<Node>& ch_ptr : pa_ptr->getChildren()) { - if (mNodes.find(ch_ptr) == mNodes.end()) { // Child not in the graph - remove = false; - break; - } - } - if (remove) { - mOutputNodes.erase(pa_ptr); - } - } - } -} - std::vector< std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::GraphView::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> - outputTensors; - for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { - std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> - tmpOutputs = (outputNode->outputs()); - outputTensors.insert(outputTensors.end(), tmpOutputs.begin(), - tmpOutputs.end()); + outsideOutputs; + for (const std::shared_ptr<Node>& outputNode : outputNodes()) { + const std::vector<std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>>> + outputNodeOutputs = outputNode->outputs(); + + for (const auto& outputPos : outputNodeOutputs) { + // Keep only the nodes connected at this output position that are outside the GraphView + std::vector<std::pair<std::shared_ptr<Node>, Aidge::IOIndex_t>> outsideOutputPos; + for (const auto& output : outputPos) { + if (mNodes.find(output.first) == mNodes.end()) { + outsideOutputPos.push_back(output); + } + } + + outsideOutputs.push_back(outsideOutputPos); + } } - return outputTensors; + return outsideOutputs; } std::vector< @@ -326,11 +367,20 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/, } void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) { + // first node to be added to the graph is the root node by default + if (mRootNode == nullptr) { + mRootNode = node; + } + // add to the GraphView nodes node->addView(shared_from_this()); mNodes.insert(node); if (!(node->name()).empty()) mNodeRegistry.insert(std::make_pair(node->name(), node)); + + // check if the node is an input/output node + updateInputsOutputsNew(node); + // add learnable parameters to the graph if (includeLearnableParam) { for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) { @@ -340,33 +390,124 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara mNodes.insert(parentNode); if (!(parentNode->name()).empty()) mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); - // check if the Node is an input node - updateInputNodes(parentNode); + // check if the parentNode is an input/output node + updateInputsOutputsNew(parentNode); } } } - // check if the Node is an input node - updateInputNodes(node); - // check if the Node is an input node - updateOutputNodes(node); } -void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { - for (auto& nodePtr : otherNodes) { add(nodePtr, includeLearnableParam); } +bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) { + if (otherNodes.empty()) { + return true; + } + + bool orderUnicity = true; + + // List only the nodes that are not already present in current graph + std::set<NodePtr> nodesToAdd; + std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin())); + + // List the nodes to rank, initially all the nodes in the GraphView + std::set<NodePtr> nodesToRank(mNodes); + nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end()); + std::vector<NodePtr> rankedNodesToAdd; + + if (mRootNode == nullptr) { + std::set<NodePtr> noParentNodes; + + // If no root node is defined, check nodes without parents + for (auto node : nodesToRank) { + bool noParent = true; + for (auto parent : node->getParents()) { + if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) { + noParent = false; + break; + } + } + + if (noParent) { + noParentNodes.insert(node); + } + } + + // Take the first one found (this is an arbitrary choice) + mRootNode = *noParentNodes.begin(); + + if (noParentNodes.size() > 1) { + // If there is more than one, order unicity cannot be garanteed! + orderUnicity = false; + } + + rankedNodesToAdd.push_back(mRootNode); + } + + nodesToRank.erase(mRootNode); + std::vector<NodePtr> rankedNodes; + rankedNodes.push_back(mRootNode); + + for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) { + NodePtr curNode = rankedNodes[curNodeIdx]; + + for (auto childs : curNode->getOrderedChildren()) { + for (auto child : childs) { + if (nodesToRank.find(child) != nodesToRank.end()) { + rankedNodes.push_back(child); + nodesToRank.erase(child); + + if (nodesToAdd.find(child) != nodesToAdd.end()) { + rankedNodesToAdd.push_back(child); + nodesToAdd.erase(child); + } + } + } + } + + for (auto parent : curNode->getParents()) { + if (nodesToRank.find(parent) != nodesToRank.end()) { + rankedNodes.push_back(parent); + nodesToRank.erase(parent); + + if (nodesToAdd.find(parent) != nodesToAdd.end()) { + rankedNodesToAdd.push_back(parent); + nodesToAdd.erase(parent); + } + } + } + } + + if (!nodesToAdd.empty()) { + // There are remaining nodes without path to the root node + orderUnicity = false; + + while (!nodesToAdd.empty()) { + const auto it = nodesToAdd.begin(); + rankedNodesToAdd.push_back(*it); + nodesToAdd.erase(it); + } + } + + for (auto node_ptr : rankedNodesToAdd) { + add(node_ptr, includeLearnableParam); + } + + return orderUnicity; +} + +bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) { + if (nodes.first != nullptr) { + mRootNode = nodes.first; + add(nodes.first, includeLearnableParam); + } + return add(nodes.second, includeLearnableParam); } -void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { - for (const std::shared_ptr<Node> &node_ptr : graph->getNodes()) { - node_ptr->addView(shared_from_this()); - mNodes.insert(node_ptr); - if (!(node_ptr->name()).empty()) - mNodeRegistry.insert(std::make_pair(node_ptr->name(), node_ptr)); - // if node_ptr is part of graph inputNodes or outputNodes - // if (graph->isInputNode(node_ptr) || graph->isOutputNode(node_ptr)) { - // Update OutputNodes/inputNodes - updateInputNodes(); - updateOutputNodes(); +bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) { + if (mRootNode == nullptr) { + mRootNode = graph->getRootNode(); } + + return add(graph->getNodes(), false); } void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode, @@ -414,7 +555,7 @@ void Aidge::GraphView::addChild( std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents() const { // TODO: choose if we return a set or a vector std::set<std::shared_ptr<Node>> parents; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { parents.insert(inputNode->getParents().begin(), inputNode->getParents().end()); } @@ -433,7 +574,7 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getParents(const std std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::GraphView::getOrderedParents() const { std::vector<std::vector<std::shared_ptr<Node>>> parents; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { + for (const std::shared_ptr<Node>& inputNode : inputNodes()) { parents.push_back(inputNode->getParents()); } return parents; @@ -441,7 +582,7 @@ Aidge::GraphView::getOrderedParents() const { std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const std::shared_ptr<Node>& outputNode : mOutputNodes) { + for (const std::shared_ptr<Node>& outputNode : outputNodes()) { children.insert((outputNode->getChildren()).begin(), (outputNode->getChildren()).end()); } @@ -475,48 +616,54 @@ std::shared_ptr<Aidge::Node> Aidge::GraphView::getNode(const std::string& nodeName) const { std::map<std::string, std::shared_ptr<Node>>::const_iterator it = mNodeRegistry.find(nodeName); - if (it != mNodeRegistry.end()) { + if (it != mNodeRegistry.cend()) { return it->second; } else { printf("No Node named %s in the current GraphView.\n", nodeName.c_str()); - exit(-1); + return nullptr; } } void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) { - if (mNodes.find(nodePtr) != mNodes.end()) { - mNodes.erase(nodePtr); - nodePtr->removeView(shared_from_this()); - } - if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } - // same for learnable params - + // remove learnable params if (includeLearnableParam) { for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) { auto inputI = nodePtr->input(i); - bool removeNode = true; - for (const auto& parentOutput : inputI.first->outputs()) { - for (const auto& childOfParentOutput : parentOutput) { - // only remove the learnable parameter if not related to any other Node in the GraphView - if (childOfParentOutput.first != nodePtr) { - removeNode = false; - break; + if (inputI.first != nullptr) { + bool removeNode = true; + for (const auto& parentOutput : inputI.first->outputs()) { + for (const auto& childOfParentOutput : parentOutput) { + // only remove the learnable parameter if not related to any other Node in the GraphView + if (childOfParentOutput.first != nodePtr) { + removeNode = false; + break; + } } } - } - if (removeNode) { - // assert Learnable Parameter in the GraphView scope - if (mNodes.find(inputI.first) != mNodes.end()) { - mNodes.erase(inputI.first); - inputI.first->removeView(shared_from_this()); + if (removeNode) { + // assert Learnable Parameter in the GraphView scope + if (mNodes.find(inputI.first) != mNodes.end()) { + mNodes.erase(inputI.first); + inputI.first->removeView(shared_from_this()); + } + if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } + + // check if the node was an input/output node + updateInputsOutputsDelete(inputI.first); } - if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } } } } - updateInputNodes(); - updateOutputNodes(); + + if (mNodes.find(nodePtr) != mNodes.end()) { + mNodes.erase(nodePtr); + nodePtr->removeView(shared_from_this()); + + // check if the nodePtr was an input/output node + updateInputsOutputsDelete(nodePtr); + } + if (!nodePtr->name().empty()) { mNodeRegistry.erase(nodePtr->name()); } } @@ -547,211 +694,369 @@ void Aidge::GraphView::insertParent(NodePtr childNode, add(newParentNode); } - bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { - // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // How to distinguish it from data input? // TODO: Parameter Tensors could be identified with their dimensions // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // It also avoids specifying each producer since they are automatically included + // (1) create GraphViews from both sets of Nodes auto oldG = std::make_shared<GraphView>("oldG"); oldG->add(oldNodes, false); auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); - if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) { - return false; + const auto oldOI = oldG->getOrderedInputs(); + const auto oldOO = oldG->getOrderedOutputs(); + const auto newOI = newG->getOrderedInputs(); + const auto newOO = newG->getOrderedOutputs(); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size()); + + // keep in memory every parent + for (std::size_t i = 0; i < oldOI.size(); ++i) { + auto inputParent = oldOI[i].first -> input(oldOI[i].second); + inputParents[i]= inputParent; + // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); } - if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) || - (newG->outputNodes().size() != 1))) { - return false; + for (std::size_t i = 0; i < oldOO.size();) { + auto outputChildList = oldOO[i].first -> output(oldOO[i].second); + if (outputChildList.empty()) { + outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); + ++i; + } + else { + for (const auto& child : outputChildList) { + if (oldNodes.find(child.first) == oldNodes.cend()) { + outputChildren[i] = child; + ++i; + } + } + } } - // there is at least one inputNode in the old/new GraphView - std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin()); - std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin()); + // only keep common views to each node for the new set + // set of common GraphView for oldNodes' Nodes + std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); + for (const auto& nodePtr : oldNodes) { + const auto nodeView = nodePtr->views(); + std::set<std::shared_ptr<GraphView>> intersection; + std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), + nodeView.begin(), nodeView.end(), + std::inserter(intersection, intersection.begin())); + commonGraphViews = intersection; + } + commonGraphViews.erase(oldG); + commonGraphViews.erase(newG); - // find Node to link to new input Node - //compute number of input for firstPreviousInputNode not in oldNodes set - std::size_t nbExternalInputs = 0; - std::shared_ptr<Node> externalInput = nullptr; - IOIndex_t externalInputId = gk_IODefaultIndex; - for (const auto& input : firstPreviousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG - nbExternalInputs++; - externalInput = input.first; - externalInputId = input.second; + if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); } - } - if (nbExternalInputs > 1) { - AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; } - if (oldG->inputNodes().size() > 1){ - // one or no input has been identified. Checking every input points to the same source - for (const auto& previousInputNode : oldG->inputNodes()) { - for (const auto& input : previousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { - if ( (externalInput != input.first) || (externalInputId != input.second) ) { - return false; // an inputNode points to an external Node different from the registered one + if ((oldOI.size() == newOI.size()) && + (oldOO.size() == newOO.size())) { + // Case 1 + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) { + inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second); + } + } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); + } + } + } + else { + // get the number of Parents for oldG->inputNodes() + // get the number of Children for oldg->outputNodes() + if (newNodes.size() == 0) { + // Case 3 + if (oldOI.size() == oldOO.size()) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) + inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); + } + } + else if ((oldOI.size() == 1) && (inputParents[0].first)) { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); + } + } + } + else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes + ((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1 + ((oldOO.size() == newOO.size())) + ) { + // Case 2 + if ((oldOI.size() == 1) && (inputParents[0].first)) { + for (std::size_t i = 0; i < newOI.size(); ++i) { + inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second); + } + } else { + for (std::size_t i = 0; i < oldOI.size(); ++i) { + if (inputParents[i].first) { + inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); } } } + for (std::size_t o = 0; o < oldOO.size(); ++o) { + if (outputChildren[o].first) { + newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); + } + } + } + else { + for (const auto& nodePtr : oldNodes) { + nodePtr->removeView(oldG); + } + for (const auto& nodePtr : newNodes) { + nodePtr->removeView(newG); + } + return false; } } - if (firstPreviousOutputNode->nbOutputs() != 1) { - return false; - } + auto oldGOutputs = oldG->outputNodes(); + for (const auto& nodePtr : oldNodes) { + bool removeFromGraphs = true; + if (std::find(oldGOutputs.cbegin(), oldGOutputs.cend(), nodePtr) == oldGOutputs.cend()) { + for (const auto& chPtr : nodePtr->getChildren()) { + if (oldNodes.find(chPtr) == oldNodes.cend()) { + removeFromGraphs = false; + } + } + } + if (removeFromGraphs) { + for (const auto& g : commonGraphViews) { + g -> remove(nodePtr, false); + g -> updateInputsOutputsDelete(nodePtr); + } + nodePtr -> resetConnections(true); + } - // find Node to replicate output connections - std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); + } - auto copyOutputs = firstPreviousOutputNode->outputs(); - // manage Views for newNodes - // only keep common views to each node for the new set - std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); + for (const auto& nodePtr : newNodes) { + for (const auto& g : commonGraphViews) { + g -> add(nodePtr); + } + } for (const auto& nodePtr : oldNodes) { - const auto nodeView = nodePtr->views(); - std::set<std::shared_ptr<GraphView>> intersection; - std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), - nodeView.begin(), nodeView.end(), - std::inserter(intersection, intersection.begin())); - commonGraphViews = intersection; + nodePtr -> removeView(oldG); } - commonGraphViews.erase(oldG); - commonGraphViews.erase(newG); + for (const auto& nodePtr : newNodes) { + nodePtr -> removeView(newG); + } + return true; +} - // clean Nodes to replace - // Do not include common Nodes to avoid cleaning Producers linked to newNodes - std::set<std::shared_ptr<Node>> nodesToClean; - std::set_difference(oldNodes.begin(), oldNodes.end(), - newNodes.begin(), newNodes.end(), - std::inserter(nodesToClean, nodesToClean.begin())); - for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } - - // copy output connections - if (newOutputNode) { - for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { - auto outputPairs = copyOutputs[o]; - for (const auto& onePair : outputPairs) { - newOutputNode->addChild(onePair.first, o, onePair.second); +void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) { + // Can be called several times with the same node, e.g. when addChild() is + // called on a node already part of the GraphView. In this case, inputs/outputs + // need to be updated! + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); + + // Remove inputs that are not input anymore because connected to newNode + for (auto orderedChilds : newNode->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // Check that newNode child is in current GraphView + if (mNodes.find(ch_ptr) != mNodes.cend()) { + IOIndex_t inputIdx = 0; + for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { + // If newNode is connected to it + if (pa_ptr == newNode) { + const auto val = std::make_pair(ch_ptr, inputIdx); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + + // Check that it was not already the case (if node UPDATE) + if (iter != mInputNodes.cend()) { // newNode is linked to an actual inputNode to an input connection + // The first old (removed) input becomes the insertion point for newNode GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); + } } + } + ++inputIdx; } + } } + } - // copy input connections - if (!newNodes.empty() && externalInput) { - for (const auto& newInputNode : newG->inputNodes()) { - IOIndex_t inputId = 0; - for (const auto& input : newInputNode->inputs()) { - if (newNodes.find(input.first) == newNodes.end()) { - externalInput->addChild(newInputNode, externalInputId, inputId); - } - inputId++; + // Manage newNode parents + // Check if any input connection is an input for the GraphView + IOIndex_t inputIdx = 0U; + for (const std::shared_ptr<Node>& pa_ptr : newNode->getParents()) { + const auto val = std::make_pair(newNode, inputIdx); + const auto it = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + if ((pa_ptr == nullptr) || + (mNodes.find(pa_ptr) == mNodes.cend())) { + // Parent doesn't exist || Parent not in the graph + if (it == mInputNodes.cend()) { + // If node's inputs are inputs for the GraphView: add them to the input list + // Addition rule: + // - Inputs addition order follows node inputs order + // - Inputs are inserted at the position of the first input removed + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); } + } else if (it != mInputNodes.cend()) { + // Parent already in the graph SO edge is not an input anymore for the graph + mInputNodes.erase(it); } + ++inputIdx; } - // insert new Nodes in the right GraphViews - for (const auto& graphPtr : commonGraphViews) { - graphPtr->add(newNodes, false); - if (newNodes.empty()) { - graphPtr->updateInputNodes(); - graphPtr->updateOutputNodes(); + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); + + // Remove outputs that are not output anymore because connected to newNode + for (const std::shared_ptr<Node>& parent : newNode->getParents()) { + // Check that newNode parent is in current GraphView + if (mNodes.find(parent) != mNodes.cend()) { + IOIndex_t outputIdx = 0; + for (auto orderedChilds : parent->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // If newNode is connected to it + if (ch_ptr == newNode) { + const auto val = std::make_pair(parent, outputIdx); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); + + if (iter != mOutputNodes.cend()) { + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } + } + } } + ++outputIdx; + } } + } - for (const auto& node : oldNodes) { - node->removeView(oldG); + // Check if node outputs are outputs for the GraphView and add them to the output list if so + IOIndex_t outputIdx = 0; + for (const auto& orderedChilds : newNode->getOrderedChildren()) { + bool noInsideConnection = true; + for (const auto& ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.cend()) { + noInsideConnection = false; + break; + } } - for (const auto& node : newNodes) { - node->removeView(newG); + + if (noInsideConnection) { + const auto val = std::make_pair(newNode, outputIdx); + // Output may be already be present (see addChild() with a node already in GraphView) + if (std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val) == mOutputNodes.cend()) { + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); + } } - return true; + ++outputIdx; + } } +void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNode) { + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newInputsInsertionPoint = mInputNodes.cend(); -void Aidge::GraphView::updateInputNodes() { - mInputNodes.clear(); - for (const std::shared_ptr<Node>& go_ptr : mNodes) { - for (const std::shared_ptr<Node>& pa_ptr : go_ptr->getParents()) { - if ((pa_ptr == nullptr) || - (mNodes.find(pa_ptr) == - mNodes.end())) { // Parent doesn't exist || Parent not in the graph - mInputNodes.insert(go_ptr); - break; + // Check if node inputs were inputs for the GraphView and remove them from the list if so + for (IOIndex_t inputIdx = 0; inputIdx < deletedNode->getParents().size(); ++inputIdx) { + const auto val = std::make_pair(deletedNode, inputIdx); + const auto iter = std::find(mInputNodes.cbegin(), mInputNodes.cend(), val); + + if (iter != mInputNodes.cend()) { + // The first old (removed) input becomes the insertion point for new GraphView inputs + if (std::distance(newInputsInsertionPoint, iter) <= 0) { + newInputsInsertionPoint = mInputNodes.erase(iter); + } + else { + mInputNodes.erase(iter); } } } -} -void Aidge::GraphView::updateInputNodes(std::shared_ptr<Node> node) { - // add node_ptr to inputNode if it can - std::size_t filledWithKnownInputs = 0U; - bool wasAdded = mInputNodes.find(node) != mInputNodes.end(); - for (const std::shared_ptr<Node>& pa_ptr : node->getParents()) { - if ((pa_ptr == nullptr) || - (mNodes.find(pa_ptr) == - mNodes.end())) { // Parent doesn't exist || Parent not in the graph - mInputNodes.insert(node); - wasAdded = true; - break; - } - ++filledWithKnownInputs; - } - if (filledWithKnownInputs == node->nbInputs() && wasAdded) { - mInputNodes.erase(node); - } - // update other inputNodes - for (const std::shared_ptr<Node>& ch_ptr : - node->getChildren()) { // check if any child is in InputNodes too - if (mInputNodes.find(ch_ptr) != - mInputNodes.end()) { // it's a match! Must check if the inputNode found - // is still an inputNode - // change here - bool remove = true; - for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { - if (pa_ptr == nullptr || - mNodes.find(pa_ptr) == - mNodes - .end()) { // Parent doesn't exist || Parent not in the graph - remove = false; - break; + // Add child node inputs that become GraphView input following the removal of the node + // Inputs addition order follows deletedNode outputs order + for (auto orderedChilds : deletedNode->getOrderedChildren()) { + for (auto ch_ptr : orderedChilds) { + // Check that deletedNode child is in current GraphView + if (mNodes.find(ch_ptr) != mNodes.cend()) { + IOIndex_t inputIdx = 0; + for (const std::shared_ptr<Node>& pa_ptr : ch_ptr->getParents()) { + // If newNode was connected to it + if (pa_ptr == deletedNode) { + const auto val = std::make_pair(ch_ptr, inputIdx); + if (std::find(mInputNodes.cbegin(), mInputNodes.cend(), val) == mInputNodes.cend()) { + newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val); + newInputsInsertionPoint = std::next(newInputsInsertionPoint); + } + } + ++inputIdx; } } - if (remove) { - mInputNodes.erase(ch_ptr); - } } } -} + std::vector<std::pair<NodePtr, IOIndex_t>>::const_iterator newOutputsInsertionPoint = mOutputNodes.cend(); -void Aidge::GraphView::removeInputNode(const std::string nodeName) { - std::map<std::string, std::shared_ptr<Node>>::iterator it = - mNodeRegistry.find(nodeName); - if (it != mNodeRegistry.end()) { - const std::shared_ptr<Node> val = (*it).second; - if (mInputNodes.find(val) != mInputNodes.end()) { - mInputNodes.erase(val); + // Check if node outputs were outputs for the GraphView and remove them from the list if so + for (IOIndex_t outputIdx = 0; outputIdx < deletedNode->getOrderedChildren().size(); ++outputIdx) { + const auto val = std::make_pair(deletedNode, outputIdx); + const auto iter = std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val); + + if (iter != mOutputNodes.cend()) { + // The first old (removed) output becomes the insertion point for newNode GraphView outputs + if (std::distance(newOutputsInsertionPoint, iter) <= 0) { + newOutputsInsertionPoint = mOutputNodes.erase(iter); + } + else { + mOutputNodes.erase(iter); + } } } -} -void Aidge::GraphView::removeOutputNode(const std::string nodeName) { - std::map<std::string, std::shared_ptr<Node>>::iterator it = - mNodeRegistry.find(nodeName); - if (it != mNodeRegistry.end()) { - const std::shared_ptr<Node> val = (*it).second; - if (mOutputNodes.find(val) != mOutputNodes.end()) { - mOutputNodes.erase(val); + // Add parent node outputs that become GraphView output following the removal of the node + // Outputs addition order follows deletedNode inputs order + for (const std::shared_ptr<Node>& parent : deletedNode->getParents()) { + if (mNodes.find(parent) != mNodes.end()) { + IOIndex_t outputIdx = 0; + for (auto orderedChilds : parent->getOrderedChildren()) { + bool noInsideConnection = true; + for (auto ch_ptr : orderedChilds) { + if (mNodes.find(ch_ptr) != mNodes.end()) { + noInsideConnection = false; + break; + } + } + + if (noInsideConnection) { + const auto val = std::make_pair(parent, outputIdx); + if (std::find(mOutputNodes.cbegin(), mOutputNodes.cend(), val) == mOutputNodes.cend()) { + newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val); + newOutputsInsertionPoint = std::next(newOutputsInsertionPoint); + } + } + ++outputIdx; + } } } } + std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const { std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName); @@ -759,46 +1064,132 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone std::map<NodePtr, NodePtr> oldToNewNodes; for (const std::shared_ptr<Node> &node_ptr : mNodes) { - oldToNewNodes[node_ptr] = cloneNode(node_ptr); + auto clonedNode = cloneNode(node_ptr); + if (clonedNode == nullptr) { + AIDGE_ASSERT(node_ptr->getChildren().size() <= 1, "deleted nodes in GraphView::clone() cannot have multiple children"); + AIDGE_ASSERT(node_ptr->nbData() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents"); + } + oldToNewNodes[node_ptr] = clonedNode; } // For each node, convert old node -> new node connections for (auto &oldToNewNode : oldToNewNodes) { - if (oldToNewNode.second == nullptr) + if (oldToNewNode.second == nullptr) { continue; // deleted node - - // Add new node to new GraphView - newGraph->add(oldToNewNode.second, false); + } // Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr size_t parentId = 0; for (auto parent : oldToNewNode.first->inputs()) { - while (oldToNewNodes[parent.first] == nullptr) { - // Find next valid parent in line, going backward in the graph - assert(parent.first->nbData() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs"); - const auto& parents = parent.first->inputs(); + if (parent.first != nullptr) { + while (oldToNewNodes[parent.first] == nullptr) { + // Find next valid parent in line, going backward in the graph + AIDGE_INTERNAL_ASSERT(parent.first->getChildren().size() == 1); + AIDGE_INTERNAL_ASSERT(parent.first->nbData() <= 1); + const auto& parents = parent.first->dataInputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + parent = parents[0]; + } + else { + break; + } + } - if (!parents.empty() && parents[0].first != nullptr // a valid parent exists - && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView - { - parent = parents[0]; + if (oldToNewNodes[parent.first]) { + AIDGE_INTERNAL_ASSERT(oldToNewNodes[parent.first]->nbOutputs() == parent.first->nbOutputs()); + oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); } - else { - break; + } + + ++parentId; + } + } + + // Once connected, add each new nodes to new GraphView + // This has to be done in a second step to ensure that new GraphView inputs/outputs + // are properly set (otherwise, some node's inputs/outputs may be wrongly registered as + // GraphView inputs/outputs because not yet connected to other nodes) + if (oldToNewNodes[mRootNode] != nullptr) { + // Add root node first if is still exists! + newGraph->add(oldToNewNodes[mRootNode], false); + } + + for (auto &oldToNewNode : oldToNewNodes) { + if (oldToNewNode.second == nullptr) + continue; // deleted node + + newGraph->add(oldToNewNode.second, false); + } + + // Update cloned graph inputs/outputs order to match initial graph order + auto newInputNodes = mInputNodes; + for (auto it = newInputNodes.begin(); it != newInputNodes.end(); ) { + // If input node was removed, find next valid input + while (oldToNewNodes[it->first] == nullptr) { + // Removed node should have only one connected output, otherwise cloning is invalid + AIDGE_INTERNAL_ASSERT(it->first->getChildren().size() <= 1); + bool found = false; + + if (it->first->getChildren().size() == 1) { + auto child = *it->first->getChildren().begin(); + + std::size_t inputIdx = 0; + for (auto parent : child->getParents()) { + if (parent == it->first) { + it->first = child; + it->second = inputIdx; + found = true; + break; + } + ++inputIdx; } } - if (oldToNewNodes[parent.first]) { - oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId); + if (!found) { + break; } + } - ++parentId; + if (oldToNewNodes[it->first] == nullptr) { + it = newInputNodes.erase(it); + } + else { + it->first = oldToNewNodes[it->first]; + ++it; } } + newGraph->setOrderedInputs(newInputNodes); + + auto newOutputNodes = mOutputNodes; + for (auto it = newOutputNodes.begin(); it != newOutputNodes.end(); ) { + // If output node was removed, find previous valid output + while (oldToNewNodes[it->first] == nullptr) { + // Removed node should have only one connected data input, otherwise cloning is invalid + AIDGE_INTERNAL_ASSERT(it->first->nbData() <= 1); + auto parents = it->first->dataInputs(); + + if (!parents.empty() && parents[0].first != nullptr // a valid parent exists + && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView + { + *it = parents[0]; + } + else { + break; + } + } - // Update OutputNodes/inputNodes - newGraph->updateInputNodes(); - newGraph->updateOutputNodes(); + if (oldToNewNodes[it->first] == nullptr) { + it = newOutputNodes.erase(it); + } + else { + it->first = oldToNewNodes[it->first]; + ++it; + } + } + newGraph->setOrderedOutputs(newOutputNodes); return newGraph; } diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 5a7b05e469daab10a4abd468177a3ad137096f63..6f0cc55159b1cc72b87bb34230376eb140b7ab8a 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -11,22 +11,25 @@ #include "aidge/graph/Node.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/operator/Producer.hpp" #include <memory> #include <vector> + +#include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/utils/Types.h" Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) : mName(name), mOperator(op), - mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), - mChildren(std::vector<std::vector<std::weak_ptr<Node>>>(static_cast<std::size_t>(op->nbOutputs()), - std::vector<std::weak_ptr<Node>>())), - mIdInChildren( - std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), - mIdOutParents(std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { + mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), + nullptr)), + mChildren(std::vector<std::vector<std::weak_ptr<Node>>>( + static_cast<std::size_t>(op->nbOutputs()), std::vector<std::weak_ptr<Node>>())), + mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), + std::vector<IOIndex_t>())), + mIdOutParents( + std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { // ctor } @@ -34,14 +37,15 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// -Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { +Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) { assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); - for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inputs()) { - assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); - (void) input; // avoid unused warning + for (std::pair<std::shared_ptr<Node>, IOIndex_t>& input : inputs()) { + assert((gk_IODefaultIndex == input.second) && + "At least one input connection is not free.\n"); + (void)input; // avoid unused warning } IOIndex_t i = 0; - for (const Connector &ctor : ctors) { + for (const Connector& ctor : ctors) { if (ctor.node() != nullptr) { // ctor must be associated with a node ctor.node()->addChild(shared_from_this(), ctor.index(), i++); } @@ -53,7 +57,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector> &ctors) { // INNER /////////////////////////////////////////////////////// -void Aidge::Node::setName(const std::string &name) { mName = name; } +void Aidge::Node::setName(const std::string& name) { mName = name; } /////////////////////////////////////////////////////// // OPERATORS @@ -92,8 +96,8 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { return nbFreeDataIn; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::dataInputs() const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() + const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { @@ -104,15 +108,15 @@ Aidge::Node::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::inputs() const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbInputs()); for (std::size_t i = 0; i < nbInputs(); ++i) { - res[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); + res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); } return res; } -// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) { +// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> +// tensor) { // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); // if (mParents[idx] != nullptr) { // mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]); @@ -128,20 +132,21 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>> Aidge::Node::outputs() const { std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> listOutputs = - std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(mIdInChildren.size()); + std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>( + mIdInChildren.size()); for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { listOutputs[i] = output(static_cast<IOIndex_t>(i)); } return listOutputs; } -std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> -Aidge::Node::output(Aidge::IOIndex_t outId) const { +std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::output( + Aidge::IOIndex_t outId) const { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(mIdInChildren[outId].size()); for (std::size_t i = 0; i < mIdInChildren[outId].size(); ++i) { - listOutputs[i] = - std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), mIdInChildren[outId][i]); + listOutputs[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mChildren[outId][i].lock(), + mIdInChildren[outId][i]); } return listOutputs; } @@ -180,7 +185,8 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) // TOPOLOGY /////////////////////////////////////////////////////// -void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, const IOIndex_t otherInId) { +void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + const IOIndex_t otherInId) { assert((otherInId < otherNode->nbInputs()) && "Input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); if (otherNode->input(otherInId).second != gk_IODefaultIndex) { @@ -196,33 +202,41 @@ void Aidge::Node::addChildOp(std::shared_ptr<Node> otherNode, const IOIndex_t ou } void Aidge::Node::addChildView(std::shared_ptr<GraphView> otherGraph, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { - assert((otherInId.second < otherInId.first->nbInputs()) && "Other graph input index out of bound."); + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + assert((otherInId.second < otherInId.first->nbInputs()) && + "Other graph input index out of bound."); assert((outId < nbOutputs()) && "Output index out of bound."); std::set<std::shared_ptr<Node>> inNodes = otherGraph->inputNodes(); if (inNodes.size() == std::size_t(0)) { // no input Node printf("Cannot add GraphView to the Node. No input node detected.\n"); } else // inNodes.size() >= 1 { - assert((inNodes.find(otherInId.first) != inNodes.end())); // assert it really is an input node + assert((inNodes.find(otherInId.first) != + inNodes.end())); // assert it really is an input node addChildOp(otherInId.first, outId, otherInId.second); } } -void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, IOIndex_t otherInId) { - otherInId = (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); - addChildOp(otherNode, outId, otherInId); +void Aidge::Node::addChild(std::shared_ptr<Node> otherNode, const IOIndex_t outId, + IOIndex_t otherInId) { + if (otherNode) { + otherInId = + (otherInId != gk_IODefaultIndex) ? otherInId : otherNode->getFirstFreeDataInput(); + addChildOp(otherNode, outId, otherInId); + } } void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, - std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { + std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { if (!otherInId.first) { assert((otherView->inputNodes().size() == 1U) && "Specify an input Node for the GraphView. More or less than one " "Node is not explicit."); otherInId.first = *(otherView->inputNodes().begin()); } - otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second : otherInId.first->getFirstFreeDataInput(); + otherInId.second = (otherInId.second != gk_IODefaultIndex) + ? otherInId.second + : otherInId.first->getFirstFreeDataInput(); addChildView(otherView, outId, otherInId); } @@ -255,8 +269,8 @@ bool Aidge::Node::removeParent(const IOIndex_t inId) { std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { std::set<std::shared_ptr<Node>> children; - for (const auto &childrenOfOneOutput : mChildren) { - for (const auto &oneChild : childrenOfOneOutput) { + for (const auto& childrenOfOneOutput : mChildren) { + for (const auto& oneChild : childrenOfOneOutput) { children.insert(oneChild.lock()); } } @@ -264,7 +278,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { } std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { - std::vector<std::vector<std::shared_ptr<Node>>> children = std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); + std::vector<std::vector<std::shared_ptr<Node>>> children = + std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { children[outId] = getChildren(outId); } @@ -273,14 +288,16 @@ std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedCh std::vector<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren(const IOIndex_t outId) const { assert((outId < nbOutputs()) && "Output index out of bound."); - std::vector<std::shared_ptr<Node>> children = std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); + std::vector<std::shared_ptr<Node>> children = + std::vector<std::shared_ptr<Node>>(mChildren[outId].size()); for (std::size_t i = 0; i < mChildren[outId].size(); ++i) { - children.push_back(mChildren[outId][i].lock()); - } + children.push_back(mChildren[outId][i].lock()); + } return children; } -bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, const Aidge::IOIndex_t outId) { +bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, + const Aidge::IOIndex_t outId) { assert((outId < nbOutputs()) && "Child index out of bound."); bool removed = false; for (std::size_t j = 0; j < mChildren[outId].size(); ++j) { @@ -301,7 +318,8 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); if (parent.first) { // number of children linked to the parent's output - while(parent.first->removeChild(shared_from_this(), parent.second) == true) {} + while (parent.first->removeChild(shared_from_this(), parent.second) == true) { + } } // every reference to this object as child has been removed // removing reference to parents. @@ -316,24 +334,23 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { mIdInChildren[i] = std::vector<IOIndex_t>(); } // removing this Node from every GraphView it belongs to - for (auto& graph : views()) { - // if keeping connections with LEarnable Parameters, then also remove them from graph - graph->remove(shared_from_this(), !includeLearnableParam); - } + // for (auto& graph : views()) { + // // if keeping connections with LEarnable Parameters, then also remove them from graph + // graph->remove(shared_from_this(), !includeLearnableParam); + // } } - /////////////////////////////////////////////////////// - // CLONE - /////////////////////////////////////////////////////// +/////////////////////////////////////////////////////// +// CLONE +/////////////////////////////////////////////////////// Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { return std::make_shared<Node>(mOperator, mName); } Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { - std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) - ? mOperator - : mOperator->clone(); + std::shared_ptr<Operator> op = + (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone(); return std::make_shared<Node>(op, mName); } @@ -342,27 +359,25 @@ Aidge::NodePtr Aidge::Node::clone() const { return std::make_shared<Node>(mOperator->clone(), mName); } - -std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ - +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::NodePtr> nodeSee) { std::set<Aidge::NodePtr> out; nodeSee.insert(shared_from_this()); - if(delta == 0) { + if (delta == 0) { out.insert(shared_from_this()); - }else if (delta > 0){ - for (const NodePtr& node : getChildren()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + } else if (delta > 0) { + for (const NodePtr& node : getChildren()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta - 1, nodeSee)) { out.insert(ch); } } } - }else{ - for (const NodePtr& node : getParents()) { - if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance - for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + } else { + for (const NodePtr& node : getParents()) { + if (nodeSee.find(node) == nodeSee.end()) { // loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta + 1, nodeSee)) { out.insert(pr); } } diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f30ad6e25b81e1ce7768fcc201ddf00c2226eebf --- /dev/null +++ b/src/graph/Testing.cpp @@ -0,0 +1,133 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> // std::shuffle, std::transform +#include <cstddef> +#include <memory> +#include <numeric> // std::iota +#include <random> // std::binomial_distribution, std::mt19937, std::discrete_distribution +#include <string> +#include <utility> // std::pair +#include <vector> + +#include "aidge/graph/Testing.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/utils/Types.h" + +std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomGraph::gen(std::mt19937::result_type seed, std::size_t nbNodes) const { + std::mt19937 gen(seed); + std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn); + std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut); + std::binomial_distribution<> dLink(1, density); + std::discrete_distribution<> dType(typesWeights.begin(), typesWeights.end()); + + std::vector<std::pair<IOIndex_t, IOIndex_t>> nbIOs; + std::vector<std::string> nodesType; + for (std::size_t i = 0; i < nbNodes; ++i) { + const auto nbIn = 1 + dIn(gen); + nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen))); + nodesType.push_back(types[dType(gen)]); + } + + std::vector<std::size_t> nodesSeq(nbNodes); + std::iota(nodesSeq.begin(), nodesSeq.end(), static_cast<std::size_t>(0)); + // Don't use gen or seed here, must be different each time! + std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}())); + + std::vector<NodePtr> nodes(nbNodes, nullptr); + for (auto idx : nodesSeq) { + const std::string name = nodesType[idx] + std::to_string(idx); + nodes[idx] = GenericOperator(nodesType[idx], nbIOs[idx].first, 0, nbIOs[idx].second, name); + } + + for (std::size_t i = 0; i < nbNodes; ++i) { + for (std::size_t j = (acyclic) ? i + 1 : 0; j < nbNodes; ++j) { + if (i == j) { + // Do not connected node to itself in case of cyclic graph! + continue; + } + + for (std::size_t outId = 0; outId < nodes[i]->nbOutputs(); ++outId) { + for (std::size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) { + if (dLink(gen)) { + // Warning: connections can be set multiple time for the + // same node input! In this case, the previous connection + // is overwritten. This is the expected behavior. + nodes[i]->addChild(nodes[j], outId, inId); + if (nodes[i]->type() == omitType || nodes[j]->type() == omitType) { + // Let nodes[i]->addChild() overwrite the previous connection. + // Now we remove the new one! + nodes[i]->removeChild(nodes[j], outId); + nodes[j]->removeParent(inId); + } +/* + // Alternative: only add child if no node is omitted + // and remove the potential previous connection, like this: + if (nodes[i]->type() != omitType && nodes[j]->type() != omitType) { + nodes[i]->addChild(nodes[j], outId, inId); + } + else { + const auto prevIn = nodes[j]->input(inId); + + if (prevIn.first != nullptr) { + prevIn.first->removeChild(nodes[j], prevIn.second); + nodes[j]->removeParent(inId); + } + } +*/ + break; + } + } + } + } + } + + NodePtr rootNode = nullptr; + std::set<NodePtr> nodesSet; + for (std::size_t i = 0; i < nbNodes; ++i) { + if (nodes[i]->type() != omitType) { + if (rootNode == nullptr) { + rootNode = nodes[i]; + } + nodesSet.insert(nodes[i]); + } + } + + return std::make_pair(rootNode, nodesSet); +} + +std::string Aidge::nodePtrToType(NodePtr node) { + return node->type(); +} + +std::string Aidge::nodePtrToName(NodePtr node) { + return node->name(); +} + +std::set<std::string> Aidge::nodePtrTo(const std::set<NodePtr>& nodes, + std::string(*nodeTo)(NodePtr)) +{ + std::set<std::string> nodesStr; + std::transform(nodes.begin(), nodes.end(), std::inserter(nodesStr, nodesStr.begin()), nodeTo); + return nodesStr; +} + +std::vector<std::pair<std::string, Aidge::IOIndex_t>> Aidge::nodePtrTo( + const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes, + std::string(*nodeTo)(NodePtr)) +{ + std::vector<std::pair<std::string, IOIndex_t>> nodesStr; + std::transform(nodes.begin(), nodes.end(), std::back_inserter(nodesStr), + [nodeTo](const std::pair<NodePtr, IOIndex_t>& node) { + return std::make_pair(nodeTo(node.first), node.second); + }); + return nodesStr; +} diff --git a/src/operator/Add.cpp b/src/operator/Add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e638fd86da487565a89760925e45339213fa8f9 --- /dev/null +++ b/src/operator/Add.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Add.hpp" + +const std::string Aidge::Add_Op::Type = "Add"; \ No newline at end of file diff --git a/src/operator/Cast.cpp b/src/operator/Cast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f09d8eb83c6a6dae6416ffebcc01b22fb479a862 --- /dev/null +++ b/src/operator/Cast.cpp @@ -0,0 +1,26 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Cast.hpp" + +const std::string Aidge::Cast_Op::Type = "Cast"; + +void Aidge::Cast_Op::forward() { + if (mImpl) { + mImpl->forward(); + } + else { + mOutputs[0]->copyCast(*(mInputs[0])); + } + + runHooks(); +} diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eafcd126480df6da2c0127bdbb896d3ce98d0e0a --- /dev/null +++ b/src/operator/Concat.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Concat.hpp" + +const std::string Aidge::Concat_Op::Type = "Concat"; \ No newline at end of file diff --git a/src/operator/Div.cpp b/src/operator/Div.cpp index 273eac2e8fa9623e617d1be204ac2ae46d8da02d..85db3ac6ef66c837c86dbece288185deaca88ba6 100644 --- a/src/operator/Div.cpp +++ b/src/operator/Div.cpp @@ -11,6 +11,7 @@ #include <cassert> #include <cstddef> +#include <string> #include <vector> #include <utility> @@ -19,6 +20,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Div_Op::Type = "Div"; + void Aidge::Div_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/Erf.cpp b/src/operator/Erf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..387af4edf417f8c7ac6ee9b8b2b7069179ad59cb --- /dev/null +++ b/src/operator/Erf.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Erf.hpp" + +const std::string Aidge::Erf_Op::Type = "Erf"; \ No newline at end of file diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32114f5bf9e0d160db9fdc2d1971481be0b4e703 --- /dev/null +++ b/src/operator/FC.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/FC.hpp" + +const std::string Aidge::FC_Op::Type = "FC"; \ No newline at end of file diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30804994b6084a5a5558f106a38a6087e54471bc --- /dev/null +++ b/src/operator/Gather.cpp @@ -0,0 +1,39 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <cstddef> +#include <string> +#include <vector> + +#include "aidge/operator/Gather.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +const std::string Aidge::Gather_Op::Type = "Gather"; + +void Aidge::Gather_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + if (getInput(1)->nbDims()!=2){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor"); + } + + std::vector<DimSize_t> outDims = getInput(0)->dims(); + std::vector<DimSize_t> indexesDims = getInput(1)->dims(); + int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size(); + outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); + outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end()); + mOutputs[0]->resize(outDims); +} \ No newline at end of file diff --git a/src/operator/Identity.cpp b/src/operator/Identity.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f57906dd4f3564b52cde16236bda87370e8f86d7 --- /dev/null +++ b/src/operator/Identity.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Identity.hpp" + +const std::string Aidge::Identity_Op::Type = "Identity"; \ No newline at end of file diff --git a/src/operator/LeakyReLU.cpp b/src/operator/LeakyReLU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32e050ee1595cf83b5cd0ffbfeba6153dc2243af --- /dev/null +++ b/src/operator/LeakyReLU.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/LeakyReLU.hpp" + +const std::string Aidge::LeakyReLU_Op::Type = "LeakyReLU"; \ No newline at end of file diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..666ed3921ed1190a91935bd9f38303e23963d912 --- /dev/null +++ b/src/operator/MatMul.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/MatMul.hpp" + +const std::string Aidge::MatMul_Op::Type = "MatMul"; \ No newline at end of file diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index bbc921d3c7b334223b2a92a8fbfee1ffae9c10e1..530357085a16ca3e834669cebd2d26882ca8ddab 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -12,63 +12,19 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" -Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, - std::vector<NodePtr> inputNodes, - std::vector<NodePtr> outputNodes) +Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph) : OperatorTensor(type, graph->dataInputs().size(), (graph->inputs().size() - graph->dataInputs().size()), graph->outputs().size()), mGraph(graph) { - // Fill inputsNodes and outputsNodes when there is no ambiguity - if (inputNodes.empty()) { - AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping"); - inputNodes.push_back(*mGraph->inputNodes().begin()); + mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); + for (std::size_t i = 0; i < mInputs.size(); ++i) { + mInputs[i] = std::make_shared<Tensor>(); } - - if (outputNodes.empty()) { - AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping"); - outputNodes.push_back(*mGraph->outputNodes().begin()); - } - - AIDGE_ASSERT(mGraph->inputNodes().size() == inputNodes.size(), "wrong number of specified input nodes"); - AIDGE_ASSERT(mGraph->outputNodes().size() == outputNodes.size(), "wrong number of specified output nodes"); - - // Identify inputs that are outside the micro-graph - for (const auto& inputNode : inputNodes) { - AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph"); - const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = - inputNode->inputs(); - - int inputIdx = 0; // input idx relative to the current node - for (const auto& in : inputNodeinputs) { - if (in.first == nullptr || !mGraph->inView(in.first)) { - // The input is not connected inside the micro-graph - // (no connection to this input or connection outside the micro-graph) - // => it is therefore an input for the meta-operator - mInputOps.push_back(std::make_pair(std::dynamic_pointer_cast<OperatorTensor>(inputNode->getOperator()), inputIdx)); - } - - ++inputIdx; - } - } - - // The outputs of the output nodes are also the outputs of the meta-operator - for (const auto& outputNode : outputNodes) { - AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph"); - const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs = - outputNode->outputs(); - - for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) { - mOutputOps.push_back(std::make_pair(std::dynamic_pointer_cast<OperatorTensor>(outputNode->getOperator()), outputIdx)); - } - } - - - AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); - AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); // Associate outputs to micro-graph outputs for custom implementation - for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { - const auto& outputOp = mOutputOps[outputIdx]; - mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedOutputs().size()); + for (size_t outputIdx = 0; outputIdx < mOutputs.size(); ++outputIdx) { + const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; + mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second)); } } @@ -77,8 +33,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI return mImpl->getNbRequiredData(inputIdx); } else { - const auto& inputOp = mInputOps[inputIdx]; - return inputOp.first->getNbRequiredData(inputOp.second); + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + return inputOp.first->getOperator()->getNbRequiredData(inputOp.second); } } @@ -87,8 +43,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co return mImpl->getNbConsumedData(inputIdx); } else { - const auto& inputOp = mInputOps[inputIdx]; - return inputOp.first->getNbConsumedData(inputOp.second); + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + return inputOp.first->getOperator()->getNbConsumedData(inputOp.second); } } @@ -97,8 +53,8 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c return mImpl->getNbProducedData(outputIdx); } else { - const auto& outputOp = mOutputOps[outputIdx]; - return outputOp.first->getNbProducedData(outputOp.second); + const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; + return outputOp.first->getOperator()->getNbProducedData(outputOp.second); } } diff --git a/src/operator/Move.cpp b/src/operator/Move.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8776e32fca909663bafe3fae3ebf9f5616c69c9 --- /dev/null +++ b/src/operator/Move.cpp @@ -0,0 +1,26 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Move.hpp" + +const std::string Aidge::Move_Op::Type = "Move"; + +void Aidge::Move_Op::forward() { + if (mImpl) { + mImpl->forward(); + } + else { + mOutputs[0]->copyFrom(*(mInputs[0])); + } + + runHooks(); +} diff --git a/src/operator/Mul.cpp b/src/operator/Mul.cpp index 2e3e77288bf1e0613f0aa572e3c50e94599a902f..bc268263e8a6e2ec7c9944faa31da84dc50c4f53 100644 --- a/src/operator/Mul.cpp +++ b/src/operator/Mul.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Mul_Op::Type = "Mul"; + void Aidge::Mul_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 1237fdc0b5565681ab1a6af6d88f74a48cbd5b57..72a71814b1463395443c6a4504f2eef660ec1185 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -88,8 +88,8 @@ const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aid } -std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>> Aidge::OperatorTensor::computeReceptiveField( - const std::size_t firstIdx, +std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>> Aidge::OperatorTensor::computeReceptiveField( + const std::vector<DimSize_t>& firstEltDims, const std::vector<Aidge::DimSize_t>& outputDims, const Aidge::IOIndex_t outputIdx) const { @@ -103,14 +103,13 @@ std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>> Aidge::Operat if (!outputDimsForwarded() || getOutput(0)->nbDims() != outputDims.size()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); } - const auto outputIdxDims = getOutput(0)->getCoord(firstIdx); for (DimIdx_t i = 0; i < outputDims.size(); ++i) { - if (((outputDims[i] + outputIdxDims[i]) > getOutput(0)->dims()[i]) || (outputDims[i] == 0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); + if (((outputDims[i] + firstEltDims[i]) > getOutput(0)->dims()[i]) || (outputDims[i] == 0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]); } } // return the same Tensor description as given in function parameter for each data input - return std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>>(nbData(),std::pair<std::size_t, std::vector<Aidge::DimSize_t>>(firstIdx, outputDims)); + return std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>(nbData(),std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>(firstEltDims, outputDims)); } void Aidge::OperatorTensor::computeOutputDims() { @@ -149,12 +148,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); } - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set"); - } - else { - getInput(i)->setDataType(dataType); - } + + for (IOIndex_t i = nbData(); i < nbInputs(); ++i) { + getInput(i)->setDataType(dataType); } } \ No newline at end of file diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp index c213a47a4a590026c07625aeb532d303ca8dbced..de1f0c3694f51fbd5b365573f61d3e3e2b9109ff 100644 --- a/src/operator/Pow.cpp +++ b/src/operator/Pow.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Pow_Op::Type = "Pow"; + void Aidge::Pow_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..443f2fa7d8a60cd25ccb622f2dad5b4926b88eea --- /dev/null +++ b/src/operator/Producer.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Producer.hpp" + +const std::string Aidge::Producer_Op::Type = "Producer"; \ No newline at end of file diff --git a/src/operator/ReLU.cpp b/src/operator/ReLU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f7874acfe7d865ea8c56d4bca02b51864480df6 --- /dev/null +++ b/src/operator/ReLU.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/ReLU.hpp" + +const std::string Aidge::ReLU_Op::Type = "ReLU"; \ No newline at end of file diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0eea3c1f9f7054021b631c85e0f80e7f8845da6 --- /dev/null +++ b/src/operator/Reshape.cpp @@ -0,0 +1,47 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cstddef> +#include <string> +#include <vector> + +#include "aidge/operator/Reshape.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +const std::string Aidge::Reshape_Op::Type = "Reshape"; + +void Aidge::Reshape_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); + } + + DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size(); + std::vector<DimSize_t> outDims; + std::size_t outSize = 1; + for(std::size_t i=0; i<nbOutDims; ++i) + { + int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; + if (dimSize < 1) + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value"); + } + outDims.push_back(dimSize); + outSize *= dimSize; + } + + if (getInput(0)->size() != outSize){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input"); + } + + mOutputs[0]->resize(outDims); +} \ No newline at end of file diff --git a/src/operator/Scaling.cpp b/src/operator/Scaling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c121e1268c1e1a62f793f38c6d816e7c6b48c25 --- /dev/null +++ b/src/operator/Scaling.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Scaling.hpp" + +const std::string Aidge::Scaling_Op::Type = "Scaling"; \ No newline at end of file diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..139e84b561a48c2f6a5ecd14ed9d6905d66dec20 --- /dev/null +++ b/src/operator/Slice.cpp @@ -0,0 +1,51 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include "aidge/operator/Slice.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +#include <cassert> +#include <cstddef> +#include <string> +#include <utility> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +const std::string Aidge::Slice_Op::Type = "Slice"; + +void Aidge::Slice_Op::computeOutputDims() { + // check input have been associated + if (!getInput(0) || (getInput(0)->empty())) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); + } + + DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); + std::vector<DimSize_t> outDims = getInput(0)->dims(); + for (std::size_t i = 0; i < nbAxes; ++i) { + // For each slice operation get the params and cast them to size_t + const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; + const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; + const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; + const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims(); + const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis]; + const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis]; + + const std::size_t sliceLength = end - start + 1; + // Check if slice length is valid + if (sliceLength > getInput(0)->dims()[axis]) + AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); + outDims[axis] = sliceLength; + } + mOutputs[0]->resize(outDims); +} diff --git a/src/operator/Softmax.cpp b/src/operator/Softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e88ff4bb4ec6e2cb1357d578c2d07cc4edcb59f7 --- /dev/null +++ b/src/operator/Softmax.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Softmax.hpp" + +const std::string Aidge::Softmax_Op::Type = "Softmax"; \ No newline at end of file diff --git a/src/operator/Sqrt.cpp b/src/operator/Sqrt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbcaba42619762f8fd00bb2f6e0aa0de11d92960 --- /dev/null +++ b/src/operator/Sqrt.cpp @@ -0,0 +1,16 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <string> + +#include "aidge/operator/Sqrt.hpp" + +const std::string Aidge::Sqrt_Op::Type = "Sqrt"; \ No newline at end of file diff --git a/src/operator/Sub.cpp b/src/operator/Sub.cpp index 8175f1b7ae5bb5eccd36267c1d739f764bd3c236..639eaf798c1c2a9a6685e8b8d2c4a2cb00a4b57a 100644 --- a/src/operator/Sub.cpp +++ b/src/operator/Sub.cpp @@ -19,6 +19,8 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +const std::string Aidge::Sub_Op::Type = "Sub"; + void Aidge::Sub_Op::computeOutputDims() { // check inputs have been associated if (!getInput(0) || !getInput(1)) { diff --git a/src/recipies/ExplicitCastMove.cpp b/src/recipies/ExplicitCastMove.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5651f2ba4cc939678ab306137464c52caa1db46c --- /dev/null +++ b/src/recipies/ExplicitCastMove.cpp @@ -0,0 +1,123 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/recipies/Recipies.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Cast.hpp" +#include "aidge/operator/Move.hpp" + +void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { + // First, remove existing Cast and Move operators, if not needed anymore + auto nodes = graph->getNodes(); + for (auto node : nodes) { + // TODO: currently, Operator data type is only reflected in its output tensor data type. + // But an Operator might have multiple outputs of different data type(?) + const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0); + if (output->getImpl() == nullptr) { + continue; + } + const auto& device = output->getImpl()->device(); + + if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) { + // Remove existing Cast and Move operators, if not needed anymore + AIDGE_INTERNAL_ASSERT(node->inputs().size() == 1); + const auto parent = node->inputs()[0]; + // Check parent is not nullptr, as this Operator may be an entry point of the graph without parent + if (parent.first != nullptr) { + const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); + + if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType()) + || (node->type() == Move_Op::Type && input->getImpl() != nullptr && input->getImpl()->device() == device)) + { + // Add direct connection bypassing Cast/Move node + const auto childs = node->outputs()[0]; + for (const auto& child : childs) { + parent.first->addChild(child.first, parent.second, child.second); + } + + // Remove all node connections + node->resetConnections(); + // Remove node from view + graph->remove(node); + } + } + } + } + + // Note: why two steps and not merge the two node loops? + // User may have changed some data type/backends on top of existing Cast/Move operators + // This may lead to situation where a Cast should be removed but a Move should + // be inserted at the same place. In this case, some conversion may be missed + // depending on the order of iteration over the nodes (which are non ordered!). + + // Second, insert Cast and/or Move operator between node inputs and parent output, if needed + nodes = graph->getNodes(); + for (auto node : nodes) { + // TODO: currently, Operator data type is only reflected in its output tensor data type. + // But an Operator might have multiple outputs of different data type(?) + const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0); + if (output->getImpl() == nullptr) { + continue; + } + const auto& device = output->getImpl()->device(); + + IOIndex_t inputIdx = 0; + for (auto parent : node->inputs()) { + // TODO: possible optimization: currently, a Cast/Move Operator may + // be added several time to the same output, if it has multiple childs, + // even if it is the same conversion each time. + if (parent.first != nullptr) { + const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); + + NodePtr moveOp = nullptr; + NodePtr castOp = nullptr; + + if (node->type() != Move_Op::Type && input->getImpl()->device() != device) { + // Change of backend => a Move operator is required + moveOp = Move(); + moveOp->getOperator()->setDataType(input->dataType()); + castOp = moveOp; + } + + if (node->type() != Cast_Op::Type && input->dataType() != output->dataType()) { + // Change of date type => a Cast operator is required + castOp = Cast(); + castOp->getOperator()->setDataType(output->dataType()); + castOp->getOperator()->setBackend(device.first, device.second); + + if (moveOp == nullptr) { + moveOp = castOp; + } + else { + moveOp->addChild(castOp, 0, 0); + } + } + + if (moveOp != nullptr && castOp != nullptr) { + // Move and/or Cast Operator(s) are needed + castOp->addChild(node, 0, inputIdx); + parent.first->addChild(moveOp, parent.second, 0); + // Set backend AFTER connection in case a specific implementation + // of the operator exists for the input type. + moveOp->getOperator()->setBackend(device.first, device.second); + + // Add/update nodes in the GraphView + graph->add(moveOp); + graph->add(castOp); + graph->add(parent.first); + graph->add(node); + } + } + + ++inputIdx; + } + } +} diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index ffb4599d83ba922ce5991460810f5d248806617c..9c4cad3f7a444c627f2324f729cb3bc3d8517f49 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -33,10 +33,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); - const std::shared_ptr<Tensor> scale = batchOp->getInput(1); - const std::shared_ptr<Tensor> shift = batchOp->getInput(2); - const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3); - const std::shared_ptr<Tensor> b_var = batchOp->getInput(4); + std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf; + const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu"); + const Tensor& shift = batchOp->getInput(2)->refCastFrom(shiftBuf, DataType::Float32, "cpu"); + const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu"); + const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu"); const float epsilon = batchOp -> getAttr<float>("Epsilon"); const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels"); @@ -44,10 +45,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims"); - assert(scale->size() == convNbOutChannels); - assert(shift->size() == convNbOutChannels); - assert(b_mean->size() == convNbOutChannels); - assert(b_var->size() == convNbOutChannels); + assert(scale.size() == convNbOutChannels); + assert(shift.size() == convNbOutChannels); + assert(b_mean.size() == convNbOutChannels); + assert(b_var.size() == convNbOutChannels); assert(epsilon > 0.0); // TODO : no no_bias attribute ? @@ -56,9 +57,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr unsigned int count = 0; for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { - // TODO: get() assumed dataType is float... - if (b_var->get<float>(outChId) > 1.0e-12) { - meanVariance += b_var->get<float>(outChId); + if (b_var.get<float>(outChId) > 1.0e-12) { + meanVariance += b_var.get<float>(outChId); ++count; } else { @@ -71,39 +71,43 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } - std::shared_ptr<Tensor> weight = convOp -> getInput(1); - std::shared_ptr<Tensor> bias = convOp -> getInput(2); + std::shared_ptr<Tensor> weightBuf, biasBuf; + Tensor& weight = convOp->getInput(1)->refCastFrom(weightBuf, DataType::Float32, "cpu"); + Tensor& bias = convOp->getInput(2)->refCastFrom(biasBuf, DataType::Float32, "cpu"); for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { // Corrected for zero-variance issue: // "A Quantization-Friendly Separable Convolution for MobileNets" // https://arxiv.org/pdf/1803.08607.pdf // to help post-training quantization - const float factor = scale->get<float>(outChId) - / std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0) - ? b_var->get<float>(outChId) : meanVariance)); + const float factor = scale.get<float>(outChId) + / std::sqrt(epsilon + ((b_var.get<float>(outChId) > 1.0e-12 || count == 0) + ? b_var.get<float>(outChId) : meanVariance)); // Weights adjustments for (std::size_t channel = 0; channel < channelsSize; ++channel) { // TODO : Suppose kerneldims = 2 for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1}; - // TODO : suppose weights are float - float weightValue = weight->get<float>(currentIdx); - weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights + float weightValue = weight.get<float>(currentIdx); + weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights } } } // TODO : check if noBias==true is set, then set biasValue to 0 - float biasValue = bias->get<float>(outChId); + float biasValue = bias.get<float>(outChId); - biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor; + biasValue = shift.get<float>(outChId) + (biasValue - b_mean.get<float>(outChId)) * factor; - bias->set<float>(outChId, biasValue); + bias.set<float>(outChId, biasValue); } + // Copy values back to the original tensors (actual copy only if needed) + convOp->getInput(1)->copyCastFrom(weight); + convOp->getInput(2)->copyCastFrom(bias); + GraphView::replace(std::set<std::shared_ptr<Node>>({ batchnormNode, batchnormNode->input(1).first, diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index d37f4749635b2bf76d10f7f8de3a44e254c56347..322b1d9a0632b893a912c6225ac5b13d63278f8d 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -38,9 +38,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< // Fetch the output dimension throught the bias size std::shared_ptr<Node> bias = (addNode->getParent(1)) ? addNode->getParent(1)->cloneSharedOperators() : nullptr; - if (!(matmulNode->getParent(1))) { - AIDGE_INTERNAL_ASSERT("No weight detected to produce the fuseMulAdd recipe."); - } + AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe."); + std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); const DimSize_t outSize = std::dynamic_pointer_cast<MatMul_Op>(matmulNode->getOperator()) -> getAttr<DimSize_t>("OutChannels"); diff --git a/src/recipies/HorizontalTiling.cpp b/src/recipies/HorizontalTiling.cpp index d8eb015939e7be19eb866b75e5a5601ba80631d0..6cc34eba076934b884b336ce40081a855d917182 100644 --- a/src/recipies/HorizontalTiling.cpp +++ b/src/recipies/HorizontalTiling.cpp @@ -11,6 +11,7 @@ #include <set> #include <memory> +#include <numeric> // std::iota #include <vector> #include <utility> @@ -74,16 +75,26 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: res.insert(clonedInputs[i]); } - for (; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis]) { - const auto inputDims = op->computeReceptiveField(outTensor->getIdx(currentFirstDims), outputDims, 0); + for (IOIndex_t i = 0; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis], ++i) { + const auto inputDims = op->computeReceptiveField(currentFirstDims, outputDims, 0); auto newNode = node -> clone(); // no input associated to clones newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis])); clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[2] -> addChild(newNode, 0, 2); // Slice for input and each parameter - auto slice = Slice(inputDims[0].first, inputDims[0].second, "Slice_" + std::to_string(currentFirstDims[axis])); + std::vector<std::int32_t> inputDimsEnd(inputDims[0].first.size()); + for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { + inputDimsEnd[dim] = static_cast<std::int32_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1; + } + std::vector<std::int32_t> inputDimsStart(inputDims[0].first.size()); + for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { + inputDimsStart[dim] = static_cast<std::int32_t>(inputDims[0].first[dim]); + } + std::vector<std::int32_t> usedDims(inputDimsEnd.size()); + std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int32_t>(0)); + auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); slice -> addChild(newNode, 0, 0); - newNode -> addChild(concat, 0, currentFirstDims[axis]); + newNode -> addChild(concat, 0, i); res.insert(slice); res.insert(newNode); diff --git a/src/recipies/RemoveDropout.cpp b/src/recipies/RemoveDropout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1dedac8f19e6ec6b4b1f6dabb6bd3e9b8c759def --- /dev/null +++ b/src/recipies/RemoveDropout.cpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> +#include <iostream> + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/recipies/Recipies.hpp" + +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + + +namespace Aidge { + void removeDropout(std::shared_ptr<Node> dropout) { + + std::set<NodePtr> nodesToRemove; + for (auto nodePtr: dropout->getParents()) + { + if(nodePtr->type() == "Producer") + { + nodesToRemove.insert(nodePtr); + } + } + nodesToRemove.insert(dropout); + GraphView::replace(nodesToRemove, {}); + } + + void removeDropout(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n"); + + for (const auto& dropout : solution->at("Dropout")) { + + removeDropout(dropout); + } + } + + void removeDropout(std::shared_ptr<GraphView> graphView){ + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Dropout","getType($) =='Dropout'"); + regex->addQuery("Dropout#"); + + for (const auto& solution : regex->match(graphView)) { + removeDropout(solution); + } + } +} diff --git a/unit_tests/graph/Test_Connector.cpp b/unit_tests/graph/Test_Connector.cpp index a7cee610e0014dc024271a008ed964fa67d367ea..79acce9281039f9f3c67b7235d8999b6c7173685 100644 --- a/unit_tests/graph/Test_Connector.cpp +++ b/unit_tests/graph/Test_Connector.cpp @@ -16,6 +16,7 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" +#include "aidge/graph/Testing.hpp" using namespace Aidge; @@ -112,7 +113,9 @@ TEST_CASE("GraphGeneration from Connector", "[GraphView]") { x= (*node09)({x}); x = (*node10)({a, x}); std::shared_ptr<GraphView> gv = generateGraph({x}); - gv->save("GraphGeneration"); + // gv->save("GraphGeneration"); + REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_matmul1", 0}})); } TEST_CASE("Connector connection GraphView", "[Connector]") { @@ -131,6 +134,9 @@ TEST_CASE("Connector connection GraphView", "[Connector]") { GenericOperator("g_conv3", 1, 0, 1), GenericOperator("g_matmul1", 2, 0, 1) }); + REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv1", 0}})); + REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_matmul1", 0}})); + x = (*prod)({}); x = (*g)({x}); std::shared_ptr<GraphView> g2 = generateGraph({x}); @@ -151,10 +157,14 @@ TEST_CASE("Connector connection GraphView", "[Connector]") { GenericOperator("g_concat", 3, 0, 1), GenericOperator("g_conv3", 1, 0, 1) }); + REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"ElemWise", 0}, {"ElemWise", 1}, {"ElemWise", 2}})); + REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv3", 0}})); x = (*g)({x, y, z}); std::shared_ptr<GraphView> gv = generateGraph({x}); - gv->save("MultiInputSequentialConnector"); + REQUIRE(nodePtrTo(gv->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(gv->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"g_conv3", 0}})); + // gv->save("MultiInputSequentialConnector"); REQUIRE(gv->inputNodes().size() == 0U); } } @@ -169,7 +179,9 @@ TEST_CASE("Connector Mini-graph", "[Connector]") { } y = (*GenericOperator("ElemWise", 2, 0, 1))({y, x}); std::shared_ptr<GraphView> g = generateGraph({y}); - g->save("TestGraph"); + REQUIRE(nodePtrTo(g->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(g->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"ElemWise", 0}})); + // g->save("TestGraph"); } TEST_CASE("Structural descrition - Sequential", "[GraphView]") { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index bb726bd4d92b5674d0e19ea3138e165e1329959a..ebbfb3ad89721eb4f1390c3efca475acbb0b6f46 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -9,6 +9,7 @@ * ********************************************************************************/ +#include <algorithm> // std::sort #include <cassert> #include <map> #include <memory> @@ -20,20 +21,201 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Testing.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/graph/OpArgs.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" using namespace Aidge; -TEST_CASE("[core/graph] GraphView(Constructor)") { +TEST_CASE("genRandomGraph", "[GraphView][randomGen]") { + const size_t nbTests = 100; + size_t nbUnicity = 0; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + const auto g2 = std::make_shared<GraphView>("g2"); + const bool unicity2 = g2->add(randGraph.gen(seed, 10)); + + // g1->save("./genRandomGraph1"); + // g2->save("./genRandomGraph2"); + + REQUIRE(unicity1 == unicity2); + + if (unicity1) { + REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); + ++nbUnicity; + + // Check that inputs/outputs are the same regardless of the order + auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); + auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); + auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); + auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); + std::sort(orderedInputs1.begin(), orderedInputs1.end()); + std::sort(orderedInputs2.begin(), orderedInputs2.end()); + std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); + std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); + + REQUIRE(orderedInputs1 == orderedInputs2); + REQUIRE(orderedOutputs1 == orderedOutputs2); + REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); + } + } + + printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests); +} + +TEST_CASE("clone", "[GraphView][clone]") { + const size_t nbTests = 100; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + const auto g1 = std::make_shared<GraphView>("g1"); + g1->add(randGraph.gen(seed, 10)); + // g1 -> save("GraphView_clone"); + const auto g2 = g1->clone(); + + REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + } +} + +NodePtr nodeDel(NodePtr node) { + if (node->type() == "DelFictive") { + return nullptr; + } + return node->clone(); +} + +TEST_CASE("clone_with_delete", "[GraphView][cloneDelete]") { + const size_t nbTests = 100; + size_t nbClonedWithDelete = 0; + + // Note: initial seed is chosen such that for nbTests=100, the generated + // graphs keep the same inputs/outputs despites the deleted nodes + // (meaning the deleted nodes are not input/output of the graph). + // Otherwise, the last two REQUIRE are not garanteed to be true! + // Warning: distributions are not required to behave the same way by the standard, + // therefore the seed has to work for both GCC and MSVC... + // See https://stackoverflow.com/questions/38532927/why-gcc-and-msvc-stdnormal-distribution-are-different + std::mt19937::result_type seed(243); + + for (int test = 0; test < nbTests; ++test) { + RandomGraph randGraph; + randGraph.types = {"Fictive", "DelFictive"}; + randGraph.typesWeights = {0.9, 0.1}; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + + if (unicity1) { + randGraph.omitType = "DelFictive"; + const auto g2 = std::make_shared<GraphView>("g2"); + const bool unicity2 = g2->add(randGraph.gen(seed, 10)); + + // g1->save("./clone_with_delete1"); + // g2->save("./clone_with_delete2"); + + try { + const auto gCloned = g1->cloneCallback(&nodeDel); + + REQUIRE(nodePtrTo(gCloned->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + REQUIRE(nodePtrTo(gCloned->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + REQUIRE(nodePtrTo(gCloned->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + ++nbClonedWithDelete; + } + catch (const std::runtime_error& error) { + // pass + } + } + + ++seed; + } + + printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests); +} + +TEST_CASE("remove", "[GraphView][remove]") { + const size_t nbTests = 100; + size_t nbTested = 0; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + randGraph.types = {"Fictive", "DelFictive"}; + randGraph.typesWeights = {0.8, 0.2}; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + + if (unicity1) { + // g1->save("./remove1_before"); + const auto nodes = g1->getNodes(); + int step = 1; + for (auto node : nodes) { + if (node->type() == "DelFictive") { + g1->remove(node, false); + // g1->save("./remove1_after" + std::to_string(step)); + step++; + } + } + + randGraph.omitType = "DelFictive"; + const auto g2 = std::make_shared<GraphView>("g2"); + g2->add(randGraph.gen(seed, 10)); + + // g1->save("./remove1"); + // g2->save("./remove2"); + + REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); + // Order not garanteed, because when a node is removed, it can create new GraphView inputs/outputs + // Their order thus depends on the deletion order! + //REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); + //REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); + + // Check that inputs/outputs are the same regardless of the order + auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); + auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); + auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); + auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); + std::sort(orderedInputs1.begin(), orderedInputs1.end()); + std::sort(orderedInputs2.begin(), orderedInputs2.end()); + std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); + std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); + + REQUIRE(orderedInputs1 == orderedInputs2); + REQUIRE(orderedOutputs1 == orderedOutputs2); + ++nbTested; + } + } + + printf("nbTested = %zu/%zu\n", nbTested, nbTests); +} + +TEST_CASE("[core/graph] GraphView(Constructor)", "[GraphView][constructor()]") { std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1"); REQUIRE(g0 != nullptr); REQUIRE(g1 != nullptr); } -TEST_CASE("[core/graph] GraphView(add)") { +TEST_CASE("[core/graph] GraphView(add)", "[GraphView][add]") { SECTION("Node alone") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); @@ -48,6 +230,9 @@ TEST_CASE("[core/graph] GraphView(add)") { g->add(GOp5); std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 1, 1, "Gop6"); g->add(GOp6); + // g->save("node_alone"); + REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop3", 0}, {"Gop4", 0}, {"Gop5", 0}, {"Gop6", 0}, {"Gop6", 1}})); + REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop2", 0}, {"Gop5", 0}, {"Gop6", 0}})); } SECTION("Several Nodes") { @@ -58,10 +243,14 @@ TEST_CASE("[core/graph] GraphView(add)") { GOp1parent->addChild(GOp1, 0, 0); g->add(GOp1); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}})); // there should be no deplicates g->add(GOp1); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); + REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({})); + REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}})); } SECTION("Initializer list ofr Node") { @@ -221,7 +410,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } } - SECTION("disconnect data iput + learnable parameters") { + SECTION("disconnect data input + learnable parameters") { std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); @@ -402,6 +591,56 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0})); } + + SECTION("Nodes with shared parameters") { + + auto myConv1 = Conv(1, 5, {1,1}, "conv1"); + auto myConv2 = Conv(5, 5, {1,1}, "conv2"); + auto myConv3 = Conv(5, 5, {1,1}, "conv3"); + auto myConv4 = Conv(5, 5, {1,1}, "conv4"); + auto myConv5 = Conv(5, 5, {1,1}, "conv5"); + + auto sharedWeightTensor = std::make_shared<Tensor>(); + sharedWeightTensor->resize({5,5,1,1}); + auto sharedWeight = Producer(sharedWeightTensor, "sharedWeight"); + sharedWeight -> addChild(myConv2, 0, 1); + sharedWeight -> addChild(myConv3, 0, 1); + sharedWeight -> addChild(myConv4, 0, 1); + + auto sharedBiasTensor = std::make_shared<Tensor>(); + sharedBiasTensor->resize({5}); + auto sharedBias = Producer(sharedBiasTensor, "sharedBias"); + sharedBias -> addChild(myConv2, 0, 2); + sharedBias -> addChild(myConv3, 0, 2); + sharedBias -> addChild(myConv4, 0, 2); + + auto g = Sequential({ + myConv1, + myConv2, + myConv3, + myConv4, + myConv5 + }); + + REQUIRE(g->getNode("sharedWeight") != nullptr); + REQUIRE(g->getNode("sharedBias") != nullptr); + + + auto newReLU4 = ReLU("relu4"); + GraphView::replace({myConv4, myConv4->getParent(1), myConv4->getParent(2)}, {newReLU4}); + REQUIRE(g->getNode("sharedWeight") != nullptr); + REQUIRE(g->getNode("sharedBias") != nullptr); + + auto newReLU3 = ReLU("relu3"); + GraphView::replace({myConv3, myConv3->getParent(1), myConv3->getParent(2)}, {newReLU3}); + REQUIRE(g->getNode("sharedWeight") != nullptr); + REQUIRE(g->getNode("sharedBias") != nullptr); + + auto newReLU2 = ReLU("relu2"); + GraphView::replace({myConv2, myConv2->getParent(1), myConv2->getParent(2)}, {newReLU2}); + REQUIRE(g->getNode("sharedWeight") == nullptr); + REQUIRE(g->getNode("sharedBias") == nullptr); + } } TEST_CASE("[GraphView] clone") { diff --git a/unit_tests/operator/Test_ConvDepthWise_Op.cpp b/unit_tests/operator/Test_ConvDepthWise_Op.cpp index 14d4dc537f527b32414151ee7f93e601f5a4bd8a..6008e3bfac346725935d5d8ffe87f392c49a3409 100644 --- a/unit_tests/operator/Test_ConvDepthWise_Op.cpp +++ b/unit_tests/operator/Test_ConvDepthWise_Op.cpp @@ -45,20 +45,20 @@ TEST_CASE("[core/operator] ConvDepthWise_Op(computeReceptiveField)", "[Operator] auto op4 = std::dynamic_pointer_cast<OperatorTensor>(cdw4 -> getOperator()); SECTION("Check individual receptive fields") { - auto res1 = op1->computeReceptiveField(0, {16,3,10,10}); - auto res2 = op2->computeReceptiveField(op2->getOutput(0)->getIdx({3,1,100,28}), {4,2,30,40}); - auto res3 = op3->computeReceptiveField(0, {1,1,109,109}); - auto res4 = op4->computeReceptiveField(op4->getInput(0)->getIdx({5,0,108,108}), {10,1,1,1}); + auto res1 = op1->computeReceptiveField({0,0,0,0}, {16,3,10,10}); + auto res2 = op2->computeReceptiveField({3,1,100,28}, {4,2,30,40}); + auto res3 = op3->computeReceptiveField({0,0,0,0}, {1,1,109,109}); + auto res4 = op4->computeReceptiveField({5,0,108,108}, {10,1,1,1}); - REQUIRE(((res1[0].first == 0) && (res1[0].second == std::vector<DimSize_t>({16, 3, 14, 14})))); - REQUIRE(((res2[0].first == op2->getInput(0)->getIdx({3,1,100,28})) && (res2[0].second == std::vector<DimSize_t>({4, 2, 32, 42})))); - REQUIRE(((res3[0].first == 0) && (res3[0].second == std::vector<DimSize_t>({1, 1, 218, 218})))); - REQUIRE(((res4[0].first == op4->getInput(0)->getIdx({5, 0, 108, 108})) && (res4[0].second == std::vector<DimSize_t>({10, 1, 1, 1})))); + REQUIRE(((res1[0].first == std::vector<DimSize_t>({0,0,0,0})) && (res1[0].second == std::vector<DimSize_t>({16, 3, 14, 14})))); + REQUIRE(((res2[0].first == std::vector<DimSize_t>({3,1,100,28})) && (res2[0].second == std::vector<DimSize_t>({4, 2, 32, 42})))); + REQUIRE(((res3[0].first == std::vector<DimSize_t>({0,0,0,0})) && (res3[0].second == std::vector<DimSize_t>({1, 1, 218, 218})))); + REQUIRE(((res4[0].first == std::vector<DimSize_t>({5,0,108,108})) && (res4[0].second == std::vector<DimSize_t>({10, 1, 1, 1})))); } SECTION("Check receptive field propagation") { // input: first-{5, 0, 50, 50} dims-{1, 1, 1, 1} - auto res4 = op4->computeReceptiveField(op4->getInput(0)->getIdx({5,0,50,50}), {1,1,1,1}); + auto res4 = op4->computeReceptiveField({5,0,50,50}, {1,1,1,1}); // cdw4 RF: first-{5, 0, 50, 50} dims-{1, 1, 1, 1} auto res3 = op3->computeReceptiveField(res4[0].first, res4[0].second); // cdw3 RF: first-{5, 0, 100, 100} dims-{1, 1, 2, 2} @@ -67,7 +67,7 @@ TEST_CASE("[core/operator] ConvDepthWise_Op(computeReceptiveField)", "[Operator] auto res1 = op1->computeReceptiveField(res2[0].first, res2[0].second); // cdw1 RF: first-{5, 0, 100, 100} dims-{1, 1, 8, 8} - REQUIRE(((res1[0].first == op1->getInput(0)->getIdx({5, 0, 100, 100})) && (res1[0].second == std::vector<DimSize_t>({1, 1, 8, 8})))); + REQUIRE(((res1[0].first == std::vector<DimSize_t>({5, 0, 100, 100})) && (res1[0].second == std::vector<DimSize_t>({1, 1, 8, 8})))); } } } // namespace Aidge \ No newline at end of file diff --git a/unit_tests/operator/Test_Conv_Op.cpp b/unit_tests/operator/Test_Conv_Op.cpp index a3e84999eb2e2a31f1217330ac9718f35b0ca396..bc24fc8081d78dedf853450ff648b6d91b47c1dc 100644 --- a/unit_tests/operator/Test_Conv_Op.cpp +++ b/unit_tests/operator/Test_Conv_Op.cpp @@ -45,22 +45,22 @@ TEST_CASE("[core/operator] Conv_Op(computeReceptiveField)", "[Operator][computeR auto op4 = std::dynamic_pointer_cast<OperatorTensor>(conv4 -> getOperator()); SECTION("Check individual receptive fields") { - auto res1 = op1 -> computeReceptiveField(0, {16,32,10,10}); - auto res2 = op2 -> computeReceptiveField(op2 -> getOutput(0)->getIdx({3,20,100,28}), {4,20,30,40}); - auto res3 = op3 -> computeReceptiveField(0, {1,1,109,109}); - auto res4 = op4 -> computeReceptiveField(op4 -> getOutput(0)->getIdx({5,0,108,108}), {10,10,1,1}); + auto res1 = op1 -> computeReceptiveField({0,0,0,0}, {16,32,10,10}); + auto res2 = op2 -> computeReceptiveField({3,20,100,28}, {4,20,30,40}); + auto res3 = op3 -> computeReceptiveField({0,0,0,0}, {1,1,109,109}); + auto res4 = op4 -> computeReceptiveField({5,0,108,108}, {10,10,1,1}); - REQUIRE(((res1[0].first == 0) && (res1[0].second == std::vector<DimSize_t>({16, 3, 14, 14})))); - REQUIRE(((res1[1].first == 0) && (res1[1].second == std::vector<DimSize_t>({32, 3, 5, 5})))); - REQUIRE(((res1[2].first == 0) && (res1[2].second == std::vector<DimSize_t>({32})))); - REQUIRE(((res2[0].first == op2->getInput(0)->getIdx({3,0,100,28})) && (res2[0].second == std::vector<DimSize_t>({4, 32, 32, 42})))); - REQUIRE(((res3[0].first == 0) && (res3[0].second == std::vector<DimSize_t>({1, 64, 218, 218})))); - REQUIRE(((res4[0].first == op4->getInput(0)->getIdx({5, 0, 108, 108})) && (res4[0].second == std::vector<DimSize_t>({10, 10, 1, 1})))); + REQUIRE(((res1[0].first == std::vector<DimSize_t>({0,0,0,0})) && (res1[0].second == std::vector<DimSize_t>({16, 3, 14, 14})))); + REQUIRE(((res1[1].first == std::vector<DimSize_t>({0,0,0,0})) && (res1[1].second == std::vector<DimSize_t>({32, 3, 5, 5})))); + REQUIRE(((res1[2].first == std::vector<DimSize_t>({0})) && (res1[2].second == std::vector<DimSize_t>({32})))); + REQUIRE(((res2[0].first == std::vector<DimSize_t>({3,0,100,28})) && (res2[0].second == std::vector<DimSize_t>({4, 32, 32, 42})))); + REQUIRE(((res3[0].first == std::vector<DimSize_t>({0,0,0,0})) && (res3[0].second == std::vector<DimSize_t>({1, 64, 218, 218})))); + REQUIRE(((res4[0].first == std::vector<DimSize_t>({5, 0, 108, 108})) && (res4[0].second == std::vector<DimSize_t>({10, 10, 1, 1})))); } SECTION("Check receptive field propagation") { // input: first-{5, 0, 50, 50} dims-{1, 1, 1, 1} - auto res4 = op4->computeReceptiveField(op4->getOutput(0)->getIdx({5,0,50,50}), {1,1,1,1}); + auto res4 = op4->computeReceptiveField({5,0,50,50}, {1,1,1,1}); // conv4 RF: first-{5, 0, 50, 50} dims-{1, 10, 1, 1} auto res3 = op3->computeReceptiveField(res4[0].first, res4[0].second); // conv3 RF: first-{5, 0, 100, 100} dims-{1, 64, 2, 2} @@ -69,7 +69,7 @@ TEST_CASE("[core/operator] Conv_Op(computeReceptiveField)", "[Operator][computeR auto res1 = op1->computeReceptiveField(res2[0].first, res2[0].second); // conv1 RF: first-{5, 0, 100, 100} dims-{1, 3, 8, 8} - REQUIRE(((res1[0].first == op1->getInput(0)->getIdx({5, 0, 100, 100})) && (res1[0].second == std::vector<DimSize_t>({1, 3, 8, 8})))); + REQUIRE(((res1[0].first == std::vector<DimSize_t>({5, 0, 100, 100})) && (res1[0].second == std::vector<DimSize_t>({1, 3, 8, 8})))); // std::cout << "conv1: {"; diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index ef0c4e7f72d3148eccb97896a3d6e3d5ae5ad6e1..68e2d4d4d5b4fe1b40f83c087eb61c7865d3db75 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -14,6 +14,7 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Testing.hpp" #include <cstddef> using namespace Aidge; @@ -26,8 +27,8 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { REQUIRE(microGraph->getNodes().size() == 2); REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias) - // Order not garanteed by the GraphView - //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad"); + REQUIRE(nodePtrTo(microGraph->getOrderedInputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"Pad", 0}, {"Conv", 1}, {"Conv", 2}})); + REQUIRE(nodePtrTo(microGraph->getOrderedOutputs()) == std::vector<std::pair<std::string, IOIndex_t>>({{"Conv", 0}})); REQUIRE(microGraph->outputNodes().size() == 1); REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); REQUIRE(op->nbInputs() == 3); @@ -43,8 +44,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { REQUIRE(opTensor->outputDimsForwarded()); REQUIRE(std::static_pointer_cast<Tensor>(opTensor->getRawOutput(0))->dims() == std::vector<size_t>({2,3,5,5})); REQUIRE(std::static_pointer_cast<Tensor>(opTensor->getRawInput(0)) == myInput); - // Order not garanteed by the GraphView - //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getRawInput(0) == myInput); + REQUIRE(microGraph->getOrderedInputs()[0].first->getOperator()->getRawInput(0) == myInput); REQUIRE(opTensor->getRawOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getRawOutput(0)); //op->getOperator()->updateConsummerProducer(); // require implementation diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp deleted file mode 100644 index 5d9c02d5582e3c56aba9d374d7087946c7d94bde..0000000000000000000000000000000000000000 --- a/unit_tests/recipies/Test_FuseBatchNorm.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ -/* -#include <catch2/catch_test_macros.hpp> -#include <set> - - -//#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" -//#include "aidge/backend/cpu/operator/ConvImpl.hpp" - - - -#include "aidge/operator/Conv.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/graph/OpArgs.hpp" -#include "aidge/operator/BatchNorm.hpp" -#include "aidge/utils/Recipies.hpp" - -//#include "aidge/backend/TensorImpl.hpp" -//#include "aidge/backend/cpu.hpp" -//#include "aidge/" - -#include <cstddef> - - -namespace Aidge { - - - TEST_CASE("[FuseBatchNorm] conv") { - auto g1 = Sequential({ - Producer({16, 3, 224, 224}, "dataProvider"), - Conv(3, 32, {3, 3}, "conv1"), - BatchNorm<2>() - }); - - g1->setDataType(DataType::Float32); - g1->setBackend("cpu"); - g1->forwardDims(); - - // std::set<std::string> availableBackends = Tensor::getAvailableBackends(); - // if (availableBackends.find("cpu") != availableBackends.end()){ - // g1->setBackend("cpu"); - // newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); - // }else{ - // printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); - // } - - fuseBatchNorm(g1); - - SECTION("Check resulting nodes") { - // REQUIRE(g1->getNodes().size() == 2); - // REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); - // REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); - // REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); - // REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); - } - } - -} -*/ \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index 0c65db98917e33a11f4b7bac678b271b1a10fb94..968826230dfdf85290ee377aee155e06855c4b28 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -61,7 +61,6 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // Transform GraphView inplace fuseMulAdd(g); - g->save("bonjour"); // Check new GraphView std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); diff --git a/unit_tests/recipies/Test_removeFlatten.cpp b/unit_tests/recipies/Test_removeFlatten.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d0ff29dae19ba2dd8009441c39da53bf44378f0 --- /dev/null +++ b/unit_tests/recipies/Test_removeFlatten.cpp @@ -0,0 +1,49 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <set> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/recipies/Recipies.hpp" + +namespace Aidge { + + +TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { + // generate the original GraphView + auto flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten"); + auto fc = FC(10, 50, "myFC"); + + flatten -> addChild(fc); + + auto g = std::make_shared<GraphView>(); + g->add({fc, flatten}); + + // Check original graph + // g -> save("before_remove_flatten"); + + // use recipie + removeFlatten(g); + + // Check transformed graph + // g -> save("after_remove_flatten"); + + REQUIRE(g->getOrderedInputs().size() == 1); + REQUIRE(g->getOrderedOutputs().size() == 1); + REQUIRE(g->getOrderedInputs()[0].first == fc); + REQUIRE(g->getOrderedOutputs()[0].first == fc); +} + +} // namespace Aidge \ No newline at end of file