diff --git a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp index 4e04b1a595a8660b1528e49921e7e3e7a567829a..a71174c03216dc04e27325d59062d0383f5224ea 100644 --- a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp +++ b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp @@ -18,12 +18,11 @@ #include "aidge/backend/cpu/operator/OperatorImpl.hpp" #include "aidge/operator/GlobalAveragePooling.hpp" #include "aidge/utils/Registrar.hpp" -#include "aidge/utils/Types.h" namespace Aidge { // Operator implementation entry point for the backend using GlobalAveragePoolingImpl_cpu = OperatorImpl_cpu<GlobalAveragePooling_Op, - void(const std::vector<DimSize_t> &, const void *, void *)>; + void(const std::shared_ptr<Tensor>&, void *)>; // Implementation entry point registration to Operator REGISTRAR(GlobalAveragePooling_Op, "cpu", Aidge::GlobalAveragePoolingImpl_cpu::create); diff --git a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp index 7a47ccf3be4de2ee066d2d7c27a6c04f115059b4..cbe4f110fc74f387625132c4f0872123814c1a62 100644 --- a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp @@ -12,92 +12,83 @@ #ifndef AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_KERNELS_H_ #define AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_KERNELS_H_ -#include <cstddef> -#include <functional> // std::multiplies -#include <numeric> // std::accumulate +#include <cstddef> // std::size_t #include <vector> #include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp" -#include "aidge/data/Data.hpp" -#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/data/Tensor.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" - namespace Aidge { template <typename T> -typename std::enable_if<std::is_floating_point<T>::value, T>::type -stableMean(const T* vec, size_t size) { - T mean = 0; - for (size_t i = 0; i < size; ++i) { - mean = std::fma<T>(vec[i] - mean, 1.0f / (i + 1), mean); - } - return mean; +typename std::enable_if_t<std::is_floating_point<T>::value, T> +static stableMean(const T* vec, std::size_t size) { + T mean{0}; + for (std::size_t i = 0; i < size; ++i) { + mean = std::fma(vec[i] - mean, static_cast<T>(1) / static_cast<T>(i + 1), mean); + } + return mean; } // Specialization for integers: perform the mean computation in float template <typename T> -typename std::enable_if<!std::is_floating_point<T>::value, double>::type -stableMean(const T* vec, size_t size) { - double mean = 0; - for (size_t i = 0; i < size; ++i) { - mean = std::fma<double>(vec[i] - mean, 1.0f / (i + 1), mean); - } - return mean; +typename std::enable_if_t<!std::is_floating_point<T>::value, double> +static stableMean(const T* vec, std::size_t size) { + double mean{0}; + for (std::size_t i = 0; i < size; ++i) { + mean = std::fma<double>(static_cast<double>(vec[i]) - mean, 1.0 / static_cast<double>(i + 1), mean); + } + return mean; } template <typename T> -typename std::enable_if<std::is_floating_point<T>::value, T>::type -castFromFloat(T value) { - return value; +typename std::enable_if_t<std::is_floating_point<T>::value, T> +static castFromFloat(T value) { + return value; } template <typename T> -typename std::enable_if<!std::is_floating_point<T>::value, T>::type -castFromFloat(double value) { - return static_cast<T>(std::nearbyint(value)); +typename std::enable_if_t<!std::is_floating_point<T>::value, T> +static castFromFloat(double value) { + return static_cast<T>(std::nearbyint(value)); } -template <class I, class O> -void GlobalAveragePoolingImpl_cpu_forward_kernel( - const std::vector<DimSize_t> &dims, const void *input_, void *output_) { - // error checking - AIDGE_ASSERT(dims.size() >= 3,"GlobalAveragePool needs at least a 3 dimensions " - "input, number of input dim : {}", - dims.size()); +template <DataType DT_I, DataType DT_O = DT_I> +void GlobalAveragePoolingImpl_cpu_forward_kernel(const std::shared_ptr<Tensor>& inputTensor, void *output_) { - // computation - const I *input = static_cast<const I *>(input_); - O *output = static_cast<O *>(output_); + // computation + using I = cpptype_t<DT_I>; + using O = cpptype_t<DT_O>; + const I *input = static_cast<const I *>(inputTensor->getImpl()->rawPtr()); + O *output = static_cast<O *>(output_); + const auto& dims = inputTensor->dims(); - DimSize_t nb_elems = std::accumulate(dims.begin(), dims.end(), std::size_t(1), - std::multiplies<std::size_t>()); + const DimSize_t strides_channels = inputTensor->strides()[1]; - const DimSize_t in_batch_nb_elems{nb_elems / dims[0]}; - const DimSize_t in_channel_nb_elems{in_batch_nb_elems / dims[1]}; - const DimSize_t out_batch_nb_elems{dims[1]}; - // parse channel by channel and fill each output with the average of the - // values in the channel - for (DimSize_t batch = 0; batch < dims[0]; ++batch) { - for (DimSize_t channel = 0; channel < dims[1]; ++channel) { - const I *filter_start = std::next( - input, (batch * in_batch_nb_elems) + (channel * in_channel_nb_elems)); - output[batch * out_batch_nb_elems + channel] = castFromFloat<O>(stableMean<I>(filter_start, in_channel_nb_elems)); + // parse channel by channel and fill each output with the average of the + // values in the channel + std::size_t input_idx = 0; + std::size_t output_idx = 0; + for (DimSize_t batch = 0; batch < dims[0]; ++batch) { + for (DimSize_t channel = 0; channel < dims[1]; ++channel) { + output[output_idx++] = castFromFloat<O>(stableMean<I>(input + input_idx, strides_channels)); + input_idx += strides_channels; + } } - } } // Kernels registration to implementation entry point REGISTRAR(GlobalAveragePoolingImpl_cpu, {DataType::Float32}, - {ProdConso::defaultModel, Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<float, float>, nullptr}); + {ProdConso::defaultModel, Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<DataType::Float32>, nullptr}); REGISTRAR(GlobalAveragePoolingImpl_cpu, {DataType::Float64}, - {ProdConso::defaultModel, Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<double, double>, nullptr}); + {ProdConso::defaultModel, Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<DataType::Float64>, nullptr}); REGISTRAR(GlobalAveragePoolingImpl_cpu, {DataType::Int32}, - {ProdConso::defaultModel, Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<int32_t, int32_t>, nullptr}); + {ProdConso::defaultModel, Aidge::GlobalAveragePoolingImpl_cpu_forward_kernel<DataType::Int32>, nullptr}); } // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_GLOBALAVERAGEPOOLINGIMPL_KERNELS_H_ */ diff --git a/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp b/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp index 1c50805d5af768dfc160488fda1e8fadfa798454..d6c60c352dc862095bad9ac67ab50d05129b8dc2 100644 --- a/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp @@ -12,7 +12,6 @@ #ifndef AIDGE_CPU_OPERATOR_REDUCEMEANIMPL_H_ #define AIDGE_CPU_OPERATOR_REDUCEMEANIMPL_H_ -#include <array> #include <memory> #include <tuple> #include <vector> diff --git a/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp index a156232230bd6e9be496eec2b76ccf8fcab4d9e9..73aa283d51d72e28d135ae5bb422f3f9f8dcd8c6 100644 --- a/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp @@ -25,39 +25,39 @@ #include "aidge/utils/Registrar.hpp" namespace Aidge { - + template <typename T> -using Acc_T = typename std::conditional<std::is_floating_point<T>::value, T, double>::type; +using Acc_T = typename std::conditional_t<std::is_floating_point<T>::value, T, double>; template <typename T> typename std::enable_if<std::is_floating_point<T>::value, T>::type -stableMean(const T* vec, size_t len, size_t stride) { +stableMean(const T* vec, std::size_t len, std::size_t stride) { T mean = 0; - for (size_t i = 0; i < len; ++i) { - mean = std::fma<T>(vec[i * stride] - mean, 1.0f / (i + 1), mean); + for (std::size_t i = 0; i < len; ++i) { + mean = std::fma(vec[i * stride] - mean, static_cast<T>(1) / static_cast<T>(i + 1), mean); } return mean; } // Specialization for integers: perform the mean computation in float template <typename T> -typename std::enable_if<!std::is_floating_point<T>::value, double>::type -stableMean(const T* vec, size_t len, size_t stride) { +typename std::enable_if_t<!std::is_floating_point<T>::value, double> +stableMean(const T* vec, std::size_t len, std::size_t stride) { double mean = 0; for (size_t i = 0; i < len; ++i) { - mean = std::fma<double>(vec[i * stride] - mean, 1.0f / (i + 1), mean); + mean = std::fma<double>(static_cast<double>(vec[i * stride]) - mean, 1.0 / static_cast<double>(i + 1), mean); } return mean; } template <typename T> -typename std::enable_if<std::is_floating_point<T>::value, T>::type +typename std::enable_if_t<std::is_floating_point<T>::value, T> castFromFloat(T value) { return value; } template <typename T> -typename std::enable_if<!std::is_floating_point<T>::value, T>::type +typename std::enable_if_t<!std::is_floating_point<T>::value, T> castFromFloat(double value) { return static_cast<T>(std::nearbyint(value)); } diff --git a/src/operator/GlobalAveragePoolingImpl.cpp b/src/operator/GlobalAveragePoolingImpl.cpp index c53f92e199aee30d55ddafe39b5ef121979acbf7..1b6d9a0629d856c2ad1fc3eae35db4c12058bc4f 100644 --- a/src/operator/GlobalAveragePoolingImpl.cpp +++ b/src/operator/GlobalAveragePoolingImpl.cpp @@ -30,13 +30,15 @@ void Aidge::GlobalAveragePoolingImpl_cpu::forward() const GlobalAveragePooling_Op& op_ = static_cast<const GlobalAveragePooling_Op&>(mOp); // Check if input is provided AIDGE_ASSERT(op_.getInput(0), "missing input 0"); + // error checking + AIDGE_ASSERT(op_.getInput(0)->nbDims() >= 3,"GlobalAveragePool needs at least a 3 dimensions " + "input. Got input dims {}", op_.getInput(0)->dims()); // Find the correct kernel type const auto impl = Registrar<GlobalAveragePoolingImpl_cpu>::create(getBestMatch(getRequiredSpec())); // Call kernel - impl.forward(op_.getInput(0)->dims(), - op_.getInput(0)->getImpl()->rawPtr(), + impl.forward(op_.getInput(0), op_.getOutput(0)->getImpl()->rawPtr()); } diff --git a/unit_tests/data/Test_Interpolation.cpp b/unit_tests/data/Test_Interpolation.cpp index 5c3b56f02ab17092a6ba238cc74e1bf75e203718..4886885d7d979c7ea4aaa70a33d75cb553b361de 100644 --- a/unit_tests/data/Test_Interpolation.cpp +++ b/unit_tests/data/Test_Interpolation.cpp @@ -9,15 +9,21 @@ * ********************************************************************************/ -#include <aidge/backend/cpu/data/Interpolation.hpp> -#include <aidge/data/Interpolation.hpp> -#include <aidge/data/Tensor.hpp> -#include <aidge/filler/Filler.hpp> -#include <aidge/utils/Types.h> -#include <catch2/catch_test_macros.hpp> +#include <cmath> // std::fabs +#include <cstdlib> // std::abs #include <limits> +#include <memory> +#include <set> +#include <vector> + +#include <catch2/catch_test_macros.hpp> #include "aidge/backend/cpu/data/Interpolation.hpp" +#include "aidge/data/Interpolation.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/filler/Filler.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/TensorUtils.hpp" namespace Aidge { @@ -30,12 +36,12 @@ TEST_CASE("Interpolation", "[Interpolation][Data]") { SECTION("1D") { pointsToInterpolateInt = std::set<Interpolation::Point<int>>({{{0}, 10}, {{1}, 20}}); - CHECK(abs(InterpolationCPU::linear({0.5}, pointsToInterpolateInt) - + REQUIRE(std::abs(InterpolationCPU::linear({0.5}, pointsToInterpolateInt) - 15) <= std::numeric_limits<int>::epsilon()); pointsToInterpolateFloat = std::set<Interpolation::Point<float>>( {{{0}, .0F}, {{1}, 0.2F}}); - CHECK(fabs(InterpolationCPU::linear({0.3}, + REQUIRE(std::fabs(InterpolationCPU::linear({0.3}, pointsToInterpolateFloat) - .06F) <= 1e-5); } @@ -46,21 +52,21 @@ TEST_CASE("Interpolation", "[Interpolation][Data]") { {{14, 21}, 162.F}, {{15, 20}, 210.F}, {{15, 21}, 95.F}}; - CHECK(fabs(InterpolationCPU::linear<float>( - {14.5F, 20.2F}, - pointsToInterpolateFloat) - - 146.1) < 1e-5); + const Tensor interpolatedValue = Tensor(std::fabs(InterpolationCPU::linear<float>( + {14.5F, 20.2F}, + pointsToInterpolateFloat))); + REQUIRE(approxEq<float>(interpolatedValue, Tensor(146.1f))); // pointsToInterpolateFloat = {{{0, 0}, .10F}, // {{0, 1}, .20F}, // {{1, 0}, .30F}, // {{1, 1}, .40F}}; - // CHECK(abs(InterpolationCPU::linear<float>({1.5, 0.5}, + // REQUIRE(std::abs(InterpolationCPU::linear<float>({1.5, 0.5}, // pointsToInterpolateInt) // - // 25) < std::numeric_limits<int>::epsilon()); // pointsToInterpolateFloat = std::vector({0.1F, 0.2F, 0.3F, - // 0.4F}); CHECK(InterpolationCPU::linear(pointsToInterpolateFloat) + // 0.4F}); REQUIRE(InterpolationCPU::linear(pointsToInterpolateFloat) // == .25f); } SECTION("3D") { @@ -72,7 +78,7 @@ TEST_CASE("Interpolation", "[Interpolation][Data]") { {{1, 0, 1}, .6F}, {{1, 1, 0}, .7F}, {{1, 1, 1}, .8F}}; - CHECK(fabs(InterpolationCPU::linear({.5, .5, .5}, + REQUIRE(std::fabs(InterpolationCPU::linear({.5, .5, .5}, pointsToInterpolateFloat) - .45f) < 1e-5); } @@ -94,7 +100,7 @@ TEST_CASE("Interpolation", "[Interpolation][Data]") { {{1, 1, 0, 1}, 1.4F}, {{1, 1, 1, 0}, 1.5F}, {{1, 1, 1, 1}, 1.6F}}; - CHECK(fabs(InterpolationCPU::linear<float>( + REQUIRE(std::fabs(InterpolationCPU::linear<float>( {.5, .5, .5, .5}, pointsToInterpolateFloat) - .85f) < 0.0001); @@ -139,25 +145,25 @@ TEST_CASE("Interpolation", "[Interpolation][Data]") { {{4}, 5.0F}}; SECTION("Floor") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::Floor) == 1); } SECTION("Ceil") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::Ceil) == 2); } SECTION("RoundPreferFloor") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::RoundPreferFloor) == 1); } SECTION("RoundPreferCeil") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::RoundPreferCeil) == 2); @@ -172,26 +178,26 @@ TEST_CASE("Interpolation", "[Interpolation][Data]") { {{3, 3}, 50.0}, {{3, 4}, 60.0}}; SECTION("Floor") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::Floor) == 30.); } SECTION("Ceil") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::Ceil) == 60.); } SECTION("RoundPreferFloor") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::RoundPreferFloor) == 40.); } SECTION("RoundPreferCeil") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::RoundPreferCeil) == 60.); @@ -207,26 +213,26 @@ TEST_CASE("Interpolation", "[Interpolation][Data]") { {{2, 3, 4}, 50.0}, {{3, 3, 4}, 60.0}}; SECTION("Floor") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::Floor) == 10.); } SECTION("Ceil") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::Ceil) == 50.); } SECTION("RoundPreferFloor") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::RoundPreferFloor) == 30.); } SECTION("RoundPreferCeil") { - CHECK(InterpolationCPU::nearest( + REQUIRE(InterpolationCPU::nearest( coordToInterpolate, pointsToInterpolate, Interpolation::Mode::RoundPreferCeil) == 30.); diff --git a/unit_tests/operator/Test_ReduceMeanImpl.cpp b/unit_tests/operator/Test_ReduceMeanImpl.cpp index 30ffeb0dd0b584f50349c206863c7ab9ac776721..8841d6773dc5ce793ca75244fedc18fdf245ca26 100644 --- a/unit_tests/operator/Test_ReduceMeanImpl.cpp +++ b/unit_tests/operator/Test_ReduceMeanImpl.cpp @@ -156,7 +156,7 @@ TEST_CASE("[cpu/operator] ReduceMean(forward)", "[ReduceMean][CPU]") { } SECTION("KeepDims") { SECTION("test 1") { - std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,3,2,2> { + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<cpptype_t<DataType::Float32>,3,2,2> { { { { 5.0, 1.0 }, @@ -172,12 +172,12 @@ TEST_CASE("[cpu/operator] ReduceMean(forward)", "[ReduceMean][CPU]") { } } }); - Tensor myOutput = Tensor(Array3D<float,3,1,2> { + Tensor myOutput = Tensor(Array3D<cpptype_t<DataType::Float32>,3,1,2> { { - {{ 12.5, 1.5 }}, - {{ 35.0, 1.5 }}, - {{ 57.5, 1.5 }} + {{ 12.5f, 1.5f }}, + {{ 35.0f, 1.5f }}, + {{ 57.5f, 1.5f }} } }); diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index be87e8ac77020b5c05469fb959752a66512e6ffb..eed4185d7ac98107f6811f38d7f37851cb6801af 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -482,7 +482,7 @@ TEST_CASE("[cpu/scheduler] Accumulate", "[scheduler]") { {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}}); std::shared_ptr<Tensor> MemInit = - std::make_shared<Tensor>(Array2D<float, 3, 2>{ + std::make_shared<Tensor>(Array2D<cpptype_t<DataType::Float32>, 3, 2>{ {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}); auto meta = Accumulate(2, "accumulate"); @@ -517,14 +517,14 @@ TEST_CASE("[cpu/scheduler] Accumulate", "[scheduler]") { REQUIRE_NOTHROW(scheduler.forward(true)); std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>( - Array2D<float, 3, 2>{{{3.0, 5.0}, {7.0, 9.0}, {11.0, 13.0}}}); + Array2D<cpptype_t<DataType::Float32>, 3, 2>{{{3.0, 5.0}, {7.0, 9.0}, {11.0, 13.0}}}); std::shared_ptr<Tensor> output = std::static_pointer_cast<OperatorTensor>(pop_o->getOperator())->getOutput(0); REQUIRE(*output == *expectedOutput); } TEST_CASE("[cpu/scheduler] Branch", "[scheduler]") { std::shared_ptr<Tensor> in = std::make_shared<Tensor>( - Array2D<float, 2, 3>{{{1, 2, 3}, {4, 5, 6}}}); + Array2D<cpptype_t<DataType::Float32>, 2, 3>{{{1, 2, 3}, {4, 5, 6}}}); std::shared_ptr<GraphView> g = Sequential({ Producer(in, "input"), @@ -576,7 +576,7 @@ TEST_CASE("[cpu/scheduler] Branch", "[scheduler]") { #ifdef WITH_OPENSSL TEST_CASE("[cpu/scheduler] Select", "[scheduler]") { std::shared_ptr<Tensor> in = std::make_shared<Tensor>( - Array2D<float, 2, 3>{{{1, 2, 3}, {4, 5, 6}}}); + Array2D<cpptype_t<DataType::Float32>, 2, 3>{{{1, 2, 3}, {4, 5, 6}}}); std::shared_ptr<GraphView> g = Sequential({ Producer(in, "input"), @@ -605,21 +605,21 @@ TEST_CASE("[cpu/scheduler] Select", "[scheduler]") { scheduler.generateScheduling(); scheduler.saveStaticSchedulingDiagram("select_scheduling"); REQUIRE_NOTHROW(scheduler.forward(true)); - + g->save("select_forwarded"); auto expectedOutputHash = std::make_shared<Tensor>( - Array1D<uint64_t, 4>{{0x1b7cf58dfe2dae24, 0x3bac903def4ce580, 0x5f5a347389d97f41, 0x2c2dc759abc6b61}}); + Array1D<cpptype_t<DataType::UInt64>, 4>{{0x1b7cf58dfe2dae24, 0x3bac903def4ce580, 0x5f5a347389d97f41, 0x2c2dc759abc6b61}}); auto outputHash = std::static_pointer_cast<OperatorTensor>(g->getNode("hash")->getOperator())->getOutput(0); REQUIRE(*outputHash == *expectedOutputHash); auto expectedOutputMod = std::make_shared<Tensor>( - Array1D<uint64_t, 4>{{2, 1, 1, 2}}); + Array1D<cpptype_t<DataType::UInt64>, 4>{{2, 1, 1, 2}}); auto outputMod = std::static_pointer_cast<OperatorTensor>(g->getNode("mod")->getOperator())->getOutput(0); REQUIRE(*outputMod == *expectedOutputMod); auto expectedOutput = std::make_shared<Tensor>( - Array2D<float, 2, 3>{{{std::sqrt(1), std::sqrt(2), std::sqrt(3)}, {std::sqrt(4), std::sqrt(5), std::sqrt(6)}}}); + Array2D<cpptype_t<DataType::Float32>, 2, 3>{{{std::sqrt(1.0f), std::sqrt(2.0f), std::sqrt(3.0f)}, {std::sqrt(4.0f), std::sqrt(5.0f), std::sqrt(6.0f)}}}); auto output = std::static_pointer_cast<OperatorTensor>(g->getNode("select")->getOperator())->getOutput(0); REQUIRE(*output == *expectedOutput);