From 3ab0a780df92e09a9acf18d02d1725d90fac14f4 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Wed, 4 Dec 2024 13:29:14 +0000
Subject: [PATCH] Add recipe apply weightInterleaving

---
 include/aidge/recipes/Recipes.hpp       | 25 ++++---
 src/recipes/ApplyWeightInterleaving.cpp | 97 +++++++++++++++++++++++++
 2 files changed, 112 insertions(+), 10 deletions(-)
 create mode 100644 src/recipes/ApplyWeightInterleaving.cpp

diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index 0fb405bfe..5f16c480c 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -180,19 +180,24 @@ size_t convToMatMul(std::shared_ptr<GraphView> graph);
  */
 void adaptToBackend(std::shared_ptr<GraphView> graph);
 
-// /**
-//  * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary.
-//  *        This recipie only operates memory transformations on the weight tensor. 
-//  *        First, permutes the dimensions to match the dataformat NHWC
-//  *        Second, compact the last dimension (Channel dimension) into int8_t
-//  * 
-//  * @param node Node 
-//  */
-// void applyWeightInterleaving(std::shared_ptr<Node> node);
-
 
+/**
+ * @brief Create a GenericOp from an Operator and replace it
+ * 
+ * @param node Node which Operator will be changed into a generic Operator
+ */
 void toGenericOp(std::shared_ptr<Node> node);
 
+/**
+ * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary.
+ *        This recipie only operates memory transformations on the weight tensor. 
+ *        First, permutes the dimensions to match the dataformat NHWC
+ *        Second, compact the last dimension of the weights (Channel dimension) into 8bits 
+ * 
+ * @param node Node 
+ */
+void applyWeightInterleaving(std::shared_ptr<Node> node);
+
 } // namespace Aidge
 
 #endif /* AIDGE_CORE_UTILS_RECIPES_H_ */
diff --git a/src/recipes/ApplyWeightInterleaving.cpp b/src/recipes/ApplyWeightInterleaving.cpp
new file mode 100644
index 000000000..42d65788b
--- /dev/null
+++ b/src/recipes/ApplyWeightInterleaving.cpp
@@ -0,0 +1,97 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <memory>
+
+#include "aidge/data/Data.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/operator/WeightInterleaving.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/recipes/Recipes.hpp"
+
+
+
+
+void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){
+    auto weightProducer = node->getParent(1);
+    AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name())
+
+    auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone());
+    auto backend = node->getOperator()->backend();
+
+    const Aidge::DataType weightDataType = weightTensor->dataType();
+
+    weightTensor->print();
+
+    // 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM
+    // Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions
+    // Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC
+    
+    if (weightTensor->dataFormat() == Aidge::DataFormat::Default) {
+        weightTensor->setDataFormat(Aidge::DataFormat::NCHW);
+    }
+    
+    // Apply permutation for NHWC format
+    if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) {
+        weightTensor->setDataFormat(Aidge::DataFormat::NHWC);
+    }
+
+    weightTensor->print();
+
+    // 2 - Apply Weight interleaving 
+    // Instanciate weight Interleaving operator
+    auto WIOp = WeightInterleaving_Op();
+
+
+    // Forward the Weight INterleaving op
+    WIOp.associateInput(0, weightTensor);
+
+    switch (weightDataType) {
+        case Aidge::DataType::Int4:
+            WIOp.setDataType(Aidge::DataType::Dual_Int4);
+            break;
+        case Aidge::DataType::UInt4:
+            WIOp.setDataType(Aidge::DataType::Dual_UInt4);
+            break;
+        case Aidge::DataType::Int3:
+            WIOp.setDataType(Aidge::DataType::Dual_Int3);
+            break;
+        case Aidge::DataType::UInt3:
+            WIOp.setDataType(Aidge::DataType::Dual_UInt3);
+            break;
+        case Aidge::DataType::Int2:
+            WIOp.setDataType(Aidge::DataType::Quad_Int2);
+            break;
+        case Aidge::DataType::UInt2:
+            WIOp.setDataType(Aidge::DataType::Quad_UInt2);
+            break;
+        case Aidge::DataType::Binary:
+            WIOp.setDataType(Aidge::DataType::Octo_Binary);
+            break;
+        default:
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type {} not supported for weight interleaving.", weightDataType);
+    }
+
+    WIOp.setDataFormat(Aidge::DataFormat::NHWC);
+    WIOp.setBackend(backend);
+
+    WIOp.forward();
+
+    WIOp.getOutput(0)->print();
+
+    // 3 - Replace the Weight Producer
+    auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())};
+    auto oldProducer = {weightProducer};
+
+    GraphView::replace(oldProducer, newProducer);
+    
+}
\ No newline at end of file
-- 
GitLab