Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2016 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
TensorUtils.hpp 2.10 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#define AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#include <cmath>  // std::abs
#include "aidge/data/Tensor.hpp"

namespace Aidge {
/**
 * @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is:
 *
 * |t1-t2| <= absolute + relative * |t2|
 *
 * If a tensor value is different from the other tensor return False
 * If the tensor does not have the same size, return False
 * If the datatype is not the same between each tensor return False
 * If the templated type does not correspond to the datatype of each tensor, raise an assertion error
 *
 * @tparam T should correspond to the type of the tensor, define the type of the absolute and relative error
 * @param t1  first :cpp:class:`Aidge::Tensor` to test
 * @param t2  second :cpp:class:`Aidge::Tensor` to test
 * @param relative relative difference allowed (should be betwen 0 and 1)
 * @param absolute absolute error allowed (shoulmd be positive)
 * @return true if both tensor are approximately equal and have the datatype, shape. Else return false
 */
template <typename T1, typename T2 = T1>
bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute){
    assert(t1.dataType() == NativeType<T1>::type);
    assert(t2.dataType() == NativeType<T2>::type);
    assert(relative >= 0);
    assert(absolute >= 0 && absolute<=1);

    if (t1.size() != t2.size()){
        return false;
    }
    for(size_t i = 0; i < t1.size(); ++i){
        if (static_cast<float>(std::abs(t1.get<T1>(i) - t2.get<T2>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T2>(i)))))){
            return false;
        }
    }
    return true;
}
}

#endif /* AIDGE_CORE_UTILS_TENSOR_UTILS_H_s */