From cbc6472f191916712250044bd21998d875f93032 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 8 Dec 2023 23:00:59 +0100
Subject: [PATCH] Added half float (float16) support... yep, it's that easy!

---
 include/aidge/backend/cpu/data/TensorImpl.hpp              | 7 +++++++
 .../backend/cpu/operator/ConvImpl_forward_kernels.hpp      | 4 ++++
 unit_tests/scheduler/Test_Convert.cpp                      | 4 ++--
 3 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp
index 0b77f2e4..8b2987b3 100644
--- a/include/aidge/backend/cpu/data/TensorImpl.hpp
+++ b/include/aidge/backend/cpu/data/TensorImpl.hpp
@@ -3,6 +3,7 @@
 
 #include "aidge/backend/TensorImpl.hpp"
 #include "aidge/data/Tensor.hpp"
+#include "aidge/data/half.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
 #include "aidge/utils/ErrorHandling.hpp"
@@ -67,6 +68,10 @@ class TensorImpl_cpu : public TensorImpl {
             std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
                     static_cast<T *>(rawPtr()));
         }
+        else if (srcDt == DataType::Float16) {
+            std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
+                    static_cast<T *>(rawPtr()));
+        }
         else if (srcDt == DataType::Int64) {
             std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
                     static_cast<T *>(rawPtr()));
@@ -169,6 +174,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float64(
         {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
 static Registrar<Tensor> registrarTensorImpl_cpu_Float32(
         {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create);
+static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
+        {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
 static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
         {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
 }  // namespace
diff --git a/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp
index 03e2c351..1af64bf4 100644
--- a/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp
@@ -14,6 +14,7 @@
 
 #include "aidge/utils/Registrar.hpp"
 
+#include "aidge/data/half.hpp"
 #include "aidge/backend/cpu/operator/ConvImpl.hpp"
 #include "aidge/utils/Types.h"
 #include <cmath>
@@ -150,6 +151,9 @@ namespace {
 static Registrar<ConvImpl2DForward_cpu> registrarConvImpl2DForward_cpu_Float32(
         {DataType::Float32, DataType::Float32, DataType::Float32, DataType::Float32},
         Aidge::ConvImpl2D_cpu_forward_kernel<float, float, float, float>);
+static Registrar<ConvImpl2DForward_cpu> registrarConvImpl2DForward_cpu_Float16(
+        {DataType::Float16, DataType::Float16, DataType::Float16, DataType::Float16},
+        Aidge::ConvImpl2D_cpu_forward_kernel<half_float::half, half_float::half, half_float::half, half_float::half>);
 static Registrar<ConvImpl2DForward_cpu> registrarConvImpl2DForward_cpu_Int32(
         {DataType::Int32, DataType::Int32, DataType::Int32, DataType::Int32},
         Aidge::ConvImpl2D_cpu_forward_kernel<int, int, int, int>);
diff --git a/unit_tests/scheduler/Test_Convert.cpp b/unit_tests/scheduler/Test_Convert.cpp
index df3db1c2..8a71ed35 100644
--- a/unit_tests/scheduler/Test_Convert.cpp
+++ b/unit_tests/scheduler/Test_Convert.cpp
@@ -188,7 +188,7 @@ TEST_CASE("[cpu/convert] Convert(forward)") {
         // input->addChild(g);
         g->setDataType(Aidge::DataType::Int32);
         g->getNode("conv1")->getOperator()->setDataType(DataType::Float32);
-        g->getNode("conv3")->getOperator()->setDataType(DataType::Float64);
+        g->getNode("conv3")->getOperator()->setDataType(DataType::Float16);
 
         explicitConvert(g);
         g->setBackend("cpu");
@@ -232,7 +232,7 @@ TEST_CASE("[cpu/convert] Convert(forward)") {
         std::shared_ptr<Tensor> other2 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv2")->getOperator())->getOutput(0);
         REQUIRE(approxEq<int>(*other2, *expectedOutput2, 0.0, 1.0e-12));
         std::shared_ptr<Tensor> other3 = std::static_pointer_cast<OperatorTensor>(g->getNode("conv3")->getOperator())->getOutput(0);
-        REQUIRE(approxEq<double, int>(*other3, *expectedOutput3, 0.0, 1.0e-12));
+        REQUIRE(approxEq<half_float::half, int>(*other3, *expectedOutput3, 0.0, 1.0e-12));
         std::shared_ptr<Tensor> other4 = std::static_pointer_cast<OperatorTensor>(g->getNode("fc")->getOperator())->getOutput(0);
         REQUIRE(approxEq<int>(*other4, expectedOutput4, 0.0, 1.0e-12));
     }
-- 
GitLab