Skip to content
Snippets Groups Projects

feat_operator_convtranspose

Merged Grégoire Kubler requested to merge feat_operator_convtranspose into dev
1 file
+ 32
19
Compare changes
  • Side-by-side
  • Inline
@@ -11,25 +11,23 @@
#ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#define AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#include <cmath> // std::abs
#include "aidge/data/DataType.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Log.hpp"
#include <cmath> // std::abs
#include <fmt/base.h>
namespace Aidge {
/**
* @brief Compare two Aidge::Tensor value wise. The comparison function is:
* @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison
* function is:
*
* |t1-t2| <= absolute + relative * |t2|
*
* If a tensor value is different from the other tensor return False
* If the tensor does not have the same size, return False
* If the datatype is not the same between each tensor return False
* If the templated type does not correspond to the datatype of each tensor, raise an assertion error
* If the templated type does not correspond to the datatype of each tensor,
* raise an assertion error
*
* @tparam T1 should correspond to the type of the first tensor, defines part of the type for absolute and relative error
* @tparam T2 should correspond to the type of the second tensor, defaults to T1
@@ -43,33 +41,42 @@ template <typename T1, typename T2 = T1>
bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float absolute = 1e-8f) {
// Check template type matches tensor datatype
if (t1.dataType() != NativeType_v<T1>) {
Log::error("First tensor datatype ({}) does not match template type", t1.dataType());
Log::error("approxEq : First tensor datatype ({}) does not match"
"template type (NativeType_v<T1> = {}) .",
t1.dataType(),
NativeType_v<T1>);
return false;
}
if (t2.dataType() != NativeType_v<T2>) {
Log::error("Second tensor datatype ({}) does not match template type", t2.dataType());
Log::error("approxEq : Second tensor datatype ({}) does not match"
"template type (NativeType_v<T1> = {}) .",
t2.dataType(),
NativeType_v<T1>);
return false;
}
// Validate parameters
if (relative < 0.0f) {
Log::error("Relative error must be non-negative (got {})", relative);
Log::error("approxEq : Relative error must be non-negative (got : {}).",
relative);
return false;
}
if (absolute < 0.0f || absolute > 1.0f) {
Log::error("Absolute error must be between 0 and 1 (got {})", absolute);
Log::error("approxEq : Absolute error must be between 0 and 1 (got : {}).",
absolute);
return false;
}
// Check tensor sizes match
if (t1.size() != t2.size()) {
Log::error("Tensor sizes do not match: {} vs {}", t1.size(), t2.size());
if (t1.dims() != t2.dims()) {
Log::error("approxEq: Dimension mismatch.\nt1 :\t{}\nt2 :\t{}",
t1.dims(),
t2.dims());
return false;
}
// Compare values/
// Compare values
for (size_t i = 0; i < t1.size(); ++i) {
const auto val1 = t1.get<T1>(i);
const auto val2 = t2.get<T2>(i);
@@ -77,9 +84,15 @@ bool approxEq(const Tensor& t1, const Tensor& t2, float relative = 1e-5f, float
const float threshold = absolute + (relative * static_cast<float>(std::abs(val2)));
if (diff > threshold) {
Log::notice("Tensor values differ at index {}: {} vs {} (diff: {}, threshold: {})\n"
"Tensor 1:\n{}\nTensor 2:\n{}",
i, val1, val2, diff, threshold, t1, t2);
Log::error("approxEq : value mismatch at index {} : {} != "
"{} (diff: {}, threshold: {}) \nt1:\n{}\nt2:\n{}\n",
i,
t1.get<T1>(i),
t2.get<T1>(i),
diff,
threshold,
t1,
t2);
return false;
}
}
Loading