From e3813cafd5f9d0174fc6b15786644ec464e36090 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Sun, 19 Jan 2025 20:23:31 +0000 Subject: [PATCH] UPD 'Tensor::toString()' - Add support for precision parameter for floating point numbers in 'toString()' and Tensor formatter - Add precision parameter to 'toString()' virtual function in Data - Select a specific formatter function for Tensor elements once instead of selecting one for each element of the Tensor --- include/aidge/data/Data.hpp | 2 +- include/aidge/data/Tensor.hpp | 30 ++++- src/data/Tensor.cpp | 243 ++++++++++++++++++---------------- 3 files changed, 157 insertions(+), 118 deletions(-) diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 156e4d8c1..1ce3782e8 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -37,7 +37,7 @@ public: return mType; } virtual ~Data() = default; - virtual std::string toString() const = 0; + virtual std::string toString(int precision = -1) const = 0; private: const std::string mType; diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index c8df815bb..686657e94 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -642,7 +642,7 @@ public: set<expectedType>(getStorageIdx(coordIdx), value); } - std::string toString() const override; + std::string toString(int precision = -1) const override; inline void print() const { fmt::print("{}\n", toString()); } @@ -981,14 +981,34 @@ private: template<> struct fmt::formatter<Aidge::Tensor> { + // Only stores override precision from format string + int precision_override = -1; + template<typename ParseContext> - inline constexpr auto parse(ParseContext& ctx) { - return ctx.begin(); + constexpr auto parse(ParseContext& ctx) { + auto it = ctx.begin(); + if (it != ctx.end() && *it == '.') { + ++it; + if (it != ctx.end() && *it >= '0' && *it <= '9') { + precision_override = 0; + do { + precision_override = precision_override * 10 + (*it - '0'); + ++it; + } while (it != ctx.end() && *it >= '0' && *it <= '9'); + } + } + + if (it != ctx.end() && *it == 'f') { + ++it; + } + + return it; } template<typename FormatContext> - inline auto format(Aidge::Tensor const& t, FormatContext& ctx) const { - return fmt::format_to(ctx.out(), "{}", t.toString()); + auto format(Aidge::Tensor const& t, FormatContext& ctx) const { + // Use precision_override if specified, otherwise toString will use default + return fmt::format_to(ctx.out(), "{}", t.toString(precision_override)); } }; diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index a14ae4187..58e157f9f 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -312,141 +312,160 @@ void Tensor::resize(const std::vector<DimSize_t>& dims, } } -std::string Tensor::toString() const { - +std::string Tensor::toString(int precision) const { if (!hasImpl() || undefined()) { - // Return no value on no implementation or undefined size - return std::string("{}"); + return "{}"; } - // TODO: move lambda elsewhere? - auto ptrToString = [](DataType dt, void* ptr, std::size_t idx) { - switch (dt) { - case DataType::Float64: - return std::to_string(static_cast<double*>(ptr)[idx]); - case DataType::Float32: - return std::to_string(static_cast<float*>(ptr)[idx]); - case DataType::Float16: - return std::to_string(static_cast<half_float::half*>(ptr)[idx]); - case DataType::Binary: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Octo_Binary: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Dual_Int4: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Dual_UInt4: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Dual_Int3: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Dual_UInt3: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Quad_Int2: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Quad_UInt2: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int4: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::UInt4: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int3: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::UInt3: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int2: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::UInt2: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int8: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Int16: - return std::to_string(static_cast<int16_t*>(ptr)[idx]); - case DataType::Int32: - return std::to_string(static_cast<int32_t*>(ptr)[idx]); - case DataType::Int64: - return std::to_string(static_cast<int64_t*>(ptr)[idx]); - case DataType::UInt8: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::UInt16: - return std::to_string(static_cast<uint16_t*>(ptr)[idx]); - case DataType::UInt32: - return std::to_string(static_cast<uint32_t*>(ptr)[idx]); - case DataType::UInt64: - return std::to_string(static_cast<uint64_t*>(ptr)[idx]); - default: - AIDGE_ASSERT(true, "unsupported type to convert to string"); - } - return std::string("?"); // To make Clang happy - }; + // Use default precision if no override provided + precision = (precision >= 0) ? precision : Log::getPrecision(); + + // Create a type-specific formatter function upfront + std::function<std::string(void*, std::size_t)> formatter; + + switch (mDataType) { + case DataType::Float64: + formatter = [precision](void* ptr, std::size_t idx) { + return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float64>*>(ptr)[idx], precision); + }; + break; + case DataType::Float32: + formatter = [precision](void* ptr, std::size_t idx) { + return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float32>*>(ptr)[idx], precision); + }; + break; + case DataType::Float16: + formatter = [precision](void* ptr, std::size_t idx) { + return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float16>*>(ptr)[idx], precision); + }; + break; + case DataType::Binary: + case DataType::Octo_Binary: + case DataType::Dual_Int4: + case DataType::Dual_Int3: + case DataType::Dual_UInt3: + case DataType::Quad_Int2: + case DataType::Quad_UInt2: + case DataType::Int4: + case DataType::UInt4: + case DataType::Int3: + case DataType::UInt3: + case DataType::Int2: + case DataType::UInt2: + case DataType::Int8: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int32>>(static_cast<cpptype_t<DataType::Int8>*>(ptr)[idx])); + }; + break; + case DataType::Dual_UInt4: + case DataType::UInt8: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt32>>(static_cast<cpptype_t<DataType::UInt8>*>(ptr)[idx])); + }; + break; + case DataType::Int16: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int16>*>(ptr)[idx]); + }; + break; + case DataType::Int32: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int32>*>(ptr)[idx]); + }; + break; + case DataType::Int64: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int64>*>(ptr)[idx]); + }; + break; + case DataType::UInt16: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt16>*>(ptr)[idx]); + }; + break; + case DataType::UInt32: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt32>*>(ptr)[idx]); + }; + break; + case DataType::UInt64: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt64>*>(ptr)[idx]); + }; + break; + default: + AIDGE_ASSERT(true, "unsupported type to convert to string"); + return "{}"; + } if (dims().empty()) { - // The Tensor is defined with rank 0, hence scalar - return ptrToString(mDataType, mImpl->hostPtr(), 0); + return formatter(mImpl->hostPtr(), 0); } - std::string res; - std::size_t dim = 0; - std::size_t counter = 0; + void* dataPtr = mImpl->hostPtr(mImplOffset); + std::string result; + if (nbDims() >= 2) { - std::vector<std::size_t> dimVals(nbDims(), 0); - res += "{\n"; + std::vector<std::size_t> currentDim(nbDims(), 0); + std::size_t depth = 0; + std::size_t counter = 0; + + result = "{\n"; while (counter < mSize) { - std::string spaceString = std::string((dim + 1) << 1, ' '); - if (dim < nbDims() - 2) { - if (dimVals[dim] == 0) { - res += spaceString + "{\n"; - ++dim; - } else if (dimVals[dim] < - static_cast<std::size_t>(dims()[dim])) { - res += spaceString + "},\n" + spaceString + "{\n"; - ++dim; + // Create indent string directly + std::string indent((depth + 1) * 2, ' '); + + if (depth < nbDims() - 2) { + if (currentDim[depth] == 0) { + result += indent + "{\n"; + ++depth; + } else if (currentDim[depth] < static_cast<std::size_t>(dims()[depth])) { + result += indent + "},\n" + indent + "{\n"; + ++depth; } else { - res += spaceString + "}\n"; - dimVals[dim--] = 0; - dimVals[dim]++; + result += indent + "}\n"; + currentDim[depth--] = 0; + ++currentDim[depth]; } } else { - for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); - ++dimVals[dim]) { - res += spaceString + "{"; - for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { - res += - " " + - ptrToString(mDataType, mImpl->hostPtr(mImplOffset), - counter++) + - ","; + for (; currentDim[depth] < static_cast<std::size_t>(dims()[depth]); ++currentDim[depth]) { + result += indent + "{"; + + for (DimSize_t j = 0; j < dims()[depth + 1]; ++j) { + result += " " + formatter(dataPtr, counter++); + if (j < dims()[depth + 1] - 1) { + result += ","; + } } - res += " " + - ptrToString(mDataType, mImpl->hostPtr(mImplOffset), - counter++) + - "}"; - if (dimVals[dim] < - static_cast<std::size_t>(dims()[dim] - 1)) { - res += ","; + + result += " }"; + if (currentDim[depth] < static_cast<std::size_t>(dims()[depth] - 1)) { + result += ","; } - res += "\n"; + result += "\n"; } - if (dim == 0) { - break; - } - dimVals[dim--] = 0; - dimVals[dim]++; + + if (depth == 0) break; + currentDim[depth--] = 0; + ++currentDim[depth]; } } - if (nbDims() != 2) { // If nbDims == 2, parenthesis is already closed - for (int i = static_cast<int>(dim); i >= 0; --i) { - res += std::string((i + 1) << 1, ' ') + "}\n"; + + if (nbDims() != 2) { + for (std::size_t i = depth + 1; i > 0;) { + result += std::string(((--i) + 1) * 2, ' ') + "}\n"; } } } else { - res += "{"; + result = "{"; for (DimSize_t j = 0; j < dims()[0]; ++j) { - res += " " + - ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) + - ((j < dims()[0] - 1) ? "," : " "); + result += " " + formatter(dataPtr, j); + result += (j < dims()[0] - 1) ? "," : " "; } } - res += "}"; - return res; + + result += "}"; + return result; } Tensor Tensor::extract( -- GitLab