From d4826e7e2b83e1469a767bc2ed5dadce6435de8f Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 6 Jun 2024 11:13:58 +0200
Subject: [PATCH] Added initial support for data format

---
 include/aidge/data/Data.hpp               | 28 ++++++++++++++++
 include/aidge/data/Tensor.hpp             | 40 ++++++++++++++++++++---
 include/aidge/operator/Operator.hpp       |  1 +
 include/aidge/operator/OperatorTensor.hpp |  1 +
 src/operator/OperatorTensor.cpp           | 11 +++++++
 5 files changed, 76 insertions(+), 5 deletions(-)

diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp
index 2752ec484..219a37da3 100644
--- a/include/aidge/data/Data.hpp
+++ b/include/aidge/data/Data.hpp
@@ -16,6 +16,7 @@
 #include <fmt/format.h>
 #include <string>
 #include <tuple>
+#include <array>
 
 #include "aidge/data/half.hpp"
 #include "aidge/utils/Attributes.hpp"
@@ -50,6 +51,28 @@ enum class DataType {
     UInt64
 };
 
+enum class DataFormat {
+    Default,
+    NCHW,
+    NHWC,
+    CHWN,
+    NCDHW,
+    NDHWC,
+    CDHWN
+};
+
+constexpr std::array<std::array<size_t, 5>, 7> DataFormatTranspose = {{
+    // Important: in this array only, dimension index must start at 1, not 0!
+    // (0 is the default value)
+    {},
+    {1, 2, 3, 4},
+    {1, 3, 4, 2},
+    {2, 3, 4, 1},
+    {1, 2, 3, 4, 5},
+    {1, 3, 4, 5, 2},
+    {2, 3, 4, 5, 1}
+}};
+
 class Data {
 public:
     Data(const std::string& type): mType(type) {};
@@ -85,6 +108,10 @@ const char* const EnumStrings<Aidge::DataType>::data[]
        "Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6",
        "UInt7", "UInt8", "UInt16", "UInt32", "UInt64"};
 
+template <>
+const char* const EnumStrings<Aidge::DataFormat>::data[]
+    = {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN"};
+
 template <Aidge::DataType D> struct cpptype {
     using type = void; // Placeholder
 };
@@ -106,6 +133,7 @@ template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type;
 
 namespace Aidge {
 inline auto format_as(DataType dt) { return EnumStrings<Aidge::DataType>::data[static_cast<int>(dt)]; }
+inline auto format_as(DataFormat df) { return EnumStrings<Aidge::DataFormat>::data[static_cast<int>(df)]; }
 }
 
 #endif /* AIDGE_DATA_H_ */
diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 3dbf54a5f..a442af8e3 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -42,6 +42,7 @@ class Tensor : public Data,
                public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)> {
    private:
     DataType mDataType = DataType::Float32; /** enum to specify data type. */
+    DataFormat mDataFormat = DataFormat::Default; /** enum to specify data format. */
     std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
     std::vector<DimSize_t> mStrides; /** Stride dimensions of the tensor. */
     std::shared_ptr<TensorImpl> mImpl = nullptr; /** Pointer to the actual data implementation. */
@@ -61,9 +62,10 @@ class Tensor : public Data,
      * @brief Construct a new empty Tensor object.
      * It has the features of an undefined scalar.
      */
-    Tensor(DataType dtype = DataType::Float32)
+    Tensor(DataType dtype = DataType::Float32, DataFormat dformat = DataFormat::Default)
         : Data(Type),
           mDataType(dtype),
+          mDataFormat(dformat),
           mDims(std::vector<DimSize_t>({})),
           mStrides({1}),
           mSize(1)
@@ -83,6 +85,7 @@ class Tensor : public Data,
     Tensor(T val)
         : Data(Type),
           mDataType(NativeType<VT>::type),
+          mDataFormat(DataFormat::Default),
           mDims({}),
           mStrides({1}),
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<VT>::type})(0, std::vector<std::size_t>())),
@@ -112,6 +115,7 @@ class Tensor : public Data,
     constexpr Tensor(Array1D<T, SIZE_0> &&arr)
         : Data(Type),
           mDataType(NativeType<T>::type),
+          mDataFormat(DataFormat::Default),
           mDims({SIZE_0}),
           mStrides({1}),
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0})),
@@ -130,6 +134,7 @@ class Tensor : public Data,
     constexpr Tensor(Array2D<T, SIZE_0, SIZE_1> &&arr)
         : Data(Type),
           mDataType(NativeType<T>::type),
+          mDataFormat(DataFormat::Default),
           mDims({SIZE_0, SIZE_1}),
           mStrides({SIZE_1, 1}),
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0, SIZE_1})),
@@ -148,6 +153,7 @@ class Tensor : public Data,
     constexpr Tensor(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr)
         : Data(Type),
           mDataType(NativeType<T>::type),
+          mDataFormat(DataFormat::Default),
           mDims({SIZE_0, SIZE_1, SIZE_2}),
           mStrides({SIZE_1 * SIZE_2, SIZE_2, 1}),
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0, SIZE_1, SIZE_2})),
@@ -167,6 +173,7 @@ class Tensor : public Data,
     constexpr Tensor(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr)
         : Data(Type),
           mDataType(NativeType<T>::type),
+          mDataFormat(DataFormat::Default),
           mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}),
           mStrides({SIZE_1 * SIZE_2 * SIZE_3, SIZE_2 * SIZE_3, SIZE_3, 1}),
           mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0, SIZE_1, SIZE_2, SIZE_3})),
@@ -247,11 +254,13 @@ class Tensor : public Data,
     Tensor operator+(const Tensor& other) const {
         AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
         AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
-        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
+        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
+        AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
         auto add_ = Add_Op(2);
         add_.associateInput(0, std::make_shared<Tensor>(*this));
         add_.associateInput(1, std::make_shared<Tensor>(other));
         add_.setDataType(dataType());
+        add_.setDataFormat(dataFormat());
         add_.setBackend(mImpl->backend());
         add_.forward();
         // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
@@ -270,11 +279,13 @@ class Tensor : public Data,
     Tensor operator-(const Tensor& other) const {
         AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
         AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
-        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
+        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
+        AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
         auto sub_ = Sub_Op();
         sub_.associateInput(0, std::make_shared<Tensor>(*this));
         sub_.associateInput(1, std::make_shared<Tensor>(other));
         sub_.setDataType(dataType());
+        sub_.setDataFormat(dataFormat());
         sub_.setBackend(mImpl->backend());
         sub_.forward();
         // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
@@ -293,11 +304,13 @@ class Tensor : public Data,
     Tensor operator*(const Tensor& other) const {
         AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
         AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
-        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
+        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
+        AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
         auto mul_ = Mul_Op();
         mul_.associateInput(0, std::make_shared<Tensor>(*this));
         mul_.associateInput(1, std::make_shared<Tensor>(other));
         mul_.setDataType(dataType());
+        mul_.setDataFormat(dataFormat());
         mul_.setBackend(mImpl->backend());
         mul_.forward();
         // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
@@ -316,11 +329,13 @@ class Tensor : public Data,
     Tensor operator/(const Tensor& other) const {
         AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
         AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
-        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
+        AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
+        AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
         auto div_ = Div_Op();
         div_.associateInput(0, std::make_shared<Tensor>(*this));
         div_.associateInput(1, std::make_shared<Tensor>(other));
         div_.setDataType(dataType());
+        div_.setDataFormat(dataFormat());
         div_.setBackend(mImpl->backend());
         div_.forward();
         // using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
@@ -390,6 +405,12 @@ public:
      */
     constexpr DataType dataType() const noexcept { return mDataType; }
 
+    /**
+     * @brief Get the data format enum.
+     * @return constexpr DataFormat
+     */
+    constexpr DataFormat dataFormat() const noexcept { return mDataFormat; }
+
     /**
      * @brief Set the DataType of the Tensor and converts data
      * if the Tensor has already been initialized and copyCast is true.
@@ -408,6 +429,14 @@ public:
         mDataType = dt;
     }
 
+    /**
+     * @brief Set the DataFormat of the Tensor.
+     * @param df DataFormat
+     */
+    void setDataFormat(const DataFormat df) {
+        mDataFormat = df;
+    }
+
     /**
      * @brief Get the Impl object
      * @return constexpr const std::shared_ptr<TensorImpl>&
@@ -575,6 +604,7 @@ public:
         }
         if (!mGrad->hasImpl()) {
             mGrad->setDataType(dataType());
+            mGrad->setDataFormat(dataFormat());
             mGrad->setBackend(hasImpl() ? mImpl->backend() : "cpu");
             mGrad->zeros();
         }
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 3ee234229..d5259b7fd 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -116,6 +116,7 @@ public:
 
     virtual void setBackend(const std::string& name, DeviceIdx_t device = 0) = 0;
     virtual void setDataType(const DataType& dataType) const = 0;
+    virtual void setDataFormat(const DataFormat& dataFormat) const = 0;
 
     /**
      * @brief Set a new OperatorImpl to the Operator
diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp
index a49379327..c6f2d78f1 100644
--- a/include/aidge/operator/OperatorTensor.hpp
+++ b/include/aidge/operator/OperatorTensor.hpp
@@ -85,6 +85,7 @@ public:
     ///////////////////////////////////////////////////
 
     virtual void setDataType(const DataType& dataType) const override;
+    virtual void setDataFormat(const DataFormat& dataFormat) const override;
     
     virtual void forward() override;
 };
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index 25c9deb2a..48c139d36 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -180,6 +180,17 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
     }
 }
 
+void Aidge::OperatorTensor::setDataFormat(const DataFormat& dataFormat) const {
+    for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
+        getOutput(i)->setDataFormat(dataFormat);
+    }
+
+    for (IOIndex_t i = nbData(); i < nbInputs(); ++i) {
+        AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type());
+        getInput(i)->setDataFormat(dataFormat);
+    }
+}
+
 void Aidge::OperatorTensor::forward() {
     if (!dimsForwarded()) {
         // Allow data dependent forwardDims at this point (data is available)
-- 
GitLab