Skip to content
Snippets Groups Projects
Commit e3813caf authored by Maxence Naud's avatar Maxence Naud
Browse files

UPD 'Tensor::toString()'

- Add support for precision parameter for floating point numbers in 'toString()' and Tensor formatter
- Add precision parameter to 'toString()' virtual function in Data
- Select a specific formatter function for Tensor elements once instead of selecting one for each element of the Tensor
parent 9084a9b2
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!307[UPD] Tensor formatting
Pipeline #63318 passed
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
return mType; return mType;
} }
virtual ~Data() = default; virtual ~Data() = default;
virtual std::string toString() const = 0; virtual std::string toString(int precision = -1) const = 0;
private: private:
const std::string mType; const std::string mType;
......
...@@ -642,7 +642,7 @@ public: ...@@ -642,7 +642,7 @@ public:
set<expectedType>(getStorageIdx(coordIdx), value); set<expectedType>(getStorageIdx(coordIdx), value);
} }
std::string toString() const override; std::string toString(int precision = -1) const override;
inline void print() const { fmt::print("{}\n", toString()); } inline void print() const { fmt::print("{}\n", toString()); }
...@@ -981,14 +981,34 @@ private: ...@@ -981,14 +981,34 @@ private:
template<> template<>
struct fmt::formatter<Aidge::Tensor> { struct fmt::formatter<Aidge::Tensor> {
// Only stores override precision from format string
int precision_override = -1;
template<typename ParseContext> template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) { constexpr auto parse(ParseContext& ctx) {
return ctx.begin(); auto it = ctx.begin();
if (it != ctx.end() && *it == '.') {
++it;
if (it != ctx.end() && *it >= '0' && *it <= '9') {
precision_override = 0;
do {
precision_override = precision_override * 10 + (*it - '0');
++it;
} while (it != ctx.end() && *it >= '0' && *it <= '9');
}
}
if (it != ctx.end() && *it == 'f') {
++it;
}
return it;
} }
template<typename FormatContext> template<typename FormatContext>
inline auto format(Aidge::Tensor const& t, FormatContext& ctx) const { auto format(Aidge::Tensor const& t, FormatContext& ctx) const {
return fmt::format_to(ctx.out(), "{}", t.toString()); // Use precision_override if specified, otherwise toString will use default
return fmt::format_to(ctx.out(), "{}", t.toString(precision_override));
} }
}; };
......
...@@ -312,141 +312,160 @@ void Tensor::resize(const std::vector<DimSize_t>& dims, ...@@ -312,141 +312,160 @@ void Tensor::resize(const std::vector<DimSize_t>& dims,
} }
} }
std::string Tensor::toString() const { std::string Tensor::toString(int precision) const {
if (!hasImpl() || undefined()) { if (!hasImpl() || undefined()) {
// Return no value on no implementation or undefined size return "{}";
return std::string("{}");
} }
// TODO: move lambda elsewhere? // Use default precision if no override provided
auto ptrToString = [](DataType dt, void* ptr, std::size_t idx) { precision = (precision >= 0) ? precision : Log::getPrecision();
switch (dt) {
case DataType::Float64: // Create a type-specific formatter function upfront
return std::to_string(static_cast<double*>(ptr)[idx]); std::function<std::string(void*, std::size_t)> formatter;
case DataType::Float32:
return std::to_string(static_cast<float*>(ptr)[idx]); switch (mDataType) {
case DataType::Float16: case DataType::Float64:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]); formatter = [precision](void* ptr, std::size_t idx) {
case DataType::Binary: return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float64>*>(ptr)[idx], precision);
return std::to_string(static_cast<int8_t*>(ptr)[idx]); };
case DataType::Octo_Binary: break;
return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Float32:
case DataType::Dual_Int4: formatter = [precision](void* ptr, std::size_t idx) {
return std::to_string(static_cast<int8_t*>(ptr)[idx]); return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float32>*>(ptr)[idx], precision);
case DataType::Dual_UInt4: };
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); break;
case DataType::Dual_Int3: case DataType::Float16:
return std::to_string(static_cast<int8_t*>(ptr)[idx]); formatter = [precision](void* ptr, std::size_t idx) {
case DataType::Dual_UInt3: return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float16>*>(ptr)[idx], precision);
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); };
case DataType::Quad_Int2: break;
return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Binary:
case DataType::Quad_UInt2: case DataType::Octo_Binary:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); case DataType::Dual_Int4:
case DataType::Int4: case DataType::Dual_Int3:
return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Dual_UInt3:
case DataType::UInt4: case DataType::Quad_Int2:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); case DataType::Quad_UInt2:
case DataType::Int3: case DataType::Int4:
return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::UInt4:
case DataType::UInt3: case DataType::Int3:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); case DataType::UInt3:
case DataType::Int2: case DataType::Int2:
return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::UInt2:
case DataType::UInt2: case DataType::Int8:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); formatter = [](void* ptr, std::size_t idx) {
case DataType::Int8: return fmt::format("{}", static_cast<cpptype_t<DataType::Int32>>(static_cast<cpptype_t<DataType::Int8>*>(ptr)[idx]));
return std::to_string(static_cast<int8_t*>(ptr)[idx]); };
case DataType::Int16: break;
return std::to_string(static_cast<int16_t*>(ptr)[idx]); case DataType::Dual_UInt4:
case DataType::Int32: case DataType::UInt8:
return std::to_string(static_cast<int32_t*>(ptr)[idx]); formatter = [](void* ptr, std::size_t idx) {
case DataType::Int64: return fmt::format("{}", static_cast<cpptype_t<DataType::UInt32>>(static_cast<cpptype_t<DataType::UInt8>*>(ptr)[idx]));
return std::to_string(static_cast<int64_t*>(ptr)[idx]); };
case DataType::UInt8: break;
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); case DataType::Int16:
case DataType::UInt16: formatter = [](void* ptr, std::size_t idx) {
return std::to_string(static_cast<uint16_t*>(ptr)[idx]); return fmt::format("{}", static_cast<cpptype_t<DataType::Int16>*>(ptr)[idx]);
case DataType::UInt32: };
return std::to_string(static_cast<uint32_t*>(ptr)[idx]); break;
case DataType::UInt64: case DataType::Int32:
return std::to_string(static_cast<uint64_t*>(ptr)[idx]); formatter = [](void* ptr, std::size_t idx) {
default: return fmt::format("{}", static_cast<cpptype_t<DataType::Int32>*>(ptr)[idx]);
AIDGE_ASSERT(true, "unsupported type to convert to string"); };
} break;
return std::string("?"); // To make Clang happy case DataType::Int64:
}; formatter = [](void* ptr, std::size_t idx) {
return fmt::format("{}", static_cast<cpptype_t<DataType::Int64>*>(ptr)[idx]);
};
break;
case DataType::UInt16:
formatter = [](void* ptr, std::size_t idx) {
return fmt::format("{}", static_cast<cpptype_t<DataType::UInt16>*>(ptr)[idx]);
};
break;
case DataType::UInt32:
formatter = [](void* ptr, std::size_t idx) {
return fmt::format("{}", static_cast<cpptype_t<DataType::UInt32>*>(ptr)[idx]);
};
break;
case DataType::UInt64:
formatter = [](void* ptr, std::size_t idx) {
return fmt::format("{}", static_cast<cpptype_t<DataType::UInt64>*>(ptr)[idx]);
};
break;
default:
AIDGE_ASSERT(true, "unsupported type to convert to string");
return "{}";
}
if (dims().empty()) { if (dims().empty()) {
// The Tensor is defined with rank 0, hence scalar return formatter(mImpl->hostPtr(), 0);
return ptrToString(mDataType, mImpl->hostPtr(), 0);
} }
std::string res; void* dataPtr = mImpl->hostPtr(mImplOffset);
std::size_t dim = 0; std::string result;
std::size_t counter = 0;
if (nbDims() >= 2) { if (nbDims() >= 2) {
std::vector<std::size_t> dimVals(nbDims(), 0); std::vector<std::size_t> currentDim(nbDims(), 0);
res += "{\n"; std::size_t depth = 0;
std::size_t counter = 0;
result = "{\n";
while (counter < mSize) { while (counter < mSize) {
std::string spaceString = std::string((dim + 1) << 1, ' '); // Create indent string directly
if (dim < nbDims() - 2) { std::string indent((depth + 1) * 2, ' ');
if (dimVals[dim] == 0) {
res += spaceString + "{\n"; if (depth < nbDims() - 2) {
++dim; if (currentDim[depth] == 0) {
} else if (dimVals[dim] < result += indent + "{\n";
static_cast<std::size_t>(dims()[dim])) { ++depth;
res += spaceString + "},\n" + spaceString + "{\n"; } else if (currentDim[depth] < static_cast<std::size_t>(dims()[depth])) {
++dim; result += indent + "},\n" + indent + "{\n";
++depth;
} else { } else {
res += spaceString + "}\n"; result += indent + "}\n";
dimVals[dim--] = 0; currentDim[depth--] = 0;
dimVals[dim]++; ++currentDim[depth];
} }
} else { } else {
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); for (; currentDim[depth] < static_cast<std::size_t>(dims()[depth]); ++currentDim[depth]) {
++dimVals[dim]) { result += indent + "{";
res += spaceString + "{";
for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { for (DimSize_t j = 0; j < dims()[depth + 1]; ++j) {
res += result += " " + formatter(dataPtr, counter++);
" " + if (j < dims()[depth + 1] - 1) {
ptrToString(mDataType, mImpl->hostPtr(mImplOffset), result += ",";
counter++) + }
",";
} }
res += " " +
ptrToString(mDataType, mImpl->hostPtr(mImplOffset), result += " }";
counter++) + if (currentDim[depth] < static_cast<std::size_t>(dims()[depth] - 1)) {
"}"; result += ",";
if (dimVals[dim] <
static_cast<std::size_t>(dims()[dim] - 1)) {
res += ",";
} }
res += "\n"; result += "\n";
} }
if (dim == 0) {
break; if (depth == 0) break;
} currentDim[depth--] = 0;
dimVals[dim--] = 0; ++currentDim[depth];
dimVals[dim]++;
} }
} }
if (nbDims() != 2) { // If nbDims == 2, parenthesis is already closed
for (int i = static_cast<int>(dim); i >= 0; --i) { if (nbDims() != 2) {
res += std::string((i + 1) << 1, ' ') + "}\n"; for (std::size_t i = depth + 1; i > 0;) {
result += std::string(((--i) + 1) * 2, ' ') + "}\n";
} }
} }
} else { } else {
res += "{"; result = "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) { for (DimSize_t j = 0; j < dims()[0]; ++j) {
res += " " + result += " " + formatter(dataPtr, j);
ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) + result += (j < dims()[0] - 1) ? "," : " ";
((j < dims()[0] - 1) ? "," : " ");
} }
} }
res += "}";
return res; result += "}";
return result;
} }
Tensor Tensor::extract( Tensor Tensor::extract(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment