Skip to content
Snippets Groups Projects
Commit 88558942 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Update applyWeightInterleaving recipe to handle 2D weight tensor (FC)

parent 309ab5a7
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
Pipeline #61006 passed
...@@ -12,10 +12,13 @@ ...@@ -12,10 +12,13 @@
#include <memory> #include <memory>
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/operator/WeightInterleaving.hpp" #include "aidge/operator/WeightInterleaving.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
...@@ -26,34 +29,55 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ ...@@ -26,34 +29,55 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){
AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name()) AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name())
auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone()); auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone());
auto backend = node->getOperator()->backend(); // auto backend = node->getOperator()->backend();
// Cover the case of Generic Operators
auto backend = node->getOperator()->backend().empty() ? "cpu" : node->getOperator()->backend();
const Aidge::DataType weightDataType = weightTensor->dataType(); const Aidge::DataType weightDataType = weightTensor->dataType();
weightTensor->print();
// 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM // 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM
// Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions // Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions
// Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC // Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC
if (weightTensor->dataFormat() == Aidge::DataFormat::Default) { std::shared_ptr<Tensor> transposedWeightTensor;
weightTensor->setDataFormat(Aidge::DataFormat::NCHW);
// Case 4D tensor (conv)
if (weightTensor->nbDims() == 4)
{
if (weightTensor->dataFormat() == Aidge::DataFormat::Default) {
weightTensor->setDataFormat(Aidge::DataFormat::NCHW);
}
// Apply permutation for NHWC format
if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) {
weightTensor->setDataFormat(Aidge::DataFormat::NHWC);
}
transposedWeightTensor = weightTensor;
} }
else if (weightTensor->nbDims() == 2)
// Apply permutation for NHWC format {
if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) { std::shared_ptr<Node> myTranspose = Transpose({1, 0});
weightTensor->setDataFormat(Aidge::DataFormat::NHWC); auto op = std::static_pointer_cast<OperatorTensor>(myTranspose -> getOperator());
op->associateInput(0,weightTensor);
op->setDataType(weightDataType);
op->setBackend("cpu");
myTranspose->forward();
transposedWeightTensor = op->getOutput(0);
transposedWeightTensor->setDataFormat(Aidge::DataFormat::NHWC);
} else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot transpose {} weights.", node->name());
} }
weightTensor->print();
// 2 - Apply Weight interleaving // 2 - Apply Weight interleaving
// Instanciate weight Interleaving operator // Instanciate weight Interleaving operator
auto WIOp = WeightInterleaving_Op(); auto WIOp = WeightInterleaving_Op();
// Forward the Weight INterleaving op // Forward the Weight INterleaving op
WIOp.associateInput(0, weightTensor); WIOp.associateInput(0, transposedWeightTensor);
switch (weightDataType) { switch (weightDataType) {
case Aidge::DataType::Int4: case Aidge::DataType::Int4:
...@@ -86,8 +110,6 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){ ...@@ -86,8 +110,6 @@ void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){
WIOp.forward(); WIOp.forward();
WIOp.getOutput(0)->print();
// 3 - Replace the Weight Producer // 3 - Replace the Weight Producer
auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())}; auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())};
auto oldProducer = {weightProducer}; auto oldProducer = {weightProducer};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment