Skip to content
Snippets Groups Projects
Commit 88558942 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Update applyWeightInterleaving recipe to handle 2D weight tensor (FC)

parent 309ab5a7
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
Pipeline #61006 passed
......@@ -12,10 +12,13 @@
#include <memory>
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/WeightInterleaving.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/recipes/Recipes.hpp"
......@@ -26,34 +29,55 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){
AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name())
auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone());
auto backend = node->getOperator()->backend();
// auto backend = node->getOperator()->backend();
// Cover the case of Generic Operators
auto backend = node->getOperator()->backend().empty() ? "cpu" : node->getOperator()->backend();
const Aidge::DataType weightDataType = weightTensor->dataType();
weightTensor->print();
// 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM
// Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions
// Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC
if (weightTensor->dataFormat() == Aidge::DataFormat::Default) {
weightTensor->setDataFormat(Aidge::DataFormat::NCHW);
std::shared_ptr<Tensor> transposedWeightTensor;
// Case 4D tensor (conv)
if (weightTensor->nbDims() == 4)
{
if (weightTensor->dataFormat() == Aidge::DataFormat::Default) {
weightTensor->setDataFormat(Aidge::DataFormat::NCHW);
}
// Apply permutation for NHWC format
if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) {
weightTensor->setDataFormat(Aidge::DataFormat::NHWC);
}
transposedWeightTensor = weightTensor;
}
// Apply permutation for NHWC format
if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) {
weightTensor->setDataFormat(Aidge::DataFormat::NHWC);
else if (weightTensor->nbDims() == 2)
{
std::shared_ptr<Node> myTranspose = Transpose({1, 0});
auto op = std::static_pointer_cast<OperatorTensor>(myTranspose -> getOperator());
op->associateInput(0,weightTensor);
op->setDataType(weightDataType);
op->setBackend("cpu");
myTranspose->forward();
transposedWeightTensor = op->getOutput(0);
transposedWeightTensor->setDataFormat(Aidge::DataFormat::NHWC);
} else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot transpose {} weights.", node->name());
}
weightTensor->print();
// 2 - Apply Weight interleaving
// Instanciate weight Interleaving operator
auto WIOp = WeightInterleaving_Op();
// Forward the Weight INterleaving op
WIOp.associateInput(0, weightTensor);
WIOp.associateInput(0, transposedWeightTensor);
switch (weightDataType) {
case Aidge::DataType::Int4:
......@@ -86,8 +110,6 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){
WIOp.forward();
WIOp.getOutput(0)->print();
// 3 - Replace the Weight Producer
auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())};
auto oldProducer = {weightProducer};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment