diff --git a/include/aidge/utils/TensorUtils.hpp b/include/aidge/utils/TensorUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7a4dfc56b5d61074bf1a64112763c7328d31a960 --- /dev/null +++ b/include/aidge/utils/TensorUtils.hpp @@ -0,0 +1,49 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_UTILS_TENSOR_UTILS_H_ +#define AIDGE_CORE_UTILS_TENSOR_UTILS_H_ + +#include "aidge/data/Tensor.hpp" + +/** + * @brief Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is: + * + * |t1-t2| <= absolute + relative * |t2| + * + * If a tensor value is different from the other tensor return False + * If the tensor does not have the same size, return False + * If the datatype is not the same between each tensor return False + * If the templated type does not correspond to the datatype of each tensor, raise an assertion error + * + * @tparam T should correspond to the type of the tensor, define the type of the absolute and relative error + * @param t1 first :cpp:class:`Aidge::Tensor` to test + * @param t2 second :cpp:class:`Aidge::Tensor` to test + * @param relative relative difference allowed + * @param absolute absolute error allowed + * @return true if both tensor are approximately equal and have the datatype, shape. Else return false + */ +template <typename T> +bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, T relative, T absolute){ + assert(t1.dataType() == t2.dataType()); + assert(t1.dataType() == NativeType<T>::type); + if (t1.size() != t2.size()){ + return false; + } + for(size_t i; i < t1.size(); ++i){ + if (abs(t1.get<T>(i) - t2.get<T>(i)) > (absolute + (relative * abs(t2.get<T>(i))))){ + return false; + } + } + return true; +} + +#endif /* AIDGE_CORE_UTILS_TENSOR_UTILS_H_s */ diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index b861f881c684a2fbe800ab672299871cfc89d7ac..83619032c3ef8e5b4b279c1ffb550f1f4340f450 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -45,7 +45,7 @@ void init_GRegex(py::module&); void init_Recipies(py::module&); void init_Scheduler(py::module&); - +void init_TensorUtils(py::module&); void set_python_flag(){ // Set an env variable to know if we run with ypthon or cpp @@ -84,6 +84,7 @@ void init_Aidge(py::module& m){ init_GRegex(m); init_Recipies(m); init_Scheduler(m); + init_TensorUtils(m); } PYBIND11_MODULE(aidge_core, m) { diff --git a/python_binding/utils/pybind_TensorUtils.cpp b/python_binding/utils/pybind_TensorUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..089e8a2114982d53a0eed6cf8a73c9aea68876c6 --- /dev/null +++ b/python_binding/utils/pybind_TensorUtils.cpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> + +#include "aidge/utils/TensorUtils.hpp" + +namespace py = pybind11; + +namespace Aidge { + +template<typename T> +void addTensorUtilsFunction(py::module &m){ + m.def("approx_eq", + & approxEq<T>, + py::arg("t1"), + py::arg("t2"), + py::arg("relative"), + py::arg("absolute"), + R"mydelimiter( + Compare two :cpp:class:`Aidge::Tensor` value wise. The comparison function is: + |t1-t2| <= absolute + relative * |t2| + + If a tensor value is different from the other tensor return False + If the tensor does not have the same size, return False + If the datatype is not the same between each tensor return False + If the templated type does not correspond to the datatype of each tensor, raise an assertion error + + :param t1: first tensor to test + :type t1: :py:class:`aidge_core.Tensor` + :param t2: second tensor to test + :type t2: :py:class:`aidge_core.Tensor` + :param relative: relative difference allowed + :type relative: compatible datatype with compared :py:class:`aidge_core.Tensor` + :param absolute: absolute error allowed + :type absolute: compatible datatype with compared :py:class:`aidge_core.Tensor` + )mydelimiter"); +} + +void init_TensorUtils(py::module &m) { + addTensorUtilsFunction<float>(m); + addTensorUtilsFunction<double>(m); + addTensorUtilsFunction<int>(m); + addTensorUtilsFunction<long>(m); +} +} // namespace Aidge