diff --git a/src/recipes/ApplyWeightInterleaving.cpp b/src/recipes/ApplyWeightInterleaving.cpp index 42d65788bebc29359740a0fa652e151d1b40b4ba..b9c042a538bc1ece754c5f659048e9c5f6c0d249 100644 --- a/src/recipes/ApplyWeightInterleaving.cpp +++ b/src/recipes/ApplyWeightInterleaving.cpp @@ -12,10 +12,13 @@ #include <memory> #include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/WeightInterleaving.hpp" +#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/recipes/Recipes.hpp" @@ -26,34 +29,55 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name()) auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone()); - auto backend = node->getOperator()->backend(); + // auto backend = node->getOperator()->backend(); + // Cover the case of Generic Operators + auto backend = node->getOperator()->backend().empty() ? "cpu" : node->getOperator()->backend(); const Aidge::DataType weightDataType = weightTensor->dataType(); - weightTensor->print(); - // 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM // Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions // Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC - if (weightTensor->dataFormat() == Aidge::DataFormat::Default) { - weightTensor->setDataFormat(Aidge::DataFormat::NCHW); + std::shared_ptr<Tensor> transposedWeightTensor; + + // Case 4D tensor (conv) + if (weightTensor->nbDims() == 4) + { + if (weightTensor->dataFormat() == Aidge::DataFormat::Default) { + weightTensor->setDataFormat(Aidge::DataFormat::NCHW); + } + + // Apply permutation for NHWC format + if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) { + weightTensor->setDataFormat(Aidge::DataFormat::NHWC); + } + + transposedWeightTensor = weightTensor; + } - - // Apply permutation for NHWC format - if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) { - weightTensor->setDataFormat(Aidge::DataFormat::NHWC); + else if (weightTensor->nbDims() == 2) + { + std::shared_ptr<Node> myTranspose = Transpose({1, 0}); + auto op = std::static_pointer_cast<OperatorTensor>(myTranspose -> getOperator()); + op->associateInput(0,weightTensor); + op->setDataType(weightDataType); + op->setBackend("cpu"); + myTranspose->forward(); + + transposedWeightTensor = op->getOutput(0); + transposedWeightTensor->setDataFormat(Aidge::DataFormat::NHWC); + + } else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot transpose {} weights.", node->name()); } - - weightTensor->print(); - + // 2 - Apply Weight interleaving // Instanciate weight Interleaving operator auto WIOp = WeightInterleaving_Op(); - // Forward the Weight INterleaving op - WIOp.associateInput(0, weightTensor); + WIOp.associateInput(0, transposedWeightTensor); switch (weightDataType) { case Aidge::DataType::Int4: @@ -86,8 +110,6 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ WIOp.forward(); - WIOp.getOutput(0)->print(); - // 3 - Replace the Weight Producer auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())}; auto oldProducer = {weightProducer};