Skip to content
Snippets Groups Projects
Commit 3ab0a780 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Add recipe apply weightInterleaving

parent 006e8f4b
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
...@@ -180,19 +180,24 @@ size_t convToMatMul(std::shared_ptr<GraphView> graph); ...@@ -180,19 +180,24 @@ size_t convToMatMul(std::shared_ptr<GraphView> graph);
*/ */
void adaptToBackend(std::shared_ptr<GraphView> graph); void adaptToBackend(std::shared_ptr<GraphView> graph);
// /**
// * @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary.
// * This recipie only operates memory transformations on the weight tensor.
// * First, permutes the dimensions to match the dataformat NHWC
// * Second, compact the last dimension (Channel dimension) into int8_t
// *
// * @param node Node
// */
// void applyWeightInterleaving(std::shared_ptr<Node> node);
/**
* @brief Create a GenericOp from an Operator and replace it
*
* @param node Node which Operator will be changed into a generic Operator
*/
void toGenericOp(std::shared_ptr<Node> node); void toGenericOp(std::shared_ptr<Node> node);
/**
* @brief The node passed contains an operator which input of index 1 is supposed be be weights of type Int4, Int3, Int2, binary.
* This recipie only operates memory transformations on the weight tensor.
* First, permutes the dimensions to match the dataformat NHWC
* Second, compact the last dimension of the weights (Channel dimension) into 8bits
*
* @param node Node
*/
void applyWeightInterleaving(std::shared_ptr<Node> node);
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_UTILS_RECIPES_H_ */ #endif /* AIDGE_CORE_UTILS_RECIPES_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/WeightInterleaving.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
void Aidge::applyWeightInterleaving(std::shared_ptr<Node> node){
auto weightProducer = node->getParent(1);
AIDGE_ASSERT(weightProducer, "Cannot Apply Weight Interleaving on {} because it has no weights linked", node->name())
auto weightTensor = std::make_shared<Aidge::Tensor>(std::static_pointer_cast<Aidge::OperatorTensor>(weightProducer->getOperator())->getOutput(0)->clone());
auto backend = node->getOperator()->backend();
const Aidge::DataType weightDataType = weightTensor->dataType();
weightTensor->print();
// 1 - Apply dataformat NHWC to match the custom kernel implementation for ARM cortexM
// Issue : If the dataFormat is Default then setting it to NHWC won't permute dimensions
// Fix : If the datatype is at default then set it to NCHW THEN set it to NHWC
if (weightTensor->dataFormat() == Aidge::DataFormat::Default) {
weightTensor->setDataFormat(Aidge::DataFormat::NCHW);
}
// Apply permutation for NHWC format
if (weightTensor->dataFormat() != Aidge::DataFormat::NHWC) {
weightTensor->setDataFormat(Aidge::DataFormat::NHWC);
}
weightTensor->print();
// 2 - Apply Weight interleaving
// Instanciate weight Interleaving operator
auto WIOp = WeightInterleaving_Op();
// Forward the Weight INterleaving op
WIOp.associateInput(0, weightTensor);
switch (weightDataType) {
case Aidge::DataType::Int4:
WIOp.setDataType(Aidge::DataType::Dual_Int4);
break;
case Aidge::DataType::UInt4:
WIOp.setDataType(Aidge::DataType::Dual_UInt4);
break;
case Aidge::DataType::Int3:
WIOp.setDataType(Aidge::DataType::Dual_Int3);
break;
case Aidge::DataType::UInt3:
WIOp.setDataType(Aidge::DataType::Dual_UInt3);
break;
case Aidge::DataType::Int2:
WIOp.setDataType(Aidge::DataType::Quad_Int2);
break;
case Aidge::DataType::UInt2:
WIOp.setDataType(Aidge::DataType::Quad_UInt2);
break;
case Aidge::DataType::Binary:
WIOp.setDataType(Aidge::DataType::Octo_Binary);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Data type {} not supported for weight interleaving.", weightDataType);
}
WIOp.setDataFormat(Aidge::DataFormat::NHWC);
WIOp.setBackend(backend);
WIOp.forward();
WIOp.getOutput(0)->print();
// 3 - Replace the Weight Producer
auto newProducer = {Producer(WIOp.getOutput(0), weightProducer->name())};
auto oldProducer = {weightProducer};
GraphView::replace(oldProducer, newProducer);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment