diff --git a/aidge_core/export_utils/data_conversion.py b/aidge_core/export_utils/data_conversion.py index 401fc39f2a70245a67719699b5f0cdc61108e0cf..6dba5b78cd7b8e79baddb160a1110c3e830c7cd7 100644 --- a/aidge_core/export_utils/data_conversion.py +++ b/aidge_core/export_utils/data_conversion.py @@ -1,8 +1,9 @@ import numpy as np import aidge_core +from typing import Dict -datatype_converter_aide2c = { +datatype_converter_aidge2c = { aidge_core.dtype.float64 : "double", aidge_core.dtype.float32 : "float", aidge_core.dtype.float16 : "half_float::half", @@ -19,12 +20,31 @@ datatype_converter_aide2c = { def aidge2c(datatype): """Convert a aidge datatype to C type + If the type is not convertible to a C type (e.g. int4), return None and raise a warning. + :param datatype: Aidge datatype to convert :type datatype: :py:object:`aidge_core.DataType` :return: A string representing the C type :rtype: string """ - if datatype in datatype_converter_aide2c: - return datatype_converter_aide2c[datatype] + if datatype in datatype_converter_aidge2c: + return datatype_converter_aidge2c[datatype] else: raise ValueError(f"Unsupported {datatype} aidge datatype") + +def aidge2export_type(datatype: aidge_core.dtype, conversion_map: Dict[aidge_core.dtype, str] = datatype_converter_aidge2c) -> str: + """Convert a aidge datatype to the export type specified by the map passed in argument + + If the aidge type is not convertible, that is to say, is not specified in the map, a value Error is raised. + + :param datatype: Aidge datatype to convert + :type datatype: :py:object:`aidge_core.DataType` + :param conversion_map: Map that specify the conversion + :type conversion_map: Dict[:py:object:`aidge_core.DataType`, str] + :return: A string representing the export type + :rtype: string + """ + if datatype in conversion_map: + return conversion_map[datatype] + else: + raise ValueError(f"Unsupported type conversion {datatype} aidge datatype for export") diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index e5b6b2098cd760c4d425b96caf7b41cc8e82c46e..8927ae5169978da81e39912ebd4e26e2655137ad 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -80,6 +80,14 @@ class ExportLib(aidge_core.OperatorImpl): aidge_core.get_key_value_Tensor(["cpu", aidge_core.dtype.uint32])) aidge_core.register_Tensor([self._name, aidge_core.dtype.uint64], aidge_core.get_key_value_Tensor(["cpu", aidge_core.dtype.uint64])) + aidge_core.register_Tensor([self._name, aidge_core.dtype.int4], + aidge_core.get_key_value_Tensor(["cpu", aidge_core.dtype.int4])) + aidge_core.register_Tensor([self._name, aidge_core.dtype.uint4], + aidge_core.get_key_value_Tensor(["cpu", aidge_core.dtype.uint4])) + aidge_core.register_Tensor([self._name, aidge_core.dtype.dual_int4], + aidge_core.get_key_value_Tensor(["cpu", aidge_core.dtype.dual_int4])) + aidge_core.register_Tensor([self._name, aidge_core.dtype.dual_uint4], + aidge_core.get_key_value_Tensor(["cpu", aidge_core.dtype.dual_uint4])) @classproperty def _export_node_registry(cls) -> Dict[str, List['ExportNode']]: diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index 5777814a0b10c49d0f75245bdc4e9681027bdfb8..c24727adf11bb936cb99c1f40312c4da8c0705f3 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -3,7 +3,8 @@ from pathlib import Path from aidge_core.export_utils import data_conversion, code_generation from abc import ABC, abstractmethod -from typing import List +from typing import List, Dict + def get_chan(tensor: aidge_core.Tensor) -> int: @@ -14,12 +15,19 @@ def get_chan(tensor: aidge_core.Tensor) -> int: return dims[1] elif len(dims) == 2: # Suppose NC return dims[1] + elif len(dims) == 1: # Suppose C (for bias) + return dims[0] else: return None elif dformat == aidge_core.dformat.nchw: return dims[1] elif dformat == aidge_core.dformat.nhwc: - return dims[3] + if len(dims) == 4: # NHWC + return dims[3] + elif len(dims) == 2: # NC + return 1 + elif len(dims) == 1: # C for bias + return 1 elif dformat == aidge_core.dformat.chwn: return dims[0] elif dformat == aidge_core.dformat.ncdhw: @@ -40,12 +48,19 @@ def get_height(tensor: aidge_core.Tensor) -> int: return dims[2] elif len(dims) == 2: # Suppose NC return 1 + elif len(dims) == 1: # Suppose C for bias + return 1 else: return None elif dformat == aidge_core.dformat.nchw: return dims[2] elif dformat == aidge_core.dformat.nhwc: - return dims[1] + if len(dims) == 4: # NHWC + return dims[1] + elif len(dims) == 2: # NC + return 1 + elif len(dims) == 1: # C for bias + return 1 elif dformat == aidge_core.dformat.chwn: return dims[1] elif dformat == aidge_core.dformat.ncdhw: @@ -66,12 +81,19 @@ def get_width(tensor: aidge_core.Tensor) -> int: return dims[3] elif len(dims) == 2: # Suppose NC return 1 + elif len(dims) == 1: # Suppose C for bias + return 1 else: return None elif dformat == aidge_core.dformat.nchw: return dims[3] elif dformat == aidge_core.dformat.nhwc: - return dims[2] + if len(dims) == 4: # NHWC + return dims[2] + elif len(dims) == 2: # NC + return 1 + elif len(dims) == 1: # C for bias + return 1 elif dformat == aidge_core.dformat.chwn: return dims[2] elif dformat == aidge_core.dformat.ncdhw: @@ -162,7 +184,9 @@ class ExportNode(ABC): """ @abstractmethod - def __init__(self, aidge_node: aidge_core.Node, mem_info: List[dict]=None) -> None: + def __init__(self, aidge_node: aidge_core.Node, + mem_info: List[dict]=None, + conversion_map: Dict[aidge_core.dtype, str] = data_conversion.datatype_converter_aidge2c) -> None: """Create ExportNode and retrieve attributes from ``aidge_node``: """ @@ -231,8 +255,8 @@ class ExportNode(ABC): self.attributes["in_dformat"][idx] = tensor.dformat() self.attributes["in_format"][idx] = aidge_core.format_as(tensor.dformat()) self.attributes["in_dtype"][idx] = tensor.dtype() - self.attributes["in_cdtype"][idx] = data_conversion.aidge2c( - tensor.dtype()) + # self.attributes["in_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype()) + self.attributes["in_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map) self.attributes["in_chan"][idx] = get_chan(tensor) self.attributes["in_height"][idx] = get_height(tensor) self.attributes["in_width"][idx] = get_width(tensor) @@ -254,8 +278,8 @@ class ExportNode(ABC): self.attributes["out_dformat"][idx] = tensor.dformat() self.attributes["out_format"][idx] = aidge_core.format_as(tensor.dformat()) self.attributes["out_dtype"][idx] = tensor.dtype() - self.attributes["out_cdtype"][idx] = data_conversion.aidge2c( - tensor.dtype()) + # self.attributes["out_cdtype"][idx] = data_conversion.aidge2c(tensor.dtype()) + self.attributes["out_cdtype"][idx] = data_conversion.aidge2export_type(tensor.dtype(), conversion_map) self.attributes["out_chan"][idx] = get_chan(tensor) self.attributes["out_height"][idx] = get_height(tensor) self.attributes["out_width"][idx] = get_width(tensor) diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index d04624fc530a21730cc4dc1f4f1ac90a58e6590b..2115b660fa38d3d077eaa9c416525a23c1d4c536 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -126,6 +126,20 @@ REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::crea REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int4}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int3}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int2}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_Int4}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt4}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_Int3}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Dual_UInt3}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Quad_Int2}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Quad_UInt2}, Aidge::TensorImpl_cpu<uint8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Binary}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Octo_Binary}, Aidge::TensorImpl_cpu<int8_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index a34718296e4ccddbfca0b4eb0daf14b08124389a..35df9c0e0bf24ee175fe27eb7c831fcae7a700e7 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -29,10 +29,14 @@ enum class DataType { Float16, BFloat16, Binary, + Octo_Binary, Ternary, Int2, + Quad_Int2, Int3, + Dual_Int3, Int4, + Dual_Int4, Int5, Int6, Int7, @@ -41,8 +45,11 @@ enum class DataType { Int32, Int64, UInt2, + Quad_UInt2, UInt3, + Dual_UInt3, UInt4, + Dual_UInt4, UInt5, UInt6, UInt7, @@ -117,6 +124,17 @@ private: } namespace { + +template <Aidge::DataType D> struct WeightInterleavingType { static const Aidge::DataType type; }; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int4>::type = Aidge::DataType::Dual_Int4; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt4>::type = Aidge::DataType::Dual_UInt4; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int3>::type = Aidge::DataType::Dual_Int3; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt3>::type = Aidge::DataType::Dual_UInt3; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Int2>::type = Aidge::DataType::Quad_Int2; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::UInt2>::type = Aidge::DataType::Quad_UInt2; +template <> const Aidge::DataType WeightInterleavingType<Aidge::DataType::Binary>::type = Aidge::DataType::Octo_Binary; + + template <typename T> struct NativeType { static const Aidge::DataType type; }; template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64; template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32; @@ -132,9 +150,9 @@ template <> const Aidge::DataType NativeType<std::uint64_t>::type = Aidge::DataT template <> const char* const EnumStrings<Aidge::DataType>::data[] - = {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Ternary", - "Int2", "Int3", "Int4", "Int5", "Int6", "Int7", "Int8", "Int16", - "Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6", + = {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Octo_Binary", "Ternary", + "Int2", "Quad_Int2", "Int3", "Dual_Int3", "Int4", "Dual_Int4", "Int5", "Int6", "Int7", "Int8", "Int16", + "Int32", "Int64", "UInt2", "Quad_UInt2", "UInt3", "Dual_UInt3", "UInt4", "Dual_UInt4", "UInt5", "UInt6", "UInt7", "UInt8", "UInt16", "UInt32", "UInt64", "Any"}; template <> @@ -147,6 +165,20 @@ template <Aidge::DataType D> struct cpptype { template <> struct cpptype<Aidge::DataType::Float16> { using type = half_float::half; }; template <> struct cpptype<Aidge::DataType::Float32> { using type = float; }; template <> struct cpptype<Aidge::DataType::Float64> { using type = double; }; +template <> struct cpptype<Aidge::DataType::Int4> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::UInt4> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Int3> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::UInt3> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Int2> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::UInt2> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Dual_Int4> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Dual_UInt4> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Dual_Int3> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Dual_UInt3> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Quad_Int2> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Quad_UInt2> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Binary> { using type = std::int8_t; }; +template <> struct cpptype<Aidge::DataType::Octo_Binary> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int8> { using type = std::int8_t; }; template <> struct cpptype<Aidge::DataType::Int16> { using type = std::int16_t; }; template <> struct cpptype<Aidge::DataType::Int32> { using type = std::int32_t; }; @@ -157,6 +189,7 @@ template <> struct cpptype<Aidge::DataType::UInt32> { using type = std::uint32_t template <> struct cpptype<Aidge::DataType::UInt64> { using type = std::uint64_t; }; template <Aidge::DataType D> using cpptype_t = typename cpptype<D>::type; + } diff --git a/include/aidge/operator/WeightInterleaving.hpp b/include/aidge/operator/WeightInterleaving.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9e51441aab7772ca5cbb26195c94a0a837d7157 --- /dev/null +++ b/include/aidge/operator/WeightInterleaving.hpp @@ -0,0 +1,83 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_WEIGHTINTERLEAVING_H_ +#define AIDGE_CORE_OPERATOR_WEIGHTINTERLEAVING_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + + +namespace Aidge { + +class WeightInterleaving_Op : + public OperatorTensor, + public Registrable<WeightInterleaving_Op, // <Op, backend, implementation creation function> + std::string, + std::function<std::shared_ptr<OperatorImpl>(const WeightInterleaving_Op&)>> +{ +public: + static const std::string Type; + + WeightInterleaving_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} + + /** + * @brief Copy-constructor. + * @param op WeightInterleaving_Op to copy. + * @details Copies the operator attributes and its output tensor(s), but not + * its input tensors. The new operator has no associated input. + */ + WeightInterleaving_Op(const WeightInterleaving_Op& op); + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::WeightInterleaving_Op + */ + std::shared_ptr<Operator> clone() const override; + + bool forwardDims(bool allowDataDependency = false) override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; + std::set<std::string> getAvailableBackends() const override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } + + /** + * @brief Calculates the required size for the 8-bits`compactData` vector. + * + * This function determines the minimum number of bytes needed in `compactData` + * to store `dataSize` elements compacted to `nb_bits` bits each. + * + * @param dataSize The total number of elements in the input data array. + * @param nb_bits The number of bits to use for each compacted element (from 1 to 7). + * @return std::size_t The required size in bytes for `compactData`. + */ + std::size_t compactDataSize(std::size_t dataSize, std::uint8_t nb_bits); + +}; + +std::shared_ptr<Node> WeightInterleaving(const std::string& name = ""); +} + +#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 0fb405bfe5e74f159fbd5504cc199e3b29842254..5f16c480c233c0aee23962549c1d86695af81d89 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -180,19 +180,24 @@ size_t convToMatMul(std::shared_ptr<GraphView> graph); */ void adaptToBackend(std::shared_ptr<GraphView> graph); -// /** -// * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary. -// * This recipie only operates memory transformations on the weight tensor. -// * First, permutes the dimensions to match the dataformat NHWC -// * Second, compact the last dimension (Channel dimension) into int8_t -// * -// * @param node Node -// */ -// void applyWeightInterleaving(std::shared_ptr<Node> node); - +/** + * @brief Create a GenericOp from an Operator and replace it + * + * @param node Node which Operator will be changed into a generic Operator + */ void toGenericOp(std::shared_ptr<Node> node); +/** + * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary. + * This recipie only operates memory transformations on the weight tensor. + * First, permutes the dimensions to match the dataformat NHWC + * Second, compact the last dimension of the weights (Channel dimension) into 8bits + * + * @param node Node + */ +void applyWeightInterleaving(std::shared_ptr<Node> node); + } // namespace Aidge #endif /* AIDGE_CORE_UTILS_RECIPES_H_ */ diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 49e45ed7e447c00cf7300e8228ee7d1b04800083..cd94997cf7a584633be8b811b1e694bdf9886fc1 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -81,6 +81,7 @@ void init_OperatorImpl(py::module& m){ .def(py::init<const DynamicAttributes&>(), py::arg("attr") = DynamicAttributes()) .def(py::init<const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("io"), py::arg("attr") = DynamicAttributes()) .def(py::init<const ImplSpec::IOSpec&, const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes()) + .def(py::init<const std::vector<ImplSpec::IOSpec>&, const std::vector<ImplSpec::IOSpec>&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes()) .def("__eq__", static_cast<bool(*)(const ImplSpec&, const ImplSpec&)>(&operator==)) .def("__repr__", [](ImplSpec self){ return fmt::format("{}\n", self); diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 0d4ed716ca6c65c2e8a0153a729ebecef771ea9e..35e60e1589ce5599affbc2b466171acc6bf4ef01 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -226,6 +226,8 @@ static T castToNativeType(const py::object val_obj) { DataType dtype; getConservativeNativeVal(val_obj, &val, &dtype); switch (dtype) { + case DataType::Int4: + return (T)val.i8; case DataType::Int8: return (T)val.i8; case DataType::Int16: @@ -353,6 +355,22 @@ void init_Tensor(py::module& m){ return py::cast(b.get<float>(idx)); case DataType::Int8: return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int4: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Dual_Int4: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int3: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Dual_Int3: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Int2: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Quad_Int2: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Binary: + return py::cast(b.get<std::int8_t>(idx)); + case DataType::Octo_Binary: + return py::cast(b.get<std::int8_t>(idx)); case DataType::Int16: return py::cast(b.get<std::int16_t>(idx)); case DataType::Int32: @@ -361,6 +379,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::int64_t>(idx)); case DataType::UInt8: return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt4: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::Dual_UInt4: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt3: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::Dual_UInt3: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::UInt2: + return py::cast(b.get<std::uint8_t>(idx)); + case DataType::Quad_UInt2: + return py::cast(b.get<std::uint8_t>(idx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(idx)); case DataType::UInt32: @@ -380,6 +410,22 @@ void init_Tensor(py::module& m){ return py::cast(b.get<float>(coordIdx)); case DataType::Int8: return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int4: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Dual_Int4: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int3: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Dual_Int3: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Int2: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Quad_Int2: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Binary: + return py::cast(b.get<std::int8_t>(coordIdx)); + case DataType::Octo_Binary: + return py::cast(b.get<std::int8_t>(coordIdx)); case DataType::Int16: return py::cast(b.get<std::int16_t>(coordIdx)); case DataType::Int32: @@ -388,6 +434,18 @@ void init_Tensor(py::module& m){ return py::cast(b.get<std::int64_t>(coordIdx)); case DataType::UInt8: return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt4: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::Dual_UInt4: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt3: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::Dual_UInt3: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::UInt2: + return py::cast(b.get<std::uint8_t>(coordIdx)); + case DataType::Quad_UInt2: + return py::cast(b.get<std::uint8_t>(coordIdx)); case DataType::UInt16: return py::cast(b.get<std::uint16_t>(coordIdx)); case DataType::UInt32: @@ -410,6 +468,30 @@ void init_Tensor(py::module& m){ case DataType::Int8: b.set(idx, castToNativeType<std::int8_t>(val)); break; + case DataType::Int4: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int4: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int3: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int3: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int2: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Quad_Int2: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Binary: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Octo_Binary: + b.set(idx, castToNativeType<std::int8_t>(val)); + break; case DataType::Int16: b.set(idx, castToNativeType<std::int16_t>(val)); break; @@ -422,6 +504,24 @@ void init_Tensor(py::module& m){ case DataType::UInt8: b.set(idx, castToNativeType<std::uint8_t>(val)); break; + case DataType::UInt4: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt4: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt3: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt3: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt2: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Quad_UInt2: + b.set(idx, castToNativeType<std::uint8_t>(val)); + break; case DataType::UInt16: b.set(idx, castToNativeType<std::uint16_t>(val)); break; @@ -448,6 +548,30 @@ void init_Tensor(py::module& m){ case DataType::Int8: b.set(coordIdx, castToNativeType<std::int8_t>(val)); break; + case DataType::Int4: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int4: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int3: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Dual_Int3: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Int2: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Quad_Int2: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Binary: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; + case DataType::Octo_Binary: + b.set(coordIdx, castToNativeType<std::int8_t>(val)); + break; case DataType::Int16: b.set(coordIdx, castToNativeType<std::int16_t>(val)); break; @@ -460,6 +584,24 @@ void init_Tensor(py::module& m){ case DataType::UInt8: b.set(coordIdx, castToNativeType<std::uint8_t>(val)); break; + case DataType::UInt4: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt4: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt3: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Dual_UInt3: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::UInt2: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; + case DataType::Quad_UInt2: + b.set(coordIdx, castToNativeType<std::uint8_t>(val)); + break; case DataType::UInt16: b.set(coordIdx, castToNativeType<std::uint16_t>(val)); break; @@ -497,6 +639,48 @@ void init_Tensor(py::module& m){ case DataType::Float32: dataFormatDescriptor = py::format_descriptor<float>::format(); break;; + case DataType::Int4: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::UInt4: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Int3: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::UInt3: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Int2: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::UInt2: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Dual_Int4: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Dual_UInt4: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Dual_Int3: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Dual_UInt3: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Quad_Int2: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Quad_UInt2: + dataFormatDescriptor = py::format_descriptor<std::uint8_t>::format(); + break; + case DataType::Binary: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; + case DataType::Octo_Binary: + dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); + break; case DataType::Int8: dataFormatDescriptor = py::format_descriptor<std::int8_t>::format(); break; diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index ded3b54088e6d1ed473ed614e23fc08cd89a0346..2191d866f2a2b1f1d490b2016de97afd8ec8157b 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -61,6 +61,7 @@ void init_Operator(py::module& m){ )mydelimiter") .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDataType, py::arg("dataType")) + .def("set_dataformat", &Operator::setDataFormat, py::arg("dataFormat")) .def("set_backend", py::overload_cast<const std::string&, DeviceIdx_t>(&Operator::setBackend), py::arg("name"), py::arg("device") = 0) .def("set_backend", py::overload_cast<const std::vector<std::pair<std::string, DeviceIdx_t>>&>(&Operator::setBackend), py::arg("backends")) .def("forward", &Operator::forward) diff --git a/python_binding/operator/pybind_WeightInterleaving.cpp b/python_binding/operator/pybind_WeightInterleaving.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25b423bd66503b39f031695121cf673c45c34bbe --- /dev/null +++ b/python_binding/operator/pybind_WeightInterleaving.cpp @@ -0,0 +1,39 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "aidge/operator/WeightInterleaving.hpp" + +namespace py = pybind11; + +namespace Aidge { + +void declare_WeightInterleaving(py::module &m) { + py::class_<WeightInterleaving_Op, std::shared_ptr<WeightInterleaving_Op>, OperatorTensor>(m, "WeightInterleavingOp", py::multiple_inheritance()) + .def(py::init<>()) + .def_static("get_inputs_name", &WeightInterleaving_Op::getInputsName) + .def_static("get_outputs_name", &WeightInterleaving_Op::getOutputsName) + .def_readonly_static("Type", &WeightInterleaving_Op::Type) + + .def("__repr__", [](WeightInterleaving_Op& b) { + return fmt::format("Operator(type='{}')", b.Type); + }); + + declare_registrable<WeightInterleaving_Op>(m, "WeightInterleavingOp"); + + m.def("WeightInterleaving", &WeightInterleaving, py::arg("name") = ""); +} + +void init_WeightInterleaving(py::module &m) { + declare_WeightInterleaving(m); +} + +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 006eeb289f25570ddf337f048b05816102624028..f572c024d3cf69a1a06cd9be3e60cc7106fccfe3 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -83,6 +83,7 @@ void init_Sub(py::module&); void init_Tanh(py::module&); void init_Transpose(py::module&); void init_Unsqueeze(py::module&); +void init_WeightInterleaving(py::module&); void init_Node(py::module&); void init_GraphView(py::module&); @@ -177,6 +178,7 @@ void init_Aidge(py::module& m) { init_Tanh(m); init_Transpose(m); init_Unsqueeze(m); + init_WeightInterleaving(m); init_Producer(m); diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index f656af70dfa05678875afd4b4748f358437852a8..21478a5b14d609801f232b20cda25e7e1c0d9475 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -151,6 +151,14 @@ void init_Recipes(py::module &m) :param node: Node which Operator will turn into a Generic Operator :type graph_view: :py:class:`aidge_core.Node` )mydelimiter"); + + m.def("apply_weightinterleaving", applyWeightInterleaving, py::arg("node"), R"mydelimiter( + Replace weight Producer linked to the given node with a weight producer with interleaving and format NHWC. + This recipe is specific to the ARM cortex-m export for low bit integer support. + + :param node: Node which linked weights will recieve interleaving + :type graph_view: :py:class:`aidge_core.Node` + )mydelimiter"); } } // namespace Aidge diff --git a/src/backend/cpu/data/TensorImpl.cpp b/src/backend/cpu/data/TensorImpl.cpp index 506287a0c520915e6426f1f0b64d9c562c754d33..236e5bb8e1e867d5a0dad85571d754bc9e2a2a22 100644 --- a/src/backend/cpu/data/TensorImpl.cpp +++ b/src/backend/cpu/data/TensorImpl.cpp @@ -95,6 +95,62 @@ void Aidge::TensorImpl_cpu<T>::copyCast(const void *src, const Aidge::DataType s std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, dstT); break; + case DataType::Int4: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt4: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Dual_Int4: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Dual_UInt4: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Int3: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt3: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Dual_Int3: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Dual_UInt3: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Int2: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::UInt2: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Quad_Int2: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Quad_UInt2: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + dstT); + break; + case DataType::Binary: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; + case DataType::Octo_Binary: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + dstT); + break; default: AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); break; diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index c834167abe15fb8a7ce96053a87a958b7515fe17..ee19796098bf1d755448d833aa6a8a2c24180baa 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -322,6 +322,34 @@ std::string Tensor::toString() const { return std::to_string(static_cast<float*>(ptr)[idx]); case DataType::Float16: return std::to_string(static_cast<half_float::half*>(ptr)[idx]); + case DataType::Binary: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Octo_Binary: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Dual_Int4: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Dual_UInt4: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Dual_Int3: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Dual_UInt3: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Quad_Int2: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Quad_UInt2: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Int4: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::UInt4: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Int3: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::UInt3: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::Int2: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::UInt2: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); case DataType::Int8: return std::to_string(static_cast<int8_t*>(ptr)[idx]); case DataType::Int16: diff --git a/src/operator/WeightInterleaving.cpp b/src/operator/WeightInterleaving.cpp new file mode 100644 index 0000000000000000000000000000000000000000..66af1d51f87c24b5b8d7d9c1f0ab3701f122515d --- /dev/null +++ b/src/operator/WeightInterleaving.cpp @@ -0,0 +1,121 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/WeightInterleaving.hpp" + +#include <memory> +#include <string> +#include <vector> + +#include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +const std::string Aidge::WeightInterleaving_Op::Type = "WeightInterleaving"; + +/** + * @brief Copy-constructor. + * @param op WeightInterleaving_Op to copy. + * @details Copies the operator attributes and its output tensor(s), but not + * its input tensors. The new operator has no associated input. + */ +Aidge::WeightInterleaving_Op::WeightInterleaving_Op(const WeightInterleaving_Op& op) + : OperatorTensor(op) +{ + if (op.mImpl) { + SET_IMPL_MACRO(WeightInterleaving_Op, *this, op.backend()); + } else { + mImpl = nullptr; + } +} + + +std::shared_ptr<Aidge::Operator> Aidge::WeightInterleaving_Op::clone() const { + return std::make_shared<WeightInterleaving_Op>(*this); +} + + +bool Aidge::WeightInterleaving_Op::forwardDims(bool /*allowDataDependency*/) { + + if (inputsAssociated()) { + + // check input data format is NHWC + AIDGE_ASSERT((getInput(0)->dataFormat() == DataFormat::NHWC), + "Wrong Input tensor Data Format : {} for WeightInterleaving operator (should be DataFormat::NHWC for STM32).", getInput(0)->dataFormat()); + + // Take the last dimension of the tensor : It is the Channel dimension in format NHWC + // The weights will be compacted along side the channel dimension only + const DimSize_t& lastDim = getInput(0)->dims().back(); + + // Compute the last dimension size of the tensor after the weight interleaving compression + // TO DO : implement a mechanism to get the number of bits of the DataType + const DataType& dt = getInput(0)->dataType(); + + std::uint8_t nbBits = 0; + + switch (dt) { + case DataType::Int4: + nbBits=4; + break; + case DataType::Int3: + nbBits=3; + break; + case DataType::Int2: + nbBits=2; + break; + default: + AIDGE_ASSERT(true, "Unsupport type for WeightInterleaving {}", dt); + } + + + const auto lastDimCompression = compactDataSize(lastDim, nbBits); + + std::vector<DimSize_t> outputDims = getInput(0)->dims(); + outputDims.back() = lastDimCompression; + + // <batch, OutChannels> + mOutputs[0]->resize(outputDims); + + return true; + } + + return false; +} + + +void Aidge::WeightInterleaving_Op::setBackend(const std::string& name, DeviceIdx_t device) { + SET_IMPL_MACRO(WeightInterleaving_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} + +std::set<std::string> Aidge::WeightInterleaving_Op::getAvailableBackends() const { + return Registrar<WeightInterleaving_Op>::getKeys(); +} + +std::shared_ptr<Aidge::Node> Aidge::WeightInterleaving(const std::string& name) { + return std::make_shared<Node>(std::make_shared<WeightInterleaving_Op>(), name); +} + + +std::size_t Aidge::WeightInterleaving_Op::compactDataSize(std::size_t dataSize, std::uint8_t nbBits) { + AIDGE_ASSERT(nbBits > 0 && nbBits < 8, "nbBits must be between 1 and 4"); // Ensure valid bit width + + // Calculate the number of `nbBits` segments that can fit in an 8-bit byte. + const unsigned int nbSlot = 8 / nbBits; + + // Calculate the number of compacted bytes needed to store all data elements. + // The formula (dataSize + nbSlot - 1) / nbSlot effectively rounds up the division, ensuring that any remaining elements that don't fully fill a byte are accounted for. + std::size_t requiredSize = (dataSize + nbSlot - 1) / nbSlot; + + return requiredSize; +} \ No newline at end of file diff --git a/src/recipes/ApplyWeightInterleaving.cpp b/src/recipes/ApplyWeightInterleaving.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9c042a538bc1ece754c5f659048e9c5f6c0d249 --- /dev/null +++ b/src/recipes/ApplyWeightInterleaving.cpp @@ -0,0 +1,119 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/WeightInterleaving.hpp" +#include "aidge/operator/Transpose.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/recipes/Recipes.hpp" + + + + +void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ + auto weightProducer = node->getParent(1); + AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name()) + + auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone()); + // auto backend = node->getOperator()->backend(); + // Cover the case of Generic Operators + auto backend = node->getOperator()->backend().empty() ? "cpu" : node->getOperator()->backend(); + + const Aidge::DataType weightDataType = weightTensor->dataType(); + + // 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM + // Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions + // Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC + + std::shared_ptr<Tensor> transposedWeightTensor; + + // Case 4D tensor (conv) + if (weightTensor->nbDims() == 4) + { + if (weightTensor->dataFormat() == Aidge::DataFormat::Default) { + weightTensor->setDataFormat(Aidge::DataFormat::NCHW); + } + + // Apply permutation for NHWC format + if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) { + weightTensor->setDataFormat(Aidge::DataFormat::NHWC); + } + + transposedWeightTensor = weightTensor; + + } + else if (weightTensor->nbDims() == 2) + { + std::shared_ptr<Node> myTranspose = Transpose({1, 0}); + auto op = std::static_pointer_cast<OperatorTensor>(myTranspose -> getOperator()); + op->associateInput(0,weightTensor); + op->setDataType(weightDataType); + op->setBackend("cpu"); + myTranspose->forward(); + + transposedWeightTensor = op->getOutput(0); + transposedWeightTensor->setDataFormat(Aidge::DataFormat::NHWC); + + } else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot transpose {} weights.", node->name()); + } + + // 2 - Apply Weight interleaving + // Instanciate weight Interleaving operator + auto WIOp = WeightInterleaving_Op(); + + // Forward the Weight INterleaving op + WIOp.associateInput(0, transposedWeightTensor); + + switch (weightDataType) { + case Aidge::DataType::Int4: + WIOp.setDataType(Aidge::DataType::Dual_Int4); + break; + case Aidge::DataType::UInt4: + WIOp.setDataType(Aidge::DataType::Dual_UInt4); + break; + case Aidge::DataType::Int3: + WIOp.setDataType(Aidge::DataType::Dual_Int3); + break; + case Aidge::DataType::UInt3: + WIOp.setDataType(Aidge::DataType::Dual_UInt3); + break; + case Aidge::DataType::Int2: + WIOp.setDataType(Aidge::DataType::Quad_Int2); + break; + case Aidge::DataType::UInt2: + WIOp.setDataType(Aidge::DataType::Quad_UInt2); + break; + case Aidge::DataType::Binary: + WIOp.setDataType(Aidge::DataType::Octo_Binary); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type {} not supported for weight interleaving.", weightDataType); + } + + WIOp.setDataFormat(Aidge::DataFormat::NHWC); + WIOp.setBackend(backend); + + WIOp.forward(); + + // 3 - Replace the Weight Producer + auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())}; + auto oldProducer = {weightProducer}; + + GraphView::replace(oldProducer, newProducer); + +} \ No newline at end of file