Skip to content
Snippets Groups Projects
Commit d4826e7e authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added initial support for data format

parent fda608ab
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!131New features to simplify exports
......@@ -16,6 +16,7 @@
#include <fmt/format.h>
#include <string>
#include <tuple>
#include <array>
#include "aidge/data/half.hpp"
#include "aidge/utils/Attributes.hpp"
......@@ -50,6 +51,28 @@ enum class DataType {
UInt64
};
enum class DataFormat {
Default,
NCHW,
NHWC,
CHWN,
NCDHW,
NDHWC,
CDHWN
};
constexpr std::array<std::array<size_t, 5>, 7> DataFormatTranspose = {{
// Important: in this array only, dimension index must start at 1, not 0!
// (0 is the default value)
{},
{1, 2, 3, 4},
{1, 3, 4, 2},
{2, 3, 4, 1},
{1, 2, 3, 4, 5},
{1, 3, 4, 5, 2},
{2, 3, 4, 5, 1}
}};
class Data {
public:
Data(const std::string& type): mType(type) {};
......@@ -85,6 +108,10 @@ const char* const EnumStrings<Aidge::DataType>::data[]
"Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6",
"UInt7", "UInt8", "UInt16", "UInt32", "UInt64"};
template <>
const char* const EnumStrings<Aidge::DataFormat>::data[]
= {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN"};
template <Aidge::DataType D> struct cpptype {
using type = void; // Placeholder
};
......@@ -106,6 +133,7 @@ template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type;
namespace Aidge {
inline auto format_as(DataType dt) { return EnumStrings<Aidge::DataType>::data[static_cast<int>(dt)]; }
inline auto format_as(DataFormat df) { return EnumStrings<Aidge::DataFormat>::data[static_cast<int>(df)]; }
}
#endif /* AIDGE_DATA_H_ */
......@@ -42,6 +42,7 @@ class Tensor : public Data,
public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)> {
private:
DataType mDataType = DataType::Float32; /** enum to specify data type. */
DataFormat mDataFormat = DataFormat::Default; /** enum to specify data format. */
std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
std::vector<DimSize_t> mStrides; /** Stride dimensions of the tensor. */
std::shared_ptr<TensorImpl> mImpl = nullptr; /** Pointer to the actual data implementation. */
......@@ -61,9 +62,10 @@ class Tensor : public Data,
* @brief Construct a new empty Tensor object.
* It has the features of an undefined scalar.
*/
Tensor(DataType dtype = DataType::Float32)
Tensor(DataType dtype = DataType::Float32, DataFormat dformat = DataFormat::Default)
: Data(Type),
mDataType(dtype),
mDataFormat(dformat),
mDims(std::vector<DimSize_t>({})),
mStrides({1}),
mSize(1)
......@@ -83,6 +85,7 @@ class Tensor : public Data,
Tensor(T val)
: Data(Type),
mDataType(NativeType<VT>::type),
mDataFormat(DataFormat::Default),
mDims({}),
mStrides({1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<VT>::type})(0, std::vector<std::size_t>())),
......@@ -112,6 +115,7 @@ class Tensor : public Data,
constexpr Tensor(Array1D<T, SIZE_0> &&arr)
: Data(Type),
mDataType(NativeType<T>::type),
mDataFormat(DataFormat::Default),
mDims({SIZE_0}),
mStrides({1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0})),
......@@ -130,6 +134,7 @@ class Tensor : public Data,
constexpr Tensor(Array2D<T, SIZE_0, SIZE_1> &&arr)
: Data(Type),
mDataType(NativeType<T>::type),
mDataFormat(DataFormat::Default),
mDims({SIZE_0, SIZE_1}),
mStrides({SIZE_1, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0, SIZE_1})),
......@@ -148,6 +153,7 @@ class Tensor : public Data,
constexpr Tensor(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr)
: Data(Type),
mDataType(NativeType<T>::type),
mDataFormat(DataFormat::Default),
mDims({SIZE_0, SIZE_1, SIZE_2}),
mStrides({SIZE_1 * SIZE_2, SIZE_2, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0, SIZE_1, SIZE_2})),
......@@ -167,6 +173,7 @@ class Tensor : public Data,
constexpr Tensor(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr)
: Data(Type),
mDataType(NativeType<T>::type),
mDataFormat(DataFormat::Default),
mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}),
mStrides({SIZE_1 * SIZE_2 * SIZE_3, SIZE_2 * SIZE_3, SIZE_3, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, {SIZE_0, SIZE_1, SIZE_2, SIZE_3})),
......@@ -247,11 +254,13 @@ class Tensor : public Data,
Tensor operator+(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto add_ = Add_Op(2);
add_.associateInput(0, std::make_shared<Tensor>(*this));
add_.associateInput(1, std::make_shared<Tensor>(other));
add_.setDataType(dataType());
add_.setDataFormat(dataFormat());
add_.setBackend(mImpl->backend());
add_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
......@@ -270,11 +279,13 @@ class Tensor : public Data,
Tensor operator-(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto sub_ = Sub_Op();
sub_.associateInput(0, std::make_shared<Tensor>(*this));
sub_.associateInput(1, std::make_shared<Tensor>(other));
sub_.setDataType(dataType());
sub_.setDataFormat(dataFormat());
sub_.setBackend(mImpl->backend());
sub_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
......@@ -293,11 +304,13 @@ class Tensor : public Data,
Tensor operator*(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto mul_ = Mul_Op();
mul_.associateInput(0, std::make_shared<Tensor>(*this));
mul_.associateInput(1, std::make_shared<Tensor>(other));
mul_.setDataType(dataType());
mul_.setDataFormat(dataFormat());
mul_.setBackend(mImpl->backend());
mul_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
......@@ -316,11 +329,13 @@ class Tensor : public Data,
Tensor operator/(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same data type");
AIDGE_ASSERT(dataFormat() == other.dataFormat(), "Tensors must have the same data format");
auto div_ = Div_Op();
div_.associateInput(0, std::make_shared<Tensor>(*this));
div_.associateInput(1, std::make_shared<Tensor>(other));
div_.setDataType(dataType());
div_.setDataFormat(dataFormat());
div_.setBackend(mImpl->backend());
div_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
......@@ -390,6 +405,12 @@ public:
*/
constexpr DataType dataType() const noexcept { return mDataType; }
/**
* @brief Get the data format enum.
* @return constexpr DataFormat
*/
constexpr DataFormat dataFormat() const noexcept { return mDataFormat; }
/**
* @brief Set the DataType of the Tensor and converts data
* if the Tensor has already been initialized and copyCast is true.
......@@ -408,6 +429,14 @@ public:
mDataType = dt;
}
/**
* @brief Set the DataFormat of the Tensor.
* @param df DataFormat
*/
void setDataFormat(const DataFormat df) {
mDataFormat = df;
}
/**
* @brief Get the Impl object
* @return constexpr const std::shared_ptr<TensorImpl>&
......@@ -575,6 +604,7 @@ public:
}
if (!mGrad->hasImpl()) {
mGrad->setDataType(dataType());
mGrad->setDataFormat(dataFormat());
mGrad->setBackend(hasImpl() ? mImpl->backend() : "cpu");
mGrad->zeros();
}
......
......@@ -116,6 +116,7 @@ public:
virtual void setBackend(const std::string& name, DeviceIdx_t device = 0) = 0;
virtual void setDataType(const DataType& dataType) const = 0;
virtual void setDataFormat(const DataFormat& dataFormat) const = 0;
/**
* @brief Set a new OperatorImpl to the Operator
......
......@@ -85,6 +85,7 @@ public:
///////////////////////////////////////////////////
virtual void setDataType(const DataType& dataType) const override;
virtual void setDataFormat(const DataFormat& dataFormat) const override;
virtual void forward() override;
};
......
......@@ -180,6 +180,17 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
}
}
void Aidge::OperatorTensor::setDataFormat(const DataFormat& dataFormat) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataFormat(dataFormat);
}
for (IOIndex_t i = nbData(); i < nbInputs(); ++i) {
AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type());
getInput(i)->setDataFormat(dataFormat);
}
}
void Aidge::OperatorTensor::forward() {
if (!dimsForwarded()) {
// Allow data dependent forwardDims at this point (data is available)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment