diff --git a/include/aidge/utils/TensorUtils.hpp b/include/aidge/utils/TensorUtils.hpp index 794abf7632766fb55d8a9a96c0ffa0c046d9de62..a62850c7342836494b29405acea5b74aaae7d2e0 100644 --- a/include/aidge/utils/TensorUtils.hpp +++ b/include/aidge/utils/TensorUtils.hpp @@ -11,25 +11,23 @@ #ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_ #define AIDGE_CORE_UTILS_TENSOR_UTILS_H_ - -#include <cmath> // std::abs - -#include "aidge/data/DataType.hpp" #include "aidge/data/Tensor.hpp" -#include "aidge/utils/ErrorHandling.hpp" -#include "aidge/utils/Log.hpp" +#include <cmath> // std::abs +#include <fmt/base.h> namespace Aidge { /** - * @brief Compare two Aidge::Tensor value wise. The comparison function is: + * @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison + * function is: * * |t1-t2| <= absolute + relative * |t2| * * If a tensor value is different from the other tensor return False * If the tensor does not have the same size, return False * If the datatype is not the same between each tensor return False - * If the templated type does not correspond to the datatype of each tensor, raise an assertion error + * If the templated type does not correspond to the datatype of each tensor, + * raise an assertion error * * @tparam T1 should correspond to the type of the first tensor, defines part of the type for absolute and relative error * @tparam T2 should correspond to the type of the second tensor, defaults to T1 @@ -43,33 +41,42 @@ template <typename T1, typename T2 = T1> bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float absolute = 1e-8f) { // Check template type matches tensor datatype if (t1.dataType() != NativeType_v<T1>) { - Log::error("First tensor datatype ({}) does not match template type", t1.dataType()); + Log::error("approxEq : First tensor datatype ({}) does not match" + "template type (NativeType_v<T1> = {}) .", + t1.dataType(), + NativeType_v<T1>); return false; } if (t2.dataType() != NativeType_v<T2>) { - Log::error("Second tensor datatype ({}) does not match template type", t2.dataType()); + Log::error("approxEq : Second tensor datatype ({}) does not match" + "template type (NativeType_v<T1> = {}) .", + t2.dataType(), + NativeType_v<T1>); return false; } // Validate parameters if (relative < 0.0f) { - Log::error("Relative error must be non-negative (got {})", relative); + Log::error("approxEq : Relative error must be non-negative (got : {}).", + relative); return false; } if (absolute < 0.0f || absolute > 1.0f) { - Log::error("Absolute error must be between 0 and 1 (got {})", absolute); + Log::error("approxEq : Absolute error must be between 0 and 1 (got : {}).", + absolute); return false; } - // Check tensor sizes match - if (t1.size() != t2.size()) { - Log::error("Tensor sizes do not match: {} vs {}", t1.size(), t2.size()); + if (t1.dims() != t2.dims()) { + Log::error("approxEq: Dimension mismatch.\nt1 :\t{}\nt2 :\t{}", + t1.dims(), + t2.dims()); return false; } - // Compare values/ + // Compare values for (size_t i = 0; i < t1.size(); ++i) { const auto val1 = t1.get<T1>(i); const auto val2 = t2.get<T2>(i); @@ -77,9 +84,15 @@ bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float const float threshold = absolute + (relative * static_cast<float>(std::abs(val2))); if (diff > threshold) { - Log::notice("Tensor values differ at index {}: {} vs {} (diff: {}, threshold: {})\n" - "Tensor 1:\n{}\nTensor 2:\n{}", - i, val1, val2, diff, threshold, t1, t2); + Log::error("approxEq : value mismatch at index {} : {} != " + "{} (diff: {}, threshold: {}) \nt1:\n{}\nt2:\n{}\n", + i, + t1.get<T1>(i), + t2.get<T1>(i), + diff, + threshold, + t1, + t2); return false; } }