Skip to content
Snippets Groups Projects
Commit 5e91f669 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

feat : approxEqual, better error checking & error messages

replaced tensor->size by tensor->dims comparison
replaced fmt::print by log::error
replaced assert with AIDGE_ASSERT
parent cd3da81c
No related branches found
No related tags found
1 merge request!319feat_operator_convtranspose
...@@ -11,25 +11,23 @@ ...@@ -11,25 +11,23 @@
#ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_ #ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#define AIDGE_CORE_UTILS_TENSOR_UTILS_H_ #define AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#include <cmath> // std::abs
#include "aidge/data/DataType.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include <cmath> // std::abs
#include "aidge/utils/Log.hpp" #include <fmt/base.h>
namespace Aidge { namespace Aidge {
/** /**
* @brief Compare two Aidge::Tensor value wise. The comparison function is: * @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison
* function is:
* *
* |t1-t2| <= absolute + relative * |t2| * |t1-t2| <= absolute + relative * |t2|
* *
* If a tensor value is different from the other tensor return False * If a tensor value is different from the other tensor return False
* If the tensor does not have the same size, return False * If the tensor does not have the same size, return False
* If the datatype is not the same between each tensor return False * If the datatype is not the same between each tensor return False
* If the templated type does not correspond to the datatype of each tensor, raise an assertion error * If the templated type does not correspond to the datatype of each tensor,
* raise an assertion error
* *
* @tparam T1 should correspond to the type of the first tensor, defines part of the type for absolute and relative error * @tparam T1 should correspond to the type of the first tensor, defines part of the type for absolute and relative error
* @tparam T2 should correspond to the type of the second tensor, defaults to T1 * @tparam T2 should correspond to the type of the second tensor, defaults to T1
...@@ -43,33 +41,42 @@ template <typename T1, typename T2 = T1> ...@@ -43,33 +41,42 @@ template <typename T1, typename T2 = T1>
bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float absolute = 1e-8f) { bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float absolute = 1e-8f) {
// Check template type matches tensor datatype // Check template type matches tensor datatype
if (t1.dataType() != NativeType_v<T1>) { if (t1.dataType() != NativeType_v<T1>) {
Log::error("First tensor datatype ({}) does not match template type", t1.dataType()); Log::error("approxEq : First tensor datatype ({}) does not match"
"template type (NativeType_v<T1> = {}) .",
t1.dataType(),
NativeType_v<T1>);
return false; return false;
} }
if (t2.dataType() != NativeType_v<T2>) { if (t2.dataType() != NativeType_v<T2>) {
Log::error("Second tensor datatype ({}) does not match template type", t2.dataType()); Log::error("approxEq : Second tensor datatype ({}) does not match"
"template type (NativeType_v<T1> = {}) .",
t2.dataType(),
NativeType_v<T1>);
return false; return false;
} }
// Validate parameters // Validate parameters
if (relative < 0.0f) { if (relative < 0.0f) {
Log::error("Relative error must be non-negative (got {})", relative); Log::error("approxEq : Relative error must be non-negative (got : {}).",
relative);
return false; return false;
} }
if (absolute < 0.0f || absolute > 1.0f) { if (absolute < 0.0f || absolute > 1.0f) {
Log::error("Absolute error must be between 0 and 1 (got {})", absolute); Log::error("approxEq : Absolute error must be between 0 and 1 (got : {}).",
absolute);
return false; return false;
} }
// Check tensor sizes match if (t1.dims() != t2.dims()) {
if (t1.size() != t2.size()) { Log::error("approxEq: Dimension mismatch.\nt1 :\t{}\nt2 :\t{}",
Log::error("Tensor sizes do not match: {} vs {}", t1.size(), t2.size()); t1.dims(),
t2.dims());
return false; return false;
} }
// Compare values/ // Compare values
for (size_t i = 0; i < t1.size(); ++i) { for (size_t i = 0; i < t1.size(); ++i) {
const auto val1 = t1.get<T1>(i); const auto val1 = t1.get<T1>(i);
const auto val2 = t2.get<T2>(i); const auto val2 = t2.get<T2>(i);
...@@ -77,9 +84,15 @@ bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float ...@@ -77,9 +84,15 @@ bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float
const float threshold = absolute + (relative * static_cast<float>(std::abs(val2))); const float threshold = absolute + (relative * static_cast<float>(std::abs(val2)));
if (diff > threshold) { if (diff > threshold) {
Log::notice("Tensor values differ at index {}: {} vs {} (diff: {}, threshold: {})\n" Log::error("approxEq : value mismatch at index {} : {} != "
"Tensor 1:\n{}\nTensor 2:\n{}", "{} (diff: {}, threshold: {}) \nt1:\n{}\nt2:\n{}\n",
i, val1, val2, diff, threshold, t1, t2); i,
t1.get<T1>(i),
t2.get<T1>(i),
diff,
threshold,
t1,
t2);
return false; return false;
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment