diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 156e4d8c14510d52780b512a03d95c7c01d25e53..fac8e7fb4ac0df25caedca1862939444cf7c4391 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -12,6 +12,7 @@ #ifndef AIDGE_DATA_H_ #define AIDGE_DATA_H_ +#include <cstddef> // std::size_t #include <string> #include "aidge/utils/ErrorHandling.hpp" @@ -37,7 +38,7 @@ public: return mType; } virtual ~Data() = default; - virtual std::string toString() const = 0; + virtual std::string toString(int precision = -1, std::size_t offset = 0) const = 0; private: const std::string mType; diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index c8df815bbe294e90e86f125a804bde2db82a739b..7aa2ed52b95e11598a2975558212b00a85dac598 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -642,7 +642,7 @@ public: set<expectedType>(getStorageIdx(coordIdx), value); } - std::string toString() const override; + std::string toString(int precision = -1, std::size_t offset = 0) const override; inline void print() const { fmt::print("{}\n", toString()); } @@ -981,14 +981,34 @@ private: template<> struct fmt::formatter<Aidge::Tensor> { + // Only stores override precision from format string + int precision_override = -1; + template<typename ParseContext> - inline constexpr auto parse(ParseContext& ctx) { - return ctx.begin(); + constexpr auto parse(ParseContext& ctx) { + auto it = ctx.begin(); + if (it != ctx.end() && *it == '.') { + ++it; + if (it != ctx.end() && *it >= '0' && *it <= '9') { + precision_override = 0; + do { + precision_override = precision_override * 10 + (*it - '0'); + ++it; + } while (it != ctx.end() && *it >= '0' && *it <= '9'); + } + } + + if (it != ctx.end() && *it == 'f') { + ++it; + } + + return it; } template<typename FormatContext> - inline auto format(Aidge::Tensor const& t, FormatContext& ctx) const { - return fmt::format_to(ctx.out(), "{}", t.toString()); + auto format(Aidge::Tensor const& t, FormatContext& ctx) const { + // Use precision_override if specified, otherwise toString will use default + return fmt::format_to(ctx.out(), "Tensor({})", t.toString(precision_override, 7)); } }; diff --git a/include/aidge/utils/Log.hpp b/include/aidge/utils/Log.hpp index 91619b15be572d77b83e339a3aac92b4576600ce..ca16018db472b7daa6d6ff5f248420b6ab1f14e9 100644 --- a/include/aidge/utils/Log.hpp +++ b/include/aidge/utils/Log.hpp @@ -216,6 +216,30 @@ public: */ static void setFileName(const std::string& fileName); + /** + * @brief Set the precision format for floating point numbers. + * @param precision number of digits displayed on the right-hand of the + * decimal point. + */ + static void setPrecision(int precision) noexcept { + if (precision < 0) { + Log::notice("Impossible to set precision to {}. Must be a positive number.", precision); + return; + } + mFloatingPointPrecision = precision; +#ifdef PYBIND +#define _CRT_SECURE_NO_WARNINGS + if (Py_IsInitialized()){ + // Note: Setting mFloatingPointPrecision is important + // to avoid garbage collection of the pointer. + py::set_shared_data("floatingPointPrecision", &mFloatingPointPrecision); + } +#endif // PYBIND + } + static int getPrecision() noexcept { + return mFloatingPointPrecision; + } + private: static void log(Level level, const std::string& msg); static void initFile(const std::string& fileName); @@ -230,6 +254,8 @@ private: static std::string mFileName; ///< Path to log file static std::unique_ptr<FILE, fcloseDeleter> mFile; ///< File handle static std::vector<std::string> mContext; ///< Stack of active contexts + + static int mFloatingPointPrecision; }; } // namespace Aidge diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index 02a692dea47a1ba270df5b7c710db0c07da2043a..52e773b6962c5de8eb7499be03c4f2deb0be24c9 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -10,65 +10,14 @@ ********************************************************************************/ #include <pybind11/pybind11.h> -#include <pybind11/stl.h> #include "aidge/data/Data.hpp" -#include "aidge/data/DataType.hpp" -#include "aidge/data/DataFormat.hpp" namespace py = pybind11; namespace Aidge { -template <class T> -void bindEnum(py::module& m, const std::string& name) { - // Define enumeration names for python as lowercase type name - // This defined enum names compatible with basic numpy type - // name such as: float32, flot64, [u]int32, [u]int64, ... - auto python_enum_name = [](const T& type) { - auto str_lower = [](std::string& str) { - std::transform(str.begin(), str.end(), str.begin(), - [](unsigned char c){ - return std::tolower(c); - }); - }; - auto type_name = std::string(Aidge::format_as(type)); - str_lower(type_name); - return type_name; - }; - // Auto generate enumeration names from lowercase type strings - std::vector<std::string> enum_names; - for (auto type_str : EnumStrings<T>::data) { - auto type = static_cast<T>(enum_names.size()); - auto enum_name = python_enum_name(type); - enum_names.push_back(enum_name); - } - - // Define python side enumeration aidge_core.type - auto e_type = py::enum_<T>(m, name.c_str()); - - // Add enum value for each enum name - for (std::size_t idx = 0; idx < enum_names.size(); idx++) { - e_type.value(enum_names[idx].c_str(), static_cast<T>(idx)); - } - - // Define str() to return the bare enum name value, it allows - // to compare directly for instance str(tensor.type()) - // with str(nparray.type) - e_type.def("__str__", [enum_names](const T& type) { - return enum_names[static_cast<int>(type)]; - }, py::prepend());; -} - void init_Data(py::module& m){ - bindEnum<DataType>(m, "dtype"); - bindEnum<DataFormat>(m, "dformat"); - py::class_<Data, std::shared_ptr<Data>>(m,"Data"); - - - m.def("format_as", (const char* (*)(DataType)) &format_as, py::arg("dt")); - m.def("format_as", (const char* (*)(DataFormat)) &format_as, py::arg("df")); - m.def("get_data_format_transpose", &getDataFormatTranspose, py::arg("src"), py::arg("dst")); - -} } + +} // namespace Aidge diff --git a/python_binding/data/pybind_DataFormat.cpp b/python_binding/data/pybind_DataFormat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a63df321c3298284df7de8fd2c3eb0fc0cecae24 --- /dev/null +++ b/python_binding/data/pybind_DataFormat.cpp @@ -0,0 +1,72 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> // std::transform +#include <cctype> // std::tolower +#include <string> // std::string +#include <vector> + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/data/DataFormat.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <class T> +void bindEnum(py::module& m, const std::string& name) { + // Define enumeration names for python as lowercase type name + // This defined enum names compatible with basic numpy type + // name such as: float32, flot64, [u]int32, [u]int64, ... + auto python_enum_name = [](const T& type) { + auto str_lower = [](std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c){ + return std::tolower(c); + }); + }; + auto type_name = std::string(Aidge::format_as(type)); + str_lower(type_name); + return type_name; + }; + + // Auto generate enumeration names from lowercase type strings + std::vector<std::string> enum_names; + for (auto type_str : EnumStrings<T>::data) { + auto type = static_cast<T>(enum_names.size()); + auto enum_name = python_enum_name(type); + enum_names.push_back(enum_name); + } + + // Define python side enumeration aidge_core.type + auto e_type = py::enum_<T>(m, name.c_str()); + + // Add enum value for each enum name + for (std::size_t idx = 0; idx < enum_names.size(); idx++) { + e_type.value(enum_names[idx].c_str(), static_cast<T>(idx)); + } + + // Define str() to return the bare enum name value, it allows + // to compare directly for instance str(tensor.type()) + // with str(nparray.type) + e_type.def("__str__", [enum_names](const T& type) { + return enum_names[static_cast<int>(type)]; + }, py::prepend()); +} + +void init_DataFormat(py::module& m) { + bindEnum<DataFormat>(m, "dformat"); + m.def("format_as", (const char* (*)(DataFormat)) &format_as, py::arg("df")); + m.def("get_data_format_transpose", &getDataFormatTranspose, py::arg("src"), py::arg("dst")); +} + +} // namespace Aidge diff --git a/python_binding/data/pybind_DataType.cpp b/python_binding/data/pybind_DataType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6aab399760c52b3886cddb0fb26c2b649252e7bb --- /dev/null +++ b/python_binding/data/pybind_DataType.cpp @@ -0,0 +1,71 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> // std::transform +#include <cctype> // std::tolower +#include <string> // std::string +#include <vector> + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/data/DataType.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <class T> +void bindEnum(py::module& m, const std::string& name) { + // Define enumeration names for python as lowercase type name + // This defined enum names compatible with basic numpy type + // name such as: float32, flot64, [u]int32, [u]int64, ... + auto python_enum_name = [](const T& type) { + auto str_lower = [](std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c){ + return std::tolower(c); + }); + }; + auto type_name = std::string(Aidge::format_as(type)); + str_lower(type_name); + return type_name; + }; + + // Auto generate enumeration names from lowercase type strings + std::vector<std::string> enum_names; + for (auto type_str : EnumStrings<T>::data) { + auto type = static_cast<T>(enum_names.size()); + auto enum_name = python_enum_name(type); + enum_names.push_back(enum_name); + } + + // Define python side enumeration aidge_core.type + auto e_type = py::enum_<T>(m, name.c_str()); + + // Add enum value for each enum name + for (std::size_t idx = 0; idx < enum_names.size(); idx++) { + e_type.value(enum_names[idx].c_str(), static_cast<T>(idx)); + } + + // Define str() to return the bare enum name value, it allows + // to compare directly for instance str(tensor.type()) + // with str(nparray.type) + e_type.def("__str__", [enum_names](const T& type) { + return enum_names[static_cast<int>(type)]; + }, py::prepend()); +} + +void init_DataType(py::module& m) { + bindEnum<DataType>(m, "dtype"); + m.def("format_as", (const char* (*)(DataType)) &format_as, py::arg("dt")); +} + +} // namespace Aidge diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 973fc6f9a94d108d8b81c93384ef8468d8247c41..2171d48975db8f4029abe7982bf6dfc17640dd52 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -345,7 +345,7 @@ void init_Tensor(py::module& m){ return b.toString(); }) .def("__repr__", [](Tensor& b) { - return fmt::format("Tensor(dims = {}, dtype = {})", b.dims(), std::string(EnumStrings<DataType>::data[static_cast<int>(b.dataType())])); + return fmt::format("Tensor({}, dims = {}, dtype = {})", b.toString(-1, 7), b.dims(), b.dataType()); }) .def("__len__", [](Tensor& b) -> size_t{ return b.size(); diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index c292a893779196707a3206f899d49d896bb7c2d6..435badb6cfe3628815d76f0105d8a1680f17a86a 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -19,6 +19,8 @@ namespace Aidge { void init_CoreSysInfo(py::module&); void init_Random(py::module&); void init_Data(py::module&); +void init_DataFormat(py::module&); +void init_DataType(py::module&); void init_Database(py::module&); void init_DataProvider(py::module&); void init_Interpolation(py::module&); @@ -110,6 +112,8 @@ void init_Aidge(py::module& m) { init_Random(m); init_Data(m); + init_DataFormat(m); + init_DataType(m); init_Database(m); init_DataProvider(m); init_Interpolation(m); diff --git a/python_binding/utils/pybind_Log.cpp b/python_binding/utils/pybind_Log.cpp index bb81c10b2ee868bdc8b3e359b19154e4717adb76..d91a96a781374662d6242abba6b94b3905b9ea10 100644 --- a/python_binding/utils/pybind_Log.cpp +++ b/python_binding/utils/pybind_Log.cpp @@ -1,8 +1,21 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + #include <pybind11/pybind11.h> + #include "aidge/utils/Log.hpp" namespace py = pybind11; namespace Aidge { + void init_Log(py::module& m){ py::enum_<Log::Level>(m, "Level") .value("Debug", Log::Debug) @@ -133,7 +146,14 @@ void init_Log(py::module& m){ :param fileName: Log file name. :type fileName: str + )mydelimiter") + .def_static("set_precision", &Log::setPrecision, py::arg("precision"), + R"mydelimiter( + Set the precision format for floating point numbers. + + :param precision: number of digits displayed on the right-hand of the decimal point. + :type precision: int )mydelimiter"); } -} +} // namespace Aidge diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index a14ae4187707490cfb70681fc418daf961cb053b..7bd2754d6cf959d8c0c11becf6d23b1c7f80192e 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -312,143 +312,164 @@ void Tensor::resize(const std::vector<DimSize_t>& dims, } } -std::string Tensor::toString() const { - +std::string Tensor::toString(int precision, std::size_t offset) const { if (!hasImpl() || undefined()) { - // Return no value on no implementation or undefined size - return std::string("{}"); - } - - // TODO: move lambda elsewhere? - auto ptrToString = [](DataType dt, void* ptr, std::size_t idx) { - switch (dt) { - case DataType::Float64: - return std::to_string(static_cast<double*>(ptr)[idx]); - case DataType::Float32: - return std::to_string(static_cast<float*>(ptr)[idx]); - case DataType::Float16: - return std::to_string(static_cast<half_float::half*>(ptr)[idx]); - case DataType::Binary: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Octo_Binary: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Dual_Int4: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Dual_UInt4: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Dual_Int3: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Dual_UInt3: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Quad_Int2: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Quad_UInt2: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int4: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::UInt4: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int3: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::UInt3: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int2: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::UInt2: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::Int8: - return std::to_string(static_cast<int8_t*>(ptr)[idx]); - case DataType::Int16: - return std::to_string(static_cast<int16_t*>(ptr)[idx]); - case DataType::Int32: - return std::to_string(static_cast<int32_t*>(ptr)[idx]); - case DataType::Int64: - return std::to_string(static_cast<int64_t*>(ptr)[idx]); - case DataType::UInt8: - return std::to_string(static_cast<uint8_t*>(ptr)[idx]); - case DataType::UInt16: - return std::to_string(static_cast<uint16_t*>(ptr)[idx]); - case DataType::UInt32: - return std::to_string(static_cast<uint32_t*>(ptr)[idx]); - case DataType::UInt64: - return std::to_string(static_cast<uint64_t*>(ptr)[idx]); - default: - AIDGE_ASSERT(true, "unsupported type to convert to string"); - } - return std::string("?"); // To make Clang happy - }; + return "{}"; + } + + // Use default precision if no override provided + precision = (precision >= 0) ? precision : Log::getPrecision(); + + // Create a type-specific formatter function upfront + std::function<std::string(void*, std::size_t)> formatter; + + switch (mDataType) { + case DataType::Float64: + formatter = [precision](void* ptr, std::size_t idx) { + return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float64>*>(ptr)[idx], precision); + }; + break; + case DataType::Float32: + formatter = [precision](void* ptr, std::size_t idx) { + return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float32>*>(ptr)[idx], precision); + }; + break; + case DataType::Float16: + formatter = [precision](void* ptr, std::size_t idx) { + return fmt::format("{:.{}f}", static_cast<cpptype_t<DataType::Float16>*>(ptr)[idx], precision); + }; + break; + case DataType::Binary: + case DataType::Octo_Binary: + case DataType::Dual_Int4: + case DataType::Dual_Int3: + case DataType::Dual_UInt3: + case DataType::Quad_Int2: + case DataType::Quad_UInt2: + case DataType::Int4: + case DataType::UInt4: + case DataType::Int3: + case DataType::UInt3: + case DataType::Int2: + case DataType::UInt2: + case DataType::Int8: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int32>>(static_cast<cpptype_t<DataType::Int8>*>(ptr)[idx])); + }; + break; + case DataType::Dual_UInt4: + case DataType::UInt8: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt32>>(static_cast<cpptype_t<DataType::UInt8>*>(ptr)[idx])); + }; + break; + case DataType::Int16: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int16>*>(ptr)[idx]); + }; + break; + case DataType::Int32: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int32>*>(ptr)[idx]); + }; + break; + case DataType::Int64: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::Int64>*>(ptr)[idx]); + }; + break; + case DataType::UInt16: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt16>*>(ptr)[idx]); + }; + break; + case DataType::UInt32: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt32>*>(ptr)[idx]); + }; + break; + case DataType::UInt64: + formatter = [](void* ptr, std::size_t idx) { + return fmt::format("{}", static_cast<cpptype_t<DataType::UInt64>*>(ptr)[idx]); + }; + break; + default: + AIDGE_ASSERT(true, "unsupported type to convert to string"); + return "{}"; + } if (dims().empty()) { - // The Tensor is defined with rank 0, hence scalar - return ptrToString(mDataType, mImpl->hostPtr(), 0); - } - - std::string res; - std::size_t dim = 0; - std::size_t counter = 0; - if (nbDims() >= 2) { - std::vector<std::size_t> dimVals(nbDims(), 0); - res += "{\n"; - while (counter < mSize) { - std::string spaceString = std::string((dim + 1) << 1, ' '); - if (dim < nbDims() - 2) { - if (dimVals[dim] == 0) { - res += spaceString + "{\n"; - ++dim; - } else if (dimVals[dim] < - static_cast<std::size_t>(dims()[dim])) { - res += spaceString + "},\n" + spaceString + "{\n"; - ++dim; - } else { - res += spaceString + "}\n"; - dimVals[dim--] = 0; - dimVals[dim]++; - } - } else { - for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); - ++dimVals[dim]) { - res += spaceString + "{"; - for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { - res += - " " + - ptrToString(mDataType, mImpl->hostPtr(mImplOffset), - counter++) + - ","; - } - res += " " + - ptrToString(mDataType, mImpl->hostPtr(mImplOffset), - counter++) + - "}"; - if (dimVals[dim] < - static_cast<std::size_t>(dims()[dim] - 1)) { - res += ","; - } - res += "\n"; - } - if (dim == 0) { - break; - } - dimVals[dim--] = 0; - dimVals[dim]++; + return formatter(mImpl->hostPtr(), 0); + } + + void* dataPtr = mImpl->hostPtr(mImplOffset); + + // Calculate maximum width across all elements + std::size_t maxWidth = 0; + for (std::size_t i = 0; i < mSize; ++i) { + std::string value = formatter(dataPtr, i); + maxWidth = std::max(maxWidth, value.length()); + } + + // Initialize variables similar to Python version + std::vector<std::size_t> indexCoord(nbDims(), 0); + const std::size_t initialDepth = nbDims() > 1 ? nbDims() - 2 : 0; + std::size_t depth = initialDepth; + std::size_t nbBrackets = nbDims() - 1; + std::size_t index = 0; + + // Calculate number of lines (product of all dimensions except last) + std::size_t nbLines = 1; + for (std::size_t d = 0; d < nbDims() - 1; ++d) { + nbLines *= dims()[d]; + } + + std::string result = "{"; // Using { instead of [ for C++ style + + for (std::size_t l = 0; l < nbLines; ++l) { + // Add spacing and opening braces + if (l != 0) { + result += std::string(1 + offset, ' '); + } + result += std::string(nbDims() - 1 - nbBrackets, ' ') + + std::string(nbBrackets, '{'); + + // Print numbers of a line + for (DimSize_t i = 0; i < dims().back(); ++i) { + std::string value = formatter(dataPtr, index); + result += std::string(1 + maxWidth - value.length(), ' ') + value; + if (i + 1 < dims().back()) { + result += ','; } + ++index; } - if (nbDims() != 2) { // If nbDims == 2, parenthesis is already closed - for (int i = static_cast<int>(dim); i >= 0; --i) { - res += std::string((i + 1) << 1, ' ') + "}\n"; + + // Check for end + if (index == mSize) { + result += std::string(nbDims(), '}'); + return result; + } else { + // Update coordinates and depth + while (indexCoord[depth] + 1 >= static_cast<std::size_t>(dims()[depth])) { + indexCoord[depth] = 0; + --depth; } + ++indexCoord[depth]; + nbBrackets = initialDepth - depth + 1; + depth = initialDepth; } - } else { - res += "{"; - for (DimSize_t j = 0; j < dims()[0]; ++j) { - res += " " + - ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) + - ((j < dims()[0] - 1) ? "," : " "); + + // Add closing braces and newlines + result += std::string(nbBrackets, '}') + ",\n"; + if (nbBrackets > 1) { + result += '\n'; } } - res += "}"; - return res; + + return result; } + Tensor Tensor::extract( const std::vector<std::size_t>& fixedCoord) const { AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous"); diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp index 5fc7a604fb795786511c09c1ce46437917e10de9..fb567e355fef28391c894be9ca8ca01e56f36418 100644 --- a/src/utils/Log.cpp +++ b/src/utils/Log.cpp @@ -70,6 +70,7 @@ std::string Log::mFileName = []() { std::unique_ptr<FILE, Log::fcloseDeleter> Log::mFile{nullptr}; std::vector<std::string> Log::mContext; +int Log::mFloatingPointPrecision = 5; /** * @brief Internal logging implementation diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp index fcdf3e8cc1bc07493cfa84608f200f9f334a29cc..677f78e54f850001ab648bfd03c3415b212ba3f2 100644 --- a/unit_tests/operator/Test_ConcatImpl.cpp +++ b/unit_tests/operator/Test_ConcatImpl.cpp @@ -33,22 +33,27 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { std::shared_ptr<Tensor> input4 = std::make_shared<Tensor>(Array1D<int,5>{{ 11, 12, 13, 14, 15 }}); std::shared_ptr<Tensor> input5 = std::make_shared<Tensor>(Array1D<int,6>{{ 16, 17, 18, 19, 20, 21 }}); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,20>{ - { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }}); - - auto myConcat = Concat(5, 0); - myConcat->getOperator()->associateInput(0, input1); - myConcat->getOperator()->associateInput(1, input2); - myConcat->getOperator()->associateInput(2, input3); - myConcat->getOperator()->associateInput(3, input4); - myConcat->getOperator()->associateInput(4, input5); - myConcat->getOperator()->setBackend("cpu"); - myConcat->getOperator()->setDataType(DataType::Int32); - myConcat->forward(); - - std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print(); - - REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); + Tensor expectedOutput = Array1D<int,20>{ + { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }}; + + std::shared_ptr<Concat_Op> op = std::make_shared<Concat_Op>(5,0); + op->associateInput(0, input1); + op->associateInput(1, input2); + op->associateInput(2, input3); + op->associateInput(3, input4); + op->associateInput(4, input5); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); + fmt::print("{}\n", *(op->getInput(0))); + fmt::print("{}\n", *(op->getInput(1))); + fmt::print("{}\n", *(op->getInput(2))); + fmt::print("{}\n", *(op->getInput(3))); + fmt::print("{}\n", *(op->getInput(4))); + op->forward(); + + fmt::print("res: {}\n", *(op->getOutput(0))); + + REQUIRE(*(op->getOutput(0)) == expectedOutput); } SECTION("Concat 4D inputs on 1st axis") { std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> { @@ -75,7 +80,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { } // }); // - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> { + Tensor expectedOutput = Array4D<int,3,3,3,2> { { // { // {{20, 47},{21, 48},{22, 49}}, // @@ -93,18 +98,19 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { {{44, 71},{45, 72},{46, 73}} // } // } // - }); // + }; // auto myConcat = Concat(2, 0); - myConcat->getOperator()->associateInput(0, input1); - myConcat->getOperator()->associateInput(1, input2); - myConcat->getOperator()->setBackend("cpu"); - myConcat->getOperator()->setDataType(DataType::Int32); + std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator()); + op->associateInput(0, input1); + op->associateInput(1, input2); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); myConcat->forward(); - std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0)->print(); + fmt::print("res: {}\n", *(op->getOutput(0))); - REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); + REQUIRE(*(op->getOutput(0)) == expectedOutput); } SECTION("Concat 4D inputs on 3rd axis") { @@ -127,7 +133,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { } }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,9,2> { + Tensor expectedOutput = Array4D<int,1,3,9,2> { { // { // {{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}}, // @@ -135,17 +141,18 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { {{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}} // }, // } // - }); // + }; // auto myConcat = Concat(2, 2); - myConcat->getOperator()->associateInput(0, input1); - myConcat->getOperator()->associateInput(1, input2); - myConcat->getOperator()->setBackend("cpu"); - myConcat->getOperator()->setDataType(DataType::Int32); + std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator()); + op->associateInput(0, input1); + op->associateInput(1, input2); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); myConcat->forward(); std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print(); - REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); + REQUIRE(*(op->getOutput(0)) == expectedOutput); } }