From 88558942a02bb98c8bba175ad9d7f51c718e3607 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Thu, 5 Dec 2024 16:12:54 +0000 Subject: [PATCH] Update applyWeightInterleaving recipe to handle 2D weight tensor (FC) --- src/recipes/ApplyWeightInterleaving.cpp | 54 +++++++++++++++++-------- 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/src/recipes/ApplyWeightInterleaving.cpp b/src/recipes/ApplyWeightInterleaving.cpp index 42d65788b..b9c042a53 100644 --- a/src/recipes/ApplyWeightInterleaving.cpp +++ b/src/recipes/ApplyWeightInterleaving.cpp @@ -12,10 +12,13 @@ #include <memory> #include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/WeightInterleaving.hpp" +#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/recipes/Recipes.hpp" @@ -26,34 +29,55 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name()) auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone()); - auto backend = node->getOperator()->backend(); + // auto backend = node->getOperator()->backend(); + // Cover the case of Generic Operators + auto backend = node->getOperator()->backend().empty() ? "cpu" : node->getOperator()->backend(); const Aidge::DataType weightDataType = weightTensor->dataType(); - weightTensor->print(); - // 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM // Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions // Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC - if (weightTensor->dataFormat() == Aidge::DataFormat::Default) { - weightTensor->setDataFormat(Aidge::DataFormat::NCHW); + std::shared_ptr<Tensor> transposedWeightTensor; + + // Case 4D tensor (conv) + if (weightTensor->nbDims() == 4) + { + if (weightTensor->dataFormat() == Aidge::DataFormat::Default) { + weightTensor->setDataFormat(Aidge::DataFormat::NCHW); + } + + // Apply permutation for NHWC format + if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) { + weightTensor->setDataFormat(Aidge::DataFormat::NHWC); + } + + transposedWeightTensor = weightTensor; + } - - // Apply permutation for NHWC format - if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) { - weightTensor->setDataFormat(Aidge::DataFormat::NHWC); + else if (weightTensor->nbDims() == 2) + { + std::shared_ptr<Node> myTranspose = Transpose({1, 0}); + auto op = std::static_pointer_cast<OperatorTensor>(myTranspose -> getOperator()); + op->associateInput(0,weightTensor); + op->setDataType(weightDataType); + op->setBackend("cpu"); + myTranspose->forward(); + + transposedWeightTensor = op->getOutput(0); + transposedWeightTensor->setDataFormat(Aidge::DataFormat::NHWC); + + } else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot transpose {} weights.", node->name()); } - - weightTensor->print(); - + // 2 - Apply Weight interleaving // Instanciate weight Interleaving operator auto WIOp = WeightInterleaving_Op(); - // Forward the Weight INterleaving op - WIOp.associateInput(0, weightTensor); + WIOp.associateInput(0, transposedWeightTensor); switch (weightDataType) { case Aidge::DataType::Int4: @@ -86,8 +110,6 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ WIOp.forward(); - WIOp.getOutput(0)->print(); - // 3 - Replace the Weight Producer auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())}; auto oldProducer = {weightProducer}; -- GitLab