From 3ab0a780df92e09a9acf18d02d1725d90fac14f4 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Wed, 4 Dec 2024 13:29:14 +0000 Subject: [PATCH] Add recipe apply weightInterleaving --- include/aidge/recipes/Recipes.hpp | 25 ++++--- src/recipes/ApplyWeightInterleaving.cpp | 97 +++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 10 deletions(-) create mode 100644 src/recipes/ApplyWeightInterleaving.cpp diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 0fb405bfe..5f16c480c 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -180,19 +180,24 @@ size_t convToMatMul(std::shared_ptr<GraphView> graph); */ void adaptToBackend(std::shared_ptr<GraphView> graph); -// /** -// * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary. -// * This recipie only operates memory transformations on the weight tensor. -// * First, permutes the dimensions to match the dataformat NHWC -// * Second, compact the last dimension (Channel dimension) into int8_t -// * -// * @param node Node -// */ -// void applyWeightInterleaving(std::shared_ptr<Node> node); - +/** + * @brief Create a GenericOp from an Operator and replace it + * + * @param node Node which Operator will be changed into a generic Operator + */ void toGenericOp(std::shared_ptr<Node> node); +/** + * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary. + * This recipie only operates memory transformations on the weight tensor. + * First, permutes the dimensions to match the dataformat NHWC + * Second, compact the last dimension of the weights (Channel dimension) into 8bits + * + * @param node Node + */ +void applyWeightInterleaving(std::shared_ptr<Node> node); + } // namespace Aidge #endif /* AIDGE_CORE_UTILS_RECIPES_H_ */ diff --git a/src/recipes/ApplyWeightInterleaving.cpp b/src/recipes/ApplyWeightInterleaving.cpp new file mode 100644 index 000000000..42d65788b --- /dev/null +++ b/src/recipes/ApplyWeightInterleaving.cpp @@ -0,0 +1,97 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/data/Data.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/WeightInterleaving.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/recipes/Recipes.hpp" + + + + +void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ + auto weightProducer = node->getParent(1); + AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name()) + + auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone()); + auto backend = node->getOperator()->backend(); + + const Aidge::DataType weightDataType = weightTensor->dataType(); + + weightTensor->print(); + + // 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM + // Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions + // Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC + + if (weightTensor->dataFormat() == Aidge::DataFormat::Default) { + weightTensor->setDataFormat(Aidge::DataFormat::NCHW); + } + + // Apply permutation for NHWC format + if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) { + weightTensor->setDataFormat(Aidge::DataFormat::NHWC); + } + + weightTensor->print(); + + // 2 - Apply Weight interleaving + // Instanciate weight Interleaving operator + auto WIOp = WeightInterleaving_Op(); + + + // Forward the Weight INterleaving op + WIOp.associateInput(0, weightTensor); + + switch (weightDataType) { + case Aidge::DataType::Int4: + WIOp.setDataType(Aidge::DataType::Dual_Int4); + break; + case Aidge::DataType::UInt4: + WIOp.setDataType(Aidge::DataType::Dual_UInt4); + break; + case Aidge::DataType::Int3: + WIOp.setDataType(Aidge::DataType::Dual_Int3); + break; + case Aidge::DataType::UInt3: + WIOp.setDataType(Aidge::DataType::Dual_UInt3); + break; + case Aidge::DataType::Int2: + WIOp.setDataType(Aidge::DataType::Quad_Int2); + break; + case Aidge::DataType::UInt2: + WIOp.setDataType(Aidge::DataType::Quad_UInt2); + break; + case Aidge::DataType::Binary: + WIOp.setDataType(Aidge::DataType::Octo_Binary); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type {} not supported for weight interleaving.", weightDataType); + } + + WIOp.setDataFormat(Aidge::DataFormat::NHWC); + WIOp.setBackend(backend); + + WIOp.forward(); + + WIOp.getOutput(0)->print(); + + // 3 - Replace the Weight Producer + auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())}; + auto oldProducer = {weightProducer}; + + GraphView::replace(oldProducer, newProducer); + +} \ No newline at end of file -- GitLab