Skip to content
Snippets Groups Projects
Commit a81e082c authored by Grégoire Kubler's avatar Grégoire Kubler Committed by Olivier BICHLER
Browse files

feat : added interpolation, linear & nearest

Also added generic interpolation function that serves as a wrapper for all future interpolations functions to implement.
parent a200c447
No related branches found
No related tags found
2 merge requests!118v0.4.0,!104update Resize operator
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_DATA_INTERPOLATION_H_
#define AIDGE_CPU_DATA_INTERPOLATION_H_
#include <vector>
#include <aidge/data/Interpolation.hpp>
#include <aidge/utils/Types.h>
namespace Aidge {
class InterpolationCPU : public Interpolation {
public:
/*
* @brief Interpolates values given via input in given mode.
*
* Values are contiguously arranged in a "square" shape around the point to
* interpolate. Depending on interpolation mode.
* The point that will be interpolated is located right in the
* middle of all points.
* Immediate neighbours :
* 1D interp : 2D interp :
* . . . . . .
* . . 1 2 . . . . . . . .
* . . 1 2 . .
* . . 3 4 . .
* . . . . . .
* . . . . . .
*
* 2 neighbours :
* 1D interp : 2D interp :
* . . . . . . . .
* . . . . . . . .
* . . 1 2 3 4 . . . . 1 2 3 4 . .
* . . 5 6 7 8 . .
* . . 9 10 11 12 . .
* . . 13 14 15 16 . .
* . . . . . . . .
* . . . . . . . .
*
* @param[in] originalCoords: coord of the point to interpolate in the
* original picture. These coords are generated with
* Interpolation::untransformCoords(coordsInInterpolatedTensor)
* @param[in] points : points to interpolate, arranged in a vector of a
* pairs ((point_coord), value) :
* [[[X1, X2, ..., XN], Xval], ...., [[A1, A2, ..., AN],Aval]].
* With :
* - N: the number of dimensions.
* - A: the number of points of the grid to interpolate.
* - All coordinates expressed in originalTensor frame.
* @param[in] interpMode: interpolation mode
* @return interpolated value
*/
template <typename T>
static T interpolate(const std::vector<float> &coordsToInterpolate,
const std::set<Point<T>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
/**
* @brief performs linear interpolation on given points.
* @param[in] values: values to interpolate, since we only do an average of
* all values, their indexes isn't useful.
* @return interpolated value
*/
template <typename T>
static T linear(const std::vector<float> &originalCoords,
const std::set<Point<T>> &points);
/**
* @brief performs nearest interpolation on given points.
* @note it is a wrapper for linearRecurse() private method
* @param[in] coordsToInterpolate: coordinates to interpolate
* @param[in] points: points to interpolate
* @param[in] interpMode: interpolation method, must be a Nearest...
* otherwise function will throw an error.
* @return interpolated value
*/
template <typename T>
static T nearest(const std::vector<float> &coordsToInterpolate,
const std::set<Point<T>> &points,
const Interpolation::Mode nearestMode);
private:
/**
* @brief actual linear interpolation function.
* will :
* - Split all points along each dimension depending of if their coords at
* idx alongDim are above or under coordsToInterpolate until they are
* 1-to-1.
* - Perform interpolation in 2 leftover points and return interpolated
* point to parent call with a set of size 1.
* - repeat until all dimensions have been interpolated.
* @param[in] coordsToInterpolate: coordinates to interpolate
* @param[in] points: points to interpolate
* @param[in] alongDim: discriminant on along which dimension are being
* segregated.
* @return
*/
template <typename T>
static std::set<Interpolation::Point<T>>
linearRecurse(const std::vector<float> &coordsToInterpolate,
const std::set<Point<T>> &points,
const DimIdx_t alongDim = 0);
};
} // namespace Aidge
#endif // AIDGE_CPU_DATA_INTERPOLATION_H_
#include "aidge/backend/cpu/data/Interpolation.hpp"
#include <aidge/utils/Log.hpp>
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iterator>
#include <stdexcept>
#include <utility>
#include <vector>
#include <aidge/data/Interpolation.hpp>
#include <aidge/data/half.hpp>
#include <aidge/utils/ErrorHandling.hpp>
#include <aidge/utils/Types.h>
namespace Aidge {
template <typename T>
std::set<Interpolation::Point<T>>
InterpolationCPU::linearRecurse(const std::vector<float> &coordToInterpolate,
const std::set<Point<T>> &points,
const DimIdx_t alongDim) {
// all points have been discriminated properly along given dimension.
if (points.size() == 1) {
return points;
}
auto extractPtCoords = [](std::set<Point<T>> pts) -> std::set<Coords> {
std::set<Coords> result;
for (const auto &pt : pts) {
result.insert(pt.first);
}
return result;
};
///////////////////
// ERROR CHECKING
if (alongDim > coordToInterpolate.size() || points.size() == 0) {
// retrieving points coords as points values can be in half_float &
// this type is not fmt compatible
std::vector<Coords> pointsCoords;
for (const auto &point : points) {
pointsCoords.push_back(point.first);
}
AIDGE_ASSERT(
alongDim >= coordToInterpolate.size(),
"InterpolationCPU::linearInterpolationRecurse: alongDim value "
"exceeded exceeded number of dimensions of coordsTointerpolate. "
"Interpolation has failed. Input values : \n - "
"coordsToInterpolate {}\n - pointsToInterpolate {}\n - alongDim "
"{}",
coordToInterpolate,
pointsCoords,
alongDim);
AIDGE_ASSERT(
points.size() == 0,
"InterpolationCPU::linearInterpolationRecurse: entering recursive "
"function with 0 points. Interpolation has failed."
"Please file a bug report to aidge_backend_cpu repo: "
"https://gitlab.eclipse.org/eclipse/aidge/aidge_backend_cpu/-/"
"issues."
"\nInput values : \n - "
"coordsToInterpolate {}\n - pointsToInterpolate {}\n - alongDim "
"{}",
coordToInterpolate,
pointsCoords,
alongDim);
}
Log::debug("\nEntering linear recurse with {} points.", points.size());
Log::debug("Points : {}", extractPtCoords(points));
Log::debug("coordsToInterpolate : {}", coordToInterpolate);
Log::debug("alongDim : {}", alongDim);
///////////////////
// COMPUTATION
// split all points along each dimension
// depending on if their coords[alongDim] are above or under
// coords to interpolate values
std::set<Point<T>> lowerPoints;
std::set<Point<T>> upperPoints;
for (const auto &point : points) {
if (point.first[alongDim] <= coordToInterpolate[alongDim]) {
lowerPoints.insert(point);
} else {
upperPoints.insert(point);
}
}
Log::debug("alongDim : {}", alongDim);
Log::debug("lowerPoints : {}", extractPtCoords(lowerPoints));
Log::debug("upperPoints : {}", extractPtCoords(upperPoints));
// Here are 3 cases
// 1. upper/lowerPoints.size() == 0
// Coordinates to interpolate along current dimension are round.
// That would be equivalent to a linear interpolation with a
// ponderation of 1 for lowerPoints & 0 for upperPoints(or the
// opposite idk), hence we will only take lower/upperPoints values
// from there.
//
// Why this happens :
// If coordinates are round, the floor()/ceil() operations called
// in retrieveNeighbours to generate direct neighbours of floating
// coordinates returned the same value.
//
// 2. lower/upperPoints.size() == 1
// All dimensions have been discriminated, we can proceed to
// weighted interpolation
//
// 3. lower/upperPoints.size() > 1
// points have not been all discriminated and must be further split
// so we call linearRecurse()
switch (lowerPoints.size()) {
case 0: {
return linearRecurse(coordToInterpolate, upperPoints, alongDim + 1);
}
case 1: {
break;
}
default: {
lowerPoints =
linearRecurse(coordToInterpolate, lowerPoints, alongDim + 1);
break;
}
}
switch (upperPoints.size()) {
case 0: {
return linearRecurse(coordToInterpolate, lowerPoints, alongDim + 1);
}
case 1: {
break;
}
default: {
upperPoints =
linearRecurse(coordToInterpolate, upperPoints, alongDim + 1);
break;
}
}
// At this point lowerPoints & upperPoints are garanteed to be
// 1 sized arrays
AIDGE_ASSERT(lowerPoints.size() == 1,
"LowerPoints Size = {} != 1",
lowerPoints.size());
AIDGE_ASSERT(upperPoints.size() == 1,
"upperPoints Size = {} != 1",
upperPoints.size());
// ( point[dim] - Pl[dim] )
// t = ------------------------
// ( Pu[dim] - Pl[dim] )
float weight =
(coordToInterpolate[alongDim] - lowerPoints.begin()->first[alongDim]) /
(upperPoints.begin()->first[alongDim] -
lowerPoints.begin()->first[alongDim]);
Point<T> interpolatedPoint = std::make_pair(
lowerPoints.begin()->first,
static_cast<T>((1.F - weight) * lowerPoints.begin()->second +
weight * upperPoints.begin()->second));
// 0 is just a sanity check to ensure later that all dims have been
// interpolate
interpolatedPoint.first[alongDim] = 0;
Log::debug("successfully returned from alongDim : {}", alongDim);
return std::set({interpolatedPoint});
};
template <typename T>
T InterpolationCPU::linear(const std::vector<float> &coordToInterpolate,
const std::set<Point<T>> &pointsToInterpolate) {
auto result = linearRecurse(coordToInterpolate, pointsToInterpolate, 0);
AIDGE_ASSERT(result.size() == 1,
"Result size is not 1 but {}",
result.size());
// if (!std::all_of(result.begin()->first.begin(),
// result.begin()->first.end(),
// [](DimSize_t coord) -> bool { return coord == 0; })) {
// std::vector<Coords> ptCoords;
// std::transform(pointsToInterpolate.begin(),
// pointsToInterpolate.end(),
// std::back_inserter(ptCoords),
// [](Point<T> pt) { return pt.first; });
// AIDGE_THROW_OR_ABORT(std::runtime_error,
// "Not all dimensions have been interpolated."
// "Input data :"
// "\n\t coord to interpolate : {}"
// "\n\t pointsToInterpolate : {}",
// // "\n\tAll non 0 values show dimensions
// // that were not interpolated : {}",
// coordToInterpolate,
// ptCoords //,
// // result.begin()->first
// );
// }
return result.begin()->second;
}
template <typename T>
T InterpolationCPU::nearest(const std::vector<float> &coordsToInterpolate,
const std::set<Point<T>> &points,
const Interpolation::Mode nearestMode) {
AIDGE_ASSERT(
coordsToInterpolate.size() == points.begin()->first.size(),
"Interpolation::nearest(): dimension mismatch : coordinate "
"to interpolate ({}) have not the same number of dimensions than "
"the points to interpolate({}).",
coordsToInterpolate,
points.begin()->first);
std::function<int64_t(const float &)> updateCoordinates;
switch (nearestMode) {
case Interpolation::Mode::NearestCeil: {
updateCoordinates = [](const float &coord) -> int64_t {
return ceil(coord);
};
break;
}
case Interpolation::Mode::NearestFloor: {
updateCoordinates = [](const float &coord) -> int64_t {
return floor(coord);
};
break;
}
case Interpolation::Mode::NearestRoundPreferFloor: {
updateCoordinates = [](const float &coord) -> int64_t {
return (coord - floor(coord)) == 0.5 ? floor(coord)
: std::round(coord);
};
break;
}
case Interpolation::Mode::NearestRoundPreferCeil: {
updateCoordinates = [](const float &coord) -> int64_t {
return (coord - floor(coord)) == 0.5 ? ceil(coord)
: std::round(coord);
};
break;
}
default: {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"Invalid Interpolation mode for "
"InterpolationCPU::interpolateNearest. Accepted modes are : "
"NearestCeil({}),NearestFloor({}),NearestRoundPreferCeil({}), "
"NearestRoundPreferFloor({}). Got {}.",
static_cast<int>(NearestCeil),
static_cast<int>(NearestFloor),
static_cast<int>(NearestRoundPreferCeil),
static_cast<int>(NearestRoundPreferFloor),
static_cast<int>(nearestMode));
}
}
Coords nearestCoords;
nearestCoords.reserve(coordsToInterpolate.size());
for (const auto &coord : coordsToInterpolate) {
nearestCoords.push_back(updateCoordinates(coord));
}
auto it = std::find_if(
points.begin(),
points.end(),
[nearestCoords](auto &point) { return nearestCoords == point.first; });
if (it != points.end()) {
return it->second;
} else {
Log::warn("Interpolate::nearest(): did not find a fitting point in "
"the neighbours whose coordinates were {}, returning 0. "
"Available neighbours are at following indexes: ",
coordsToInterpolate);
for (const auto &point : points) {
Log::warn("idx : [{}]\t\tvalue {}", point.first);
}
return static_cast<T>(0);
}
}
template <typename T>
T InterpolationCPU::interpolate(const std::vector<float> &coordsToInterpolate,
const std::set<Point<T>> &points,
const Mode interpMode) {
T result{0};
switch (interpMode) {
case Interpolation::Mode::Cubic: {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"Unsupported interpolation mode selected : Cubic.");
break;
}
case Interpolation::Mode::Linear: {
return linear(coordsToInterpolate, points);
break;
}
case Interpolation::Mode::NearestCeil:
case Interpolation::Mode::NearestFloor:
case Interpolation::Mode::NearestRoundPreferFloor:
case Interpolation::Mode::NearestRoundPreferCeil: {
result =
InterpolationCPU::nearest(coordsToInterpolate, points, interpMode);
break;
}
default: {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"InterpolationCPU::Interpolate({}): Unsupported "
"interpolation mode given as input.",
static_cast<int>(interpMode));
break;
}
}
return result;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////
// TEMPLATE DECLARATION
//////////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////
// INTERPOLATE
template int8_t InterpolationCPU::interpolate<int8_t>(
const std::vector<float> &originalCoords,
const std::set<Point<int8_t>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
template int16_t InterpolationCPU::interpolate<int16_t>(
const std::vector<float> &originalCoords,
const std::set<Point<int16_t>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
template int32_t InterpolationCPU::interpolate<int32_t>(
const std::vector<float> &originalCoords,
const std::set<Point<int32_t>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
template int64_t InterpolationCPU::interpolate<int64_t>(
const std::vector<float> &originalCoords,
const std::set<Point<int64_t>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
template half_float::half InterpolationCPU::interpolate<half_float::half>(
const std::vector<float> &originalCoords,
const std::set<Point<half_float::half>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
template float InterpolationCPU::interpolate<float>(
const std::vector<float> &originalCoords,
const std::set<Point<float>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
template double InterpolationCPU::interpolate<double>(
const std::vector<float> &originalCoords,
const std::set<Point<double>> &points,
const Mode interpMode = Interpolation::Mode::Linear);
////////////////////////////////////////////////////////////////////
// INTERPOLATE LINEAR (& its associated recursive function)
template int8_t
InterpolationCPU::linear(const std::vector<float> &coordsToInterpolate,
const std::set<Point<int8_t>> &points);
template std::set<Interpolation::Point<int8_t>>
InterpolationCPU::linearRecurse(const std::vector<float> &coordsToInterpolate,
const std::set<Point<int8_t>> &points,
DimIdx_t alongDim);
template int16_t
InterpolationCPU::linear(const std::vector<float> &coordsToInterpolate,
const std::set<Point<int16_t>> &points);
template std::set<Interpolation::Point<int16_t>>
InterpolationCPU::linearRecurse(const std::vector<float> &coordsToInterpolate,
const std::set<Point<int16_t>> &points,
DimIdx_t alongDim);
template int32_t
InterpolationCPU::linear(const std::vector<float> &coordsToInterpolate,
const std::set<Point<int32_t>> &points);
template std::set<Interpolation::Point<int32_t>>
InterpolationCPU::linearRecurse(const std::vector<float> &coordsToInterpolate,
const std::set<Point<int32_t>> &points,
DimIdx_t alongDim);
template half_float::half
InterpolationCPU::linear(const std::vector<float> &coordsToInterpolate,
const std::set<Point<half_float::half>> &points);
template std::set<Interpolation::Point<half_float::half>>
InterpolationCPU::linearRecurse(
const std::vector<float> &coordsToInterpolate,
const std::set<Point<half_float::half>> &points,
DimIdx_t alongDim);
template float
InterpolationCPU::linear(const std::vector<float> &coordsToInterpolate,
const std::set<Point<float>> &points);
template std::set<Interpolation::Point<float>>
InterpolationCPU::linearRecurse(const std::vector<float> &coordsToInterpolate,
const std::set<Point<float>> &points,
DimIdx_t alongDim);
template double
InterpolationCPU::linear(const std::vector<float> &coordsToInterpolate,
const std::set<Point<double>> &points);
template std::set<Interpolation::Point<double>>
InterpolationCPU::linearRecurse(const std::vector<float> &coordsToInterpolate,
const std::set<Point<double>> &points,
DimIdx_t alongDim);
//////////////////////////////////
// INTERPOLATE NEAREST
template int8_t
InterpolationCPU::nearest(const std::vector<float> &originalCoords,
const std::set<Point<int8_t>> &points,
const Interpolation::Mode nearestMode);
template int16_t
InterpolationCPU::nearest(const std::vector<float> &originalCoords,
const std::set<Point<int16_t>> &points,
const Interpolation::Mode nearestMode);
template int32_t
InterpolationCPU::nearest(const std::vector<float> &originalCoords,
const std::set<Point<int32_t>> &points,
const Interpolation::Mode nearestMode);
template half_float::half
InterpolationCPU::nearest(const std::vector<float> &originalCoords,
const std::set<Point<half_float::half>> &points,
const Interpolation::Mode nearestMode);
template float
InterpolationCPU::nearest(const std::vector<float> &originalCoords,
const std::set<Point<float>> &points,
const Interpolation::Mode nearestMode);
template double
InterpolationCPU::nearest(const std::vector<float> &originalCoords,
const std::set<Point<double>> &points,
const Interpolation::Mode nearestMode);
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <aidge/backend/cpu/data/Interpolation.hpp>
#include <aidge/data/Interpolation.hpp>
#include <aidge/data/Tensor.hpp>
#include <aidge/filler/Filler.hpp>
#include <aidge/utils/Types.h>
#include <catch2/catch_test_macros.hpp>
#include <limits>
#include "aidge/backend/cpu/data/Interpolation.hpp"
namespace Aidge {
TEST_CASE("Interpolation", "[Interpolation][Data]") {
SECTION("Linear") {
std::set<Interpolation::Point<int>> pointsToInterpolateInt;
std::set<Interpolation::Point<float>> pointsToInterpolateFloat;
SECTION("1D") {
pointsToInterpolateInt =
std::set<Interpolation::Point<int>>({{{0}, 10}, {{1}, 20}});
CHECK(abs(InterpolationCPU::linear({0.5}, pointsToInterpolateInt) -
15) <= std::numeric_limits<int>::epsilon());
pointsToInterpolateFloat = std::set<Interpolation::Point<float>>(
{{{0}, .0F}, {{1}, 0.2F}});
CHECK(fabs(InterpolationCPU::linear({0.3},
pointsToInterpolateFloat) -
.06F) <= 1e-5);
}
SECTION("2D") {
// example taken from
// https://en.wikipedia.org/wiki/Bilinear_interpolation
pointsToInterpolateFloat = {{{14, 20}, 91.F},
{{14, 21}, 162.F},
{{15, 20}, 210.F},
{{15, 21}, 95.F}};
CHECK(fabs(InterpolationCPU::linear<float>(
{14.5F, 20.2F},
pointsToInterpolateFloat) -
146.1) < 1e-5);
// pointsToInterpolateFloat = {{{0, 0}, .10F},
// {{0, 1}, .20F},
// {{1, 0}, .30F},
// {{1, 1}, .40F}};
// CHECK(abs(InterpolationCPU::linear<float>({1.5, 0.5},
// pointsToInterpolateInt)
// -
// 25) < std::numeric_limits<int>::epsilon());
// pointsToInterpolateFloat = std::vector({0.1F, 0.2F, 0.3F,
// 0.4F}); CHECK(InterpolationCPU::linear(pointsToInterpolateFloat)
// == .25f);
}
SECTION("3D") {
pointsToInterpolateFloat = {{{0, 0, 0}, .1F},
{{0, 0, 1}, .2F},
{{0, 1, 0}, .3F},
{{0, 1, 1}, .4F},
{{1, 0, 0}, .5F},
{{1, 0, 1}, .6F},
{{1, 1, 0}, .7F},
{{1, 1, 1}, .8F}};
CHECK(fabs(InterpolationCPU::linear({.5, .5, .5},
pointsToInterpolateFloat) -
.45f) < 1e-5);
}
SECTION("4D") {
SECTION("Casual") {
pointsToInterpolateFloat = {{{0, 0, 0, 0}, .1F},
{{0, 0, 0, 1}, .2F},
{{0, 0, 1, 0}, .3F},
{{0, 0, 1, 1}, .4F},
{{0, 1, 0, 0}, .5F},
{{0, 1, 0, 1}, .6F},
{{0, 1, 1, 0}, .7F},
{{0, 1, 1, 1}, .8F},
{{1, 0, 0, 0}, .9F},
{{1, 0, 0, 1}, 1.F},
{{1, 0, 1, 0}, 1.1F},
{{1, 0, 1, 1}, 1.2F},
{{1, 1, 0, 0}, 1.3F},
{{1, 1, 0, 1}, 1.4F},
{{1, 1, 1, 0}, 1.5F},
{{1, 1, 1, 1}, 1.6F}};
CHECK(fabs(InterpolationCPU::linear<float>(
{.5, .5, .5, .5},
pointsToInterpolateFloat) -
.85f) < 0.0001);
}
}
SECTION("Some of the coords to interpolate were round") {
// In this case retrieveNeighbours()
// only retrieved the neighbours against not round dimensions
auto tensor =
std::make_shared<Tensor>(std::vector<DimSize_t>({10, 10}));
tensor->setDataType(DataType::Float32);
tensor->setBackend("cpu");
Aidge::constantFiller(tensor, 1337.F);
std::set<Interpolation::Point<float>> expectedResult = {
{{0, 0, -1, -1}, 0.F},
{{0, 0, 0, -1}, 0.F},
{{0, 0, -1, 0}, 0.F},
{{0, 0, 0, 0}, 1337.F}};
pointsToInterpolateFloat = Interpolation::retrieveNeighbours(
reinterpret_cast<float *>(tensor->getImpl()->rawPtr()),
tensor->dims(),
std::vector<float>({0.F, 0.F, -0.25F, -0.25F}));
pointsToInterpolateFloat = {{{0, 0, -1, -1}, 1337.F},
{{0, 0, 0, -1}, 1337.F},
{{0, 0, -1, 0}, 1337.F},
{{0, 0, 0, 0}, 1337.F}};
}
}
SECTION("Nearest") {
std::set<Interpolation::Point<float>> pointsToInterpolate;
std::vector<float> coordToInterpolate;
SECTION("1D") {
coordToInterpolate = {0.5F};
pointsToInterpolate =
std::set<Interpolation::Point<float>>{{{0}, 1.0F},
{{1}, 2.0F},
{{2}, 3.0F},
{{3}, 4.0F},
{{4}, 5.0F}};
SECTION("NearestFloor") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestFloor) == 1);
}
SECTION("NearestCeil") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestCeil) == 2);
}
SECTION("NearestRoundPreferFloor") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestRoundPreferFloor) == 1);
}
SECTION("NearestRoundPreferCeil") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestRoundPreferCeil) == 2);
}
}
SECTION("2D") {
coordToInterpolate = {2.5F, 3.97F};
pointsToInterpolate = {{{0, 0}, 10.0},
{{1, 1}, 20.0},
{{2, 3}, 30.0},
{{2, 4}, 40.0},
{{3, 3}, 50.0},
{{3, 4}, 60.0}};
SECTION("NearestFloor") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestFloor) == 30.);
}
SECTION("NearestCeil") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestCeil) == 60.);
}
SECTION("NearestRoundPreferFloor") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestRoundPreferFloor) ==
40.);
}
SECTION("NearestRoundPreferCeil") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestRoundPreferCeil) == 60.);
}
}
SECTION("3D") {
coordToInterpolate = {1.9, 2.1, 3.6};
pointsToInterpolate = {{{0, 0, 0}, 5.0},
{{1, 2, 3}, 10.0},
{{2, 1, 4}, 20.0},
{{2, 2, 4}, 30.0},
{{2, 3, 3}, 40.0},
{{2, 3, 4}, 50.0},
{{3, 3, 4}, 60.0}};
SECTION("NearestFloor") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestFloor) == 10.);
}
SECTION("NearestCeil") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestCeil) == 50.);
}
SECTION("NearestRoundPreferFloor") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestRoundPreferFloor) ==
30.);
}
SECTION("NearestRoundPreferCeil") {
CHECK(InterpolationCPU::nearest(
coordToInterpolate,
pointsToInterpolate,
Interpolation::Mode::NearestRoundPreferCeil) == 30.);
}
}
}
}
} // namespace Aidge
......@@ -25,7 +25,7 @@
namespace Aidge {
TEST_CASE("Test addition of Tensors","[TensorImpl][Add]") {
TEST_CASE("Test addition of Tensors","[TensorImpl][Add][Data]") {
constexpr std::uint16_t NBTRIALS = 10;
// Create a random number generator
std::random_device rd;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment