From cbc6472f191916712250044bd21998d875f93032 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 8 Dec 2023 23:00:59 +0100 Subject: [PATCH] Added half float (float16) support... yep, it's that easy! --- include/aidge/backend/cpu/data/TensorImpl.hpp | 7 +++++++ .../backend/cpu/operator/ConvImpl_forward_kernels.hpp | 4 ++++ unit_tests/scheduler/Test_Convert.cpp | 4 ++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 0b77f2e4..8b2987b3 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -3,6 +3,7 @@ #include "aidge/backend/TensorImpl.hpp" #include "aidge/data/Tensor.hpp" +#include "aidge/data/half.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" @@ -67,6 +68,10 @@ class TensorImpl_cpu : public TensorImpl { std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, static_cast<T *>(rawPtr())); } + else if (srcDt == DataType::Float16) { + std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length, + static_cast<T *>(rawPtr())); + } else if (srcDt == DataType::Int64) { std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, static_cast<T *>(rawPtr())); @@ -169,6 +174,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float64( {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); static Registrar<Tensor> registrarTensorImpl_cpu_Float32( {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create); +static Registrar<Tensor> registrarTensorImpl_cpu_Float16( + {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); static Registrar<Tensor> registrarTensorImpl_cpu_Int32( {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create); } // namespace diff --git a/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp index 03e2c351..1af64bf4 100644 --- a/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp @@ -14,6 +14,7 @@ #include "aidge/utils/Registrar.hpp" +#include "aidge/data/half.hpp" #include "aidge/backend/cpu/operator/ConvImpl.hpp" #include "aidge/utils/Types.h" #include <cmath> @@ -150,6 +151,9 @@ namespace { static Registrar<ConvImpl2DForward_cpu> registrarConvImpl2DForward_cpu_Float32( {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32}, Aidge::ConvImpl2D_cpu_forward_kernel<float, float, float, float>); +static Registrar<ConvImpl2DForward_cpu> registrarConvImpl2DForward_cpu_Float16( + {DataType::Float16, DataType::Float16, DataType::Float16, DataType::Float16}, + Aidge::ConvImpl2D_cpu_forward_kernel<half_float::half, half_float::half, half_float::half, half_float::half>); static Registrar<ConvImpl2DForward_cpu> registrarConvImpl2DForward_cpu_Int32( {DataType::Int32, DataType::Int32, DataType::Int32, DataType::Int32}, Aidge::ConvImpl2D_cpu_forward_kernel<int, int, int, int>); diff --git a/unit_tests/scheduler/Test_Convert.cpp b/unit_tests/scheduler/Test_Convert.cpp index df3db1c2..8a71ed35 100644 --- a/unit_tests/scheduler/Test_Convert.cpp +++ b/unit_tests/scheduler/Test_Convert.cpp @@ -188,7 +188,7 @@ TEST_CASE("[cpu/convert] Convert(forward)") { // input->addChild(g); g->setDataType(Aidge::DataType::Int32); g->getNode("conv1")->getOperator()->setDataType(DataType::Float32); - g->getNode("conv3")->getOperator()->setDataType(DataType::Float64); + g->getNode("conv3")->getOperator()->setDataType(DataType::Float16); explicitConvert(g); g->setBackend("cpu"); @@ -232,7 +232,7 @@ TEST_CASE("[cpu/convert] Convert(forward)") { std::shared_ptr<Tensor> other2 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv2")->getOperator())->getOutput(0); REQUIRE(approxEq<int>(*other2, *expectedOutput2, 0.0, 1.0e-12)); std::shared_ptr<Tensor> other3 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv3")->getOperator())->getOutput(0); - REQUIRE(approxEq<double, int>(*other3, *expectedOutput3, 0.0, 1.0e-12)); + REQUIRE(approxEq<half_float::half, int>(*other3, *expectedOutput3, 0.0, 1.0e-12)); std::shared_ptr<Tensor> other4 = std::static_pointer_cast<OperatorTensor>(g->getNode("fc")->getOperator())->getOutput(0); REQUIRE(approxEq<int>(*other4, expectedOutput4, 0.0, 1.0e-12)); } -- GitLab