Skip to content
Snippets Groups Projects
Commit 8ebc6ad5 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added utils functions

parent 6923a80c
No related branches found
No related tags found
1 merge request!69Draft: Add support for lower than 8-bits precision
Pipeline #80562 failed
#pragma once
#ifndef __AIDGE_EXPORT_CPP_ACTIVATION_UTILS_HPP__
#define __AIDGE_EXPORT_CPP_ACTIVATION_UTILS_HPP__
#include <type_traits>
#include "network/typedefs.hpp"
......@@ -50,3 +51,5 @@ Output_T activation_forward_value (Sum_T weightedSum,
return saturate<Output_T>(rescaling(weightedSum, output), 8 * sizeof(Output_T));
}
#endif
This diff is collapsed.
#pragma once
#ifndef __AIDGE_EXPORT_CPP_RESCALING_UTILS_HPP__
#define __AIDGE_EXPORT_CPP_RESCALING_UTILS_HPP__
// ---------------------------------------------------
// ----------------- Saturate Utils ------------------
......@@ -15,64 +16,176 @@ constexpr int64_t smlal(int32_t lhs, int32_t rhs,
}
// ---------------------------------------------------
// --------------- Scaling by Shifting ---------------
// ------------------- No Scaling --------------------
// ---------------------------------------------------
template<int SHIFT>
struct SingleShiftScaling {
struct NoScaling {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const {
return weightedSum;
}
};
struct FloatingPointScaling {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const
{
return (SHIFT != 0) ? ((weightedSum >> (SHIFT - 1)) + 1) >> 1 // Rounding
: weightedSum;
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const {
return round(weightedSum*mScaling);
}
// // Shift attribute
// static const int mShift = SHIFT;
// static const Scaling_T mScalingType = SingleShift;
double mScaling;
};
// // FP Attribute
// static const int32_t mScaling = 0;
// static const int64_t mFractionalBits = 0;
template<size_t SIZE>
struct FloatingPointScalingPerChannel {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t output) const {
return round(weightedSum * mScaling[output]);
}
double mScaling[SIZE];
};
// ---------------------------------------------------
// --------------- Fixed Point Scaling ---------------
// ---------------------------------------------------
template<size_t SIZE>
struct FloatingPointClippingAndScaling {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const {
Sum_T clipValue = weightedSum;
clipValue = (clipValue < Sum_T(0)) ?
Sum_T(0) : (clipValue > Sum_T(mClipping)) ?
Sum_T(mClipping) : clipValue;
return round(clipValue * mScaling);
}
double mScaling;
int32_t mClipping;
};
template<int64_t SHIFT, int32_t COEF>
template<size_t SIZE>
struct FloatingPointClippingAndScalingPerChannel {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t output) const {
Sum_T clipValue = weightedSum;
clipValue = (clipValue < Sum_T(0)) ?
Sum_T(0) : (clipValue > Sum_T(mClipping[output])) ?
Sum_T(mClipping[output]) : clipValue;
return round(clipValue * mScaling[output]);
}
double mScaling[SIZE];
int32_t mClipping[SIZE];
};
template<int32_t SCALING, int64_t FRACTIONAL_BITS>
struct FixedPointScaling {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const {
// Different rounding if weightesSum < 0
// if(weightedSum < 0) {
// HALF--;
// }
return smlal(weightedSum, SCALING, HALF_LO, HALF_HI) >> FRACTIONAL_BITS;
}
static const uint32_t HALF_LO = (FRACTIONAL_BITS > 0)
? (1ull << (FRACTIONAL_BITS - 1)) & 0xFFFFFFFF : 0;
static const uint32_t HALF_HI = (FRACTIONAL_BITS > 0)
? (1ull << (FRACTIONAL_BITS - 1)) >> 32u : 0;
};
template<size_t SIZE, int64_t FRACTIONAL_BITS>
struct FixedPointScalingScalingPerChannel {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const
{
return smlal(weightedSum, COEF, HALF_LO, HALF_HI) >> SHIFT;
Sum_T operator()(Sum_T weightedSum, size_t output) const {
// Different rounding if weightesSum < 0
// if(weightedSum < 0) {
// HALF--;
// }
return smlal(weightedSum, mScaling[output], HALF_LO, HALF_HI) >> FRACTIONAL_BITS;
}
// Attributes
static constexpr uint32_t HALF_LO = (SHIFT > 0)
? (1ull << (SHIFT - 1)) & 0xFFFFFFFF : 0;
static constexpr uint32_t HALF_HI = (SHIFT > 0)
? (1ull << (SHIFT - 1)) >> 32u : 0;
// static const int32_t mScaling = SCALING;
// static const int64_t mFractionalBits = FRACTIONAL_BITS;
// static const Scaling_T mScalingType = FixedPoint;
// static const int mShift = 0;
static const uint32_t HALF_LO = (FRACTIONAL_BITS > 0)
? (1ull << (FRACTIONAL_BITS - 1)) & 0xFFFFFFFF : 0;
static const uint32_t HALF_HI = (FRACTIONAL_BITS > 0)
? (1ull << (FRACTIONAL_BITS - 1)) >> 32u : 0;
int32_t mScaling[SIZE];
};
// ---------------------------------------------------
// ------------------- No Scaling --------------------
// ---------------------------------------------------
template<size_t SIZE, int64_t FRACTIONAL_BITS>
struct FixedPointClippingAndScalingPerChannel {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t output) const {
// Different rounding if weightesSum < 0
// if(weightedSum < 0) {
// HALF--;
// }
Sum_T clipValue = weightedSum;
clipValue = (clipValue < Sum_T(0)) ?
Sum_T(0) : (clipValue > Sum_T(mClipping[output])) ?
Sum_T(mClipping[output]) : clipValue;
return smlal(clipValue, mScaling[output], HALF_LO, HALF_HI) >> FRACTIONAL_BITS;
}
struct NoScaling {
static const uint32_t HALF_LO = (1ull << (FRACTIONAL_BITS - 1)) & 0xFFFFFFFF;
static const uint32_t HALF_HI = (1ull << (FRACTIONAL_BITS - 1)) >> 32u;
int32_t mScaling[SIZE];
int32_t mClipping[SIZE];
};
template<size_t SHIFT>
struct SingleShiftScaling {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, unsigned int /*output*/) const
{
return weightedSum;
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const {
return (SHIFT != 0) ? ((weightedSum >> (SHIFT - 1)) + 1) >> 1 // Rounding
: weightedSum;
}
};
template<size_t SIZE>
struct SingleShiftScalingPerChannel {
template<typename Sum_T>
Sum_T operator()(Sum_T weightedSum, size_t output) const {
return (mScaling[output] != 0) ? ((weightedSum >> (mScaling[output] - 1)) + 1) >> 1 // Rounding
: weightedSum;
}
unsigned char mScaling[SIZE];
};
template<size_t SHIFT1, size_t SHIFT2, typename Sum_T>
struct DoubleShiftScaling {
Sum_T operator()(Sum_T weightedSum, size_t /*output*/) const {
// Different rounding if weightesSum < 0
// if(weightedSum < 0) {
// HALF--;
// }
return (weightedSum + (weightedSum << SHIFT1) + HALF) >> SHIFT2;
}
static const Sum_T HALF = ((Sum_T) 1) << (SHIFT2 - 1);
};
template<size_t SIZE, bool UNSIGNED_WEIGHTED_SUM, typename Sum_T>
struct DoubleShiftScalingPerChannel {
Sum_T operator()(Sum_T weightedSum, size_t output) const {
const Sum_T SHIFT1 = mScaling[output][0];
const Sum_T SHIFT2 = mScaling[output][1];
const Sum_T HALF = mScaling[output][2];
// Different rounding if weightesSum < 0
// if(weightedSum < 0) {
// HALF--;
// }
return (weightedSum + (weightedSum << SHIFT1) + HALF) >> SHIFT2;
}
Sum_T mScaling[SIZE][3];
};
#endif
#ifndef __AIDGE_EXPORT_CPP_NETWORK_TYPEDEFS__
#define __AIDGE_EXPORT_CPP_NETWORK_TYPEDEFS__
#include <stdint.h>
#include <cmath>
#include <type_traits>
#include <limits>
#include <cstdint>
typedef enum {
Tanh,
......@@ -30,4 +33,509 @@ typedef enum {
} CoeffMode_T;
// ----------------------------------------------------------------------------
// -------------------- Generic custom bit-width types ------------------------
// ----------------------------------------------------------------------------
template <int BITWIDTH>
struct data {};
template <int BITWIDTH>
struct udata{};
namespace std {
// Specialization of STL, allows to use std::is_unsigned<> for example.
template <int BITWIDTH>
struct is_integral<data<BITWIDTH>>
: std::is_integral<decltype(data<BITWIDTH>::value)>::type {};
template <int BITWIDTH>
struct is_integral<udata<BITWIDTH>>
: std::is_integral<decltype(udata<BITWIDTH>::value)>::type {};
template <int BITWIDTH>
struct is_floating_point<data<BITWIDTH>>
: std::is_floating_point<decltype(data<BITWIDTH>::value)>::type {};
template <int BITWIDTH>
struct is_unsigned<data<BITWIDTH>>
: std::is_unsigned<decltype(data<BITWIDTH>::value)>::type {};
template <int BITWIDTH>
struct is_unsigned<udata<BITWIDTH>>
: std::is_unsigned<decltype(udata<BITWIDTH>::value)>::type {};
template <int BITWIDTH>
class numeric_limits<data<BITWIDTH>> {
public:
static constexpr int is_integer = (BITWIDTH > 0);
static constexpr int is_signed = true;
static constexpr int digits = std::abs(BITWIDTH);
static constexpr decltype(data<BITWIDTH>::value) min() noexcept
{ return (BITWIDTH > 0) ? -(1 << (BITWIDTH - 1)) :
std::numeric_limits<decltype(data<BITWIDTH>::value)>::min(); };
static constexpr decltype(data<BITWIDTH>::value) lowest() noexcept
{ return (BITWIDTH > 0) ? -(1 << (BITWIDTH - 1)) :
std::numeric_limits<decltype(data<BITWIDTH>::value)>::lowest(); };
static constexpr decltype(data<BITWIDTH>::value) max() noexcept
{ return (BITWIDTH > 0) ? ((1 << (BITWIDTH - 1)) - 1) :
std::numeric_limits<decltype(data<BITWIDTH>::value)>::max(); };
};
template <int BITWIDTH>
class numeric_limits<udata<BITWIDTH>> {
public:
static constexpr int is_integer = true;
static constexpr int is_signed = false;
static constexpr int digits = BITWIDTH;
static constexpr decltype(data<BITWIDTH>::value) min() noexcept
{ return 0; };
static constexpr decltype(data<BITWIDTH>::value) lowest() noexcept
{ return 0; };
static constexpr decltype(data<BITWIDTH>::value) max() noexcept
{ return ((1 << BITWIDTH) - 1); };
};
}
// ----------------------------------------------------------------------------
// -------------- Custom bit-width types operator overloading -----------------
// ----------------------------------------------------------------------------
// data
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH>& operator+=(data<BITWIDTH>& d, T rhs)
// {return d.value += decltype(data<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH> operator+(data<BITWIDTH> d, T rhs)
// {return d += rhs;}
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH>& operator-=(data<BITWIDTH>& d, T rhs)
// {return d.value -= decltype(data<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH> operator-(data<BITWIDTH> d, T rhs)
// {return d -= rhs;}
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH>& operator*=(data<BITWIDTH>& d, T rhs)
// {return d.value *= decltype(data<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH> operator*(data<BITWIDTH> d, T rhs)
// {return d *= rhs;}
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH>& operator/=(data<BITWIDTH>& d, T rhs)
// {return d.value /= decltype(data<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr data<BITWIDTH> operator/(data<BITWIDTH> d, T rhs)
// {return d /= rhs;}
// udata
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH>& operator+=(udata<BITWIDTH>& d, T rhs)
// {return d.value += decltype(udata<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH> operator+(udata<BITWIDTH> d, T rhs)
// {return d += rhs;}
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH>& operator-=(udata<BITWIDTH>& d, T rhs)
// {return d.value -= decltype(udata<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH> operator-(udata<BITWIDTH> d, T rhs)
// {return d -= rhs;}
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH>& operator*=(udata<BITWIDTH>& d, T rhs)
// {return d.value *= decltype(udata<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH> operator*(udata<BITWIDTH> d, T rhs)
// {return d *= rhs;}
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH>& operator/=(udata<BITWIDTH>& d, T rhs)
// {return d.value /= decltype(udata<BITWIDTH>::value)(rhs);}
// template<int BITWIDTH, typename T>
// constexpr udata<BITWIDTH> operator/(udata<BITWIDTH> d, T rhs)
// {return d /= rhs;}
// ----------------------------------------------------------------------------
// ---------------- Custom bit-width types specializations --------------------
// ----------------------------------------------------------------------------
// Data structure for double
template <>
struct data<-64>
{
data<-64>() = default;
constexpr data<-64>(double v): value(v) {};
constexpr operator double() const { return value; }
union {
double value;
};
};
// Data structure for float
template <>
struct data<-32>
{
data<-32>() = default;
constexpr data<-32>(float v): value(v) {};
constexpr operator float() const { return value; }
union {
float value;
};
};
// Data structure for half float
template <>
struct data<-16>
{
data<-16>() = default;
constexpr data<-16>(float v): value(v) {};
constexpr operator float() const { return value; }
union {
float value;
};
};
// Data structure for int32
template <>
struct data<32>
{
data<32>() = default;
constexpr data<32>(int32_t v): value(v) {};
constexpr operator int32_t() const { return value; }
union {
int32_t value;
};
};
// Data structure for uint32
template <>
struct udata<32>
{
udata<32>() = default;
constexpr udata<32>(uint32_t v): value(v) {};
constexpr operator uint32_t() const { return value; }
union {
uint32_t value;
};
};
// Data structure for int16
template <>
struct data<16>
{
data<16>() = default;
constexpr data<16>(int16_t v): value(v) {};
constexpr operator int16_t() const { return value; }
union {
int16_t value;
};
};
// Data structure for uint16
template <>
struct udata<16>
{
udata<16>() = default;
constexpr udata<16>(uint16_t v): value(v) {};
constexpr operator uint16_t() const { return value; }
union {
uint16_t value;
};
};
// Data structure for int8
template <>
struct data<8>
{
data<8>() = default;
constexpr data<8>(int8_t v): value(v) {};
constexpr operator int8_t() const { return value; }
union {
int8_t value;
};
};
// Data structure for uint8
template <>
struct udata<8>
{
udata<8>() = default;
constexpr udata<8>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
};
};
// Data structure for int7
template <>
struct data<7>
{
data<7>() = default;
constexpr data<7>(int8_t v): value(v) {};
constexpr operator int8_t() const { return value; }
union {
int8_t value;
};
};
// Data structure for uint7
template <>
struct udata<7>
{
udata<7>() = default;
constexpr udata<7>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
};
};
// Data structure for int6
template <>
struct data<6>
{
data<6>() = default;
constexpr data<6>(int8_t v): value(v) {};
constexpr operator int8_t() const { return value; }
union {
int8_t value;
};
};
// Data structure for uint6
template <>
struct udata<6>
{
udata<6>() = default;
constexpr udata<6>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
};
};
// Data structure for int5
template <>
struct data<5>
{
data<5>() = default;
constexpr data<5>(int8_t v): value(v) {};
constexpr operator int8_t() const { return value; }
union {
int8_t value;
};
};
// Data structure for uint5
template <>
struct udata<5>
{
udata<5>() = default;
constexpr udata<5>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
};
};
// Data structure for 2 * int4
template <>
struct data<4>
{
data<4>() = default;
constexpr data<4>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
int8_t value;
uint8_t uvalue;
struct
{
int8_t op0 : 4;
int8_t op1 : 4;
} fields;
};
};
// Data structure for 2 * uint4
template <>
struct udata<4>
{
udata<4>() = default;
constexpr udata<4>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
uint8_t uvalue;
struct
{
uint8_t op0 : 4;
uint8_t op1 : 4;
} fields;
};
};
// Data structure for 2 * int3
template <>
struct data<3>
{
data<3>() = default;
constexpr data<3>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
int8_t value;
uint8_t uvalue;
struct
{
int8_t op0 : 4;
int8_t op1 : 4;
} fields;
};
};
// Data structure for 2 * uint3
template <>
struct udata<3>
{
udata<3>() = default;
constexpr udata<3>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
uint8_t uvalue;
struct
{
uint8_t op0 : 4;
uint8_t op1 : 4;
} fields;
};
};
// Data structure for 4 * int2
template <>
struct data<2>
{
data<2>() = default;
constexpr data<2>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
int8_t value;
uint8_t uvalue;
struct
{
int8_t op0 : 2;
int8_t op1 : 2;
int8_t op2 : 2;
int8_t op3 : 2;
} fields;
};
};
// Data structure for 4 * uint2
template <>
struct udata<2>
{
udata<2>() = default;
constexpr udata<2>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
uint8_t uvalue;
struct
{
uint8_t op0 : 2;
uint8_t op1 : 2;
uint8_t op2 : 2;
uint8_t op3 : 2;
} fields;
};
};
// Data structure for 8 * int1
template <>
struct data<1>
{
data<1>() = default;
constexpr data<1>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
int8_t value;
uint8_t uvalue;
struct
{
int8_t op0 : 1;
int8_t op1 : 1;
int8_t op2 : 1;
int8_t op3 : 1;
int8_t op4 : 1;
int8_t op5 : 1;
int8_t op6 : 1;
int8_t op7 : 1;
} fields;
};
};
// Data structure for 8 * uint1
template <>
struct udata<1>
{
udata<1>() = default;
constexpr udata<1>(uint8_t v): value(v) {};
constexpr operator uint8_t() const { return value; }
union {
uint8_t value;
uint8_t uvalue;
struct
{
uint8_t op0 : 1;
uint8_t op1 : 1;
uint8_t op2 : 1;
uint8_t op3 : 1;
uint8_t op4 : 1;
uint8_t op5 : 1;
uint8_t op6 : 1;
uint8_t op7 : 1;
} fields;
};
};
// ----------------------------------------------------------------------------
// ------------------------- Structures and Unions ----------------------------
// ----------------------------------------------------------------------------
/* Object for compressing the outputs after mac operations */
typedef struct PackSupport {
uint8_t accumulator;
unsigned int cptAccumulator;
} PackSupport;
/* Union to access the data<32>/data<8>/data<4>/data<1> types */
union dataword
{
data<32> word;
data<8> bytes[4];
data<4> half_bytes[4];
data<1> bitfields[4];
};
/* Union to access the udata<32>/udata<8>/udata<4>/udata<1> types */
union udataword
{
udata<32> word;
udata<8> bytes[4];
udata<4> half_bytes[4];
udata<1> bitfields[4];
};
#endif // __AIDGE_EXPORT_CPP_NETWORK_TYPEDEFS__
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment