Skip to content
Snippets Groups Projects
Commit 1d2d8d00 authored by Grégoire Kubler's avatar Grégoire Kubler Committed by Maxence Naud
Browse files

feat : Interpolation class implemented

The interpolate fucntion is just a placeholder to reminder that its implementation is backend dependant.
parent ee26ba87
No related branches found
No related tags found
2 merge requests!279v0.4.0,!242Extends the functionalities of Resize Operator
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_UTILS_INTERPOLATION_H_
#define AIDGE_CORE_UTILS_INTERPOLATION_H_
#include <cstdint>
#include <vector>
#include "aidge/operator/Pad.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/* @brief generic class to hold interpolation */
class Interpolation {
public:
/**
* @brief simple type alias to describe a coordinates
* @note the indexes are deliberately chosen to be signed values as some
* points retrieved by interpolation are out of bound, hence their coords
* can be < 0
*/
using Coords = std::vector<int64_t>;
/**
* @brief type alias to designate a point of any type : hence coordinates &
* associated value
*/
template <class T> using Point = std::pair<Coords, T>;
/**
* @brief details how coordinates are transformed from interpolated tensor
* to original tensor
*/
enum CoordinateTransformation {
HalfPixel,
HalfPixelSymmetric,
PytorchHalfPixel,
AlignCorners,
Asymmetric,
};
/**
* @brief apply transformation to coords in interpolated Tensor to find
* equivalent coordinates in original tensor reference frame.
* @warning it is assumed that all parameters have the same
* number of dimensions.
* @param[in] transformedCoords : coords in interpolated tensor
* @param[in] inputDims: input dimensions of tensor
* @param[in] inputDims: output dimensions of tensor
* @return std::vector containing coords in orginal tensor reference frame
*/
static std::vector<float> untransformCoordinates(
const std::vector<DimSize_t> &transformedCoords,
const std::vector<DimSize_t> &inputDims,
const std::vector<DimSize_t> &outputDims,
const Interpolation::CoordinateTransformation coordTransfoMode);
/**
* @brief retrieves neighbouring value of a given index
* @param[in] tensorValues raw pointer of the tensor values
* retrieved with
* @code
* tensor->getImpl()->rawPtr()
* @endcode
* @param[in] tensorDimensions dimensions of given tensor
* retrieved with
* @code
* tensor->dims()
* @endcode
* @param[in] coords coordinates in the tensor of the values we want to
* find the neighbours of.
* @return static std::vector<std::pair<std::vector<DimSize_t>, T>>
* containing both indexes of neighbours & their values
*/
template <typename T>
static std::set<Point<T>>
retrieveNeighbours(const T *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode = PadBorderType::Zero);
/* @brief interpolation type */
enum Mode {
Cubic,
Linear,
NearestRoundPreferFloor,
NearestRoundPreferCeil,
NearestFloor,
NearestCeil
};
/*
* @brief Interpolates values given via input in given mode.
*
* @warning This function is empty and is meant to be overriden in derived
* class in backend libraries.
*
* Values are contiguously arranged in a "square" shape around the point to
* interpolate. Depending on interpolation mode.
* The point that will be interpolated is located right in the
* middle of all points.
* Immediate neighbours :
* 1D interp : 2D interp :
* . . . . . .
* . . 1 2 . . . . . . . .
* . . 1 2 . .
* . . 3 4 . .
* . . . . . .
* . . . . . .
*
* 2 neighbours :
* 1D interp : 2D interp :
* . . . . . . . .
* . . . . . . . .
* . . 1 2 3 4 . . . . 1 2 3 4 . .
* . . 5 6 7 8 . .
* . . 9 10 11 12 . .
* . . 13 14 15 16 . .
* . . . . . . . .
* . . . . . . . .
*
* @param[in] originalIndex: index of the point to in the original picture
* Since the coord are being transformed from the interpolatedTensor frame
* to originalTensor frame, the result might be in float.
* @param[in] points : points to interpolate, arranged in a vector of a
* pairs ((point_coord), value) :
* [[[X1, X2, ..., XN], Xval], ...., [[A1, A2, ..., AN],Aval]].
* With :
* - N: the number of dimensions.
* - A: the number of points of the grid to interpolate.
* - All coordinates expressed in originalTensor frame.
* @param[in] interpMode: interpolation mode
* @return interpolated value
*/
template <typename T>
[[noreturn]] static T interpolate(const std::vector<float> &originalIndex,
const std::vector<Point<T>> &points,
const Mode interpMode);
};
} // namespace Aidge
#endif
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/data/Interpolation.hpp"
#include <algorithm>
#include <bitset>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/data/half.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
template <typename T>
[[noreturn]] T
Interpolation::interpolate(const std::vector<float> & /*originalIndex*/,
const std::vector<Point<T>> & /*points*/,
const Mode /*interpMode*/) {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"interpolate() is backend dependendant and should be"
"called from derived classes: Interpolation<Backend>::interpolate(...)"
"Meaning that for CPU backend, InterpolationCPU::interpolate() should "
"be called.");
}
std::vector<float> Interpolation::untransformCoordinates(
const std::vector<DimSize_t> &transformedCoords,
const std::vector<DimSize_t> &inputDims,
const std::vector<DimSize_t> &outputDims,
const Interpolation::CoordinateTransformation coordTransfoMode) {
AIDGE_ASSERT(
inputDims.size() == outputDims.size(),
"Interpolate::untransformCoordinates: input and output coordinates "
"dimension number mismatch, they should be equal."
"Got inputDims({}) and outputDims ({}).",
inputDims,
outputDims);
AIDGE_ASSERT(
transformedCoords.size() == outputDims.size(),
"Interpolate::untransformCoordinates: coordinates dimension mismatch, "
"transformed coords number should be equal to output dimension number."
"Got coords to transform ({}) and outputDims ({})",
transformedCoords,
outputDims);
std::vector<float> originalCoords;
originalCoords.resize(transformedCoords.size());
for (DimIdx_t i = 0; i < transformedCoords.size(); ++i) {
float scale = static_cast<float>(outputDims[i]) /
static_cast<float>(inputDims[i]);
switch (coordTransfoMode) {
case CoordinateTransformation::AlignCorners:
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"Interpolation::untransformCoords: Unsupported Coordinate "
"transform : AlignCorners");
break;
case CoordinateTransformation::Asymmetric:
originalCoords[i] = transformedCoords[i] / scale;
break;
case CoordinateTransformation::HalfPixel:
originalCoords[i] = (transformedCoords[i] + 0.5) / scale - 0.5;
break;
case CoordinateTransformation::HalfPixelSymmetric:
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"Interpolation::untransformCoords: Unsupported Coordinate "
"transform : HalfPixelSymmetric");
break;
case Interpolation::CoordinateTransformation::PytorchHalfPixel:
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"Interpolation::untransformCoords: Unsupported Coordinate "
"transform : PytorchHalfPixel");
break;
}
}
return originalCoords;
}
/**
* @details Generates a list of all neighbours of a given coordinate.
* Since the coordinates are floating points as they are the result of
* Interpolation::untransformCoords, they are approximation of coordinates in
* originalTensor frame from coordinates in interpolatedTensor frame.
*
* So to retrieve the neghbouring values, we must apply either floor() or
* ceil() to each coordinate.
*
* In order to generate the list of all combinations
* available, we simply iterate through the bits of each values from 0 to
* tensorDims.
* @example : in 2 dimensions , we have the point (1.3, 3.4)
* we iterate up to 2^2 - 1 and
* 0 = 0b00 -> (floor(x), floor(y)) = (1,3)
* 1 = 0b01 -> (floor(x), ceil(y)) = (1,4)
* 2 = 0b10 -> (ceil(x) , floor(y)) = (2,3)
* 3 = 0b11 -> (ceil(x) , ceil(y)) = (2,4)
*/
template <typename T>
std::set<Interpolation::Point<T>>
Interpolation::retrieveNeighbours(const T *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode) {
Log::debug("retrieveNeighbours: TensorDims : {}", tensorDims);
Log::debug("retrieveNeighbours: coords to interpolate : {}", coords);
// Will retrieve out of bound values depending on given padding mode.
auto retrieveOutOfBoundValue =
[&tensorValues, &tensorDims, &paddingMode](Coords coord) -> T {
std::vector<DimSize_t> rectifiedCoord;
rectifiedCoord.reserve(coord.size());
switch (paddingMode) {
case Aidge::PadBorderType::Nearest: {
for (DimSize_t i = 0; i < coord.size(); ++i) {
rectifiedCoord[i] =
std::clamp<int64_t>(coord[i],
0,
static_cast<int64_t>(tensorDims[i]));
}
return tensorValues[Tensor::getIdx(tensorDims, rectifiedCoord)];
}
case Aidge::PadBorderType::Zero: {
return static_cast<T>(0);
}
default: {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"Unsupported padding mode as of now for interpolation.");
}
}
};
std::set<Point<T>> neighbours;
std::bitset<MaxDim> bits;
DimSize_t nbNeighbours = pow(2, tensorDims.size());
Coords neighbourCoords;
neighbourCoords.resize(tensorDims.size());
for (DimSize_t i = 0; i < nbNeighbours; ++i) {
bits = std::bitset<MaxDim>{i};
for (size_t j = 0; j < tensorDims.size(); ++j) {
neighbourCoords[j] =
bits[j] == 0 ? ceil(coords[j]) : std::floor(coords[j]);
}
T value;
if (Tensor::isInBounds(tensorDims, neighbourCoords)) {
// cast from unsigned to signed won't create problem as we ensured
// that all neighboursCoords values are > 0 with isInBounds
value = tensorValues[Tensor::getIdx(
tensorDims,
std::vector<DimSize_t>(neighbourCoords.begin(),
neighbourCoords.end()))];
} else {
value = retrieveOutOfBoundValue(neighbourCoords);
}
neighbours.insert(std::make_pair(neighbourCoords, value));
}
Log::debug("Interpolation::retrieveNeighbours(): neighbourCoords: {}",
neighbours);
return neighbours;
}
template std::set<Interpolation::Point<int16_t>>
Interpolation::retrieveNeighbours(const int16_t *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode);
template std::set<Interpolation::Point<int32_t>>
Interpolation::retrieveNeighbours(const int32_t *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode);
template std::set<Interpolation::Point<int64_t>>
Interpolation::retrieveNeighbours(const int64_t *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode);
template std::set<Interpolation::Point<half_float::half>>
Interpolation::retrieveNeighbours(const half_float::half *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode);
template std::set<Interpolation::Point<float>>
Interpolation::retrieveNeighbours(const float *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode);
template std::set<Interpolation::Point<double>>
Interpolation::retrieveNeighbours(const double *tensorValues,
const std::vector<DimSize_t> &tensorDims,
const std::vector<float> &coords,
const PadBorderType paddingMode);
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstdlib>
#include <memory>
#include <string>
#include <vector>
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators_random.hpp>
#include "aidge/data/Data.hpp"
#include "aidge/data/Interpolation.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
TEST_CASE("[core/data] Interpolation", "[Interpolation][Data]") {
Log::setConsoleLevel(Log::Debug);
auto tensor = std::make_shared<Tensor>(std::vector<DimSize_t>({10, 10}));
tensor->setDataType(DataType::Float32);
tensor->setBackend("cpu");
Aidge::constantFiller(tensor, 1337.F);
SECTION("retrieveNeighbours") {
std::set<Interpolation::Point<float>> neighbours;
std::set<Interpolation::Point<float>> expectedResult;
std::vector<float> coords;
SECTION("Out of bounds") {
coords = {-0.5, -0.5};
expectedResult = {{{-1, -1}, 0.f},
{{0, -1}, 0.F},
{{-1, 0}, 0.F},
{{0, 0}, 1337.F}};
neighbours = Interpolation::retrieveNeighbours<float>(
reinterpret_cast<float *>(tensor->getImpl()->rawPtr()),
tensor->dims(),
coords,
PadBorderType::Zero);
CHECK(neighbours == expectedResult);
}
SECTION("Some coords are rounds hence duplicates are filtered out") {
tensor = std::make_shared<Tensor>(
std::vector<DimSize_t>({5, 10, 10, 10}));
tensor->setDataType(DataType::Float32);
tensor->setBackend("cpu");
Aidge::constantFiller(tensor, 1337.F);
expectedResult = {{{0, 0, -1, -1}, 0.F},
{{0, 0, 0, -1}, 0.F},
{{0, 0, -1, 0}, 0.F},
{{0, 0, 0, 0}, 1337.F}};
neighbours = Interpolation::retrieveNeighbours(
reinterpret_cast<float *>(tensor->getImpl()->rawPtr()),
tensor->dims(),
std::vector<float>({0, 0, -0.25, -0.25}));
CHECK(expectedResult == neighbours);
}
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment