Skip to content
Snippets Groups Projects
Commit f8a49175 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[TensorUtils] Add approxEq method to check if two tensors are approximatly equalts.

parent 6682b885
No related branches found
No related tags found
1 merge request!9Fuse bn
Pipeline #31928 failed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#define AIDGE_CORE_UTILS_TENSOR_UTILS_H_
#include "aidge/data/Tensor.hpp"
/**
* @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is:
*
* |t1-t2| <= absolute + relative * |t2|
*
* If a tensor value is different from the other tensor return False
* If the tensor does not have the same size, return False
* If the datatype is not the same between each tensor return False
* If the templated type does not correspond to the datatype of each tensor, raise an assertion error
*
* @tparam T should correspond to the type of the tensor, define the type of the absolute and relative error
* @param t1 first :cpp:class:`Aidge::Tensor` to test
* @param t2 second :cpp:class:`Aidge::Tensor` to test
* @param relative relative difference allowed
* @param absolute absolute error allowed
* @return true if both tensor are approximately equal and have the datatype, shape. Else return false
*/
template <typename T>
bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, T relative, T absolute){
assert(t1.dataType() == t2.dataType());
assert(t1.dataType() == NativeType<T>::type);
if (t1.size() != t2.size()){
return false;
}
for(size_t i; i < t1.size(); ++i){
if (abs(t1.get<T>(i) - t2.get<T>(i)) > (absolute + (relative * abs(t2.get<T>(i))))){
return false;
}
}
return true;
}
#endif /* AIDGE_CORE_UTILS_TENSOR_UTILS_H_s */
...@@ -45,7 +45,7 @@ void init_GRegex(py::module&); ...@@ -45,7 +45,7 @@ void init_GRegex(py::module&);
void init_Recipies(py::module&); void init_Recipies(py::module&);
void init_Scheduler(py::module&); void init_Scheduler(py::module&);
void init_TensorUtils(py::module&);
void set_python_flag(){ void set_python_flag(){
// Set an env variable to know if we run with ypthon or cpp // Set an env variable to know if we run with ypthon or cpp
...@@ -84,6 +84,7 @@ void init_Aidge(py::module& m){ ...@@ -84,6 +84,7 @@ void init_Aidge(py::module& m){
init_GRegex(m); init_GRegex(m);
init_Recipies(m); init_Recipies(m);
init_Scheduler(m); init_Scheduler(m);
init_TensorUtils(m);
} }
PYBIND11_MODULE(aidge_core, m) { PYBIND11_MODULE(aidge_core, m) {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "aidge/utils/TensorUtils.hpp"
namespace py = pybind11;
namespace Aidge {
template<typename T>
void addTensorUtilsFunction(py::module &m){
m.def("approx_eq",
& approxEq<T>,
py::arg("t1"),
py::arg("t2"),
py::arg("relative"),
py::arg("absolute"),
R"mydelimiter(
Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is:
|t1-t2| <= absolute + relative * |t2|
If a tensor value is different from the other tensor return False
If the tensor does not have the same size, return False
If the datatype is not the same between each tensor return False
If the templated type does not correspond to the datatype of each tensor, raise an assertion error
:param t1: first tensor to test
:type t1: :py:class:`aidge_core.Tensor`
:param t2: second tensor to test
:type t2: :py:class:`aidge_core.Tensor`
:param relative: relative difference allowed
:type relative: compatible datatype with compared :py:class:`aidge_core.Tensor`
:param absolute: absolute error allowed
:type absolute: compatible datatype with compared :py:class:`aidge_core.Tensor`
)mydelimiter");
}
void init_TensorUtils(py::module &m) {
addTensorUtilsFunction<float>(m);
addTensorUtilsFunction<double>(m);
addTensorUtilsFunction<int>(m);
addTensorUtilsFunction<long>(m);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment