Skip to content
Snippets Groups Projects
Commit 3f669a98 authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

Merge branch 'add_matmul' into 'dev'

Add support for the MatMul operator

See merge request !45
parents f2713d00 541e7299
No related branches found
No related tags found
3 merge requests!54Update 0.3.1 -> 0.4.0,!49Forked from add_matmul (merged automatically),!45Add support for the MatMul operator
Pipeline #68196 failed
...@@ -74,6 +74,12 @@ namespace Aidge { ...@@ -74,6 +74,12 @@ namespace Aidge {
*/ */
bool isNotQuantized(std::shared_ptr<Node> node); bool isNotQuantized(std::shared_ptr<Node> node);
/**
* @brief Compute the absolute max of a tensor
* @param tensor The Tensor to process
*/
double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor);
/** /**
* @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes. * @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes.
* @param graphView The graphView containing the nodes * @param graphView The graphView containing the nodes
......
...@@ -9,30 +9,30 @@ ...@@ -9,30 +9,30 @@
* *
********************************************************************************/ ********************************************************************************/
#ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ #ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_
#define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ #define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include <memory> #include <memory>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
namespace Aidge { namespace Aidge {
namespace QuantLSQ { namespace QuantLSQ {
/** /**
* @brief Given a GraphView with parameters properly initialized, insert * @brief Given a GraphView with parameters properly initialized, insert
* the LSQ quantizer nodes, and setup the adjustment their step-sizes. * the LSQ quantizer nodes, and setup the adjustment their step-sizes.
* @param graphView The GraphView containing the network to quantize. * @param graphView The GraphView containing the network to quantize.
* @param nbBits Number of quantization bits. * @param nbBits Number of quantization bits.
*/ */
void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
} // namespace QuantLSQ } // namespace QuantLSQ
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */ #endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */
\ No newline at end of file
...@@ -40,6 +40,14 @@ namespace Aidge ...@@ -40,6 +40,14 @@ namespace Aidge
* @param graphView The GraphView to process. * @param graphView The GraphView to process.
*/ */
void sanitizeNodeNames(std::shared_ptr<GraphView> graphView); void sanitizeNodeNames(std::shared_ptr<GraphView> graphView);
/**
* @brief Given a GraphView, set all it's MatMul weights to index 1 (required for the PTQ)
* This operation involve the insertion of Transpose nodes as well as the transposition of
* the MatMul weight tensors.
* @param graphView The GraphView to process.
*/
void reorderMatMulInputs(std::shared_ptr<GraphView> graphView);
} }
#endif /* AIDGE_QUANTIZATION_QUANTRECIPES_H_ */ #endif /* AIDGE_QUANTIZATION_QUANTRECIPES_H_ */
...@@ -18,13 +18,15 @@ ...@@ -18,13 +18,15 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge
{
void init_QuantRecipes(py::module &m) {
void init_QuantRecipes(py::module &m)
{
m.def("pop_softmax", &popSoftMax, py::arg("network")); m.def("pop_softmax", &popSoftMax, py::arg("network"));
m.def("insert_batchnorm_nodes", &insertBatchNormNodes, py::arg("network")); m.def("insert_batchnorm_nodes", &insertBatchNormNodes, py::arg("network"));
m.def("sanitize_node_names", &sanitizeNodeNames, py::arg("network")); m.def("sanitize_node_names", &sanitizeNodeNames, py::arg("network"));
m.def("reorder_matmul_inputs", &reorderMatMulInputs, py::arg("network"));
} }
} // namespace Aidge } // namespace Aidge
...@@ -52,68 +52,24 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) ...@@ -52,68 +52,24 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
} }
static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) static bool nodeHasBias(std::shared_ptr<Node> node)
{ {
auto mulOp = Mul_Op(); if (node->getParents().size() == 3) {
mulOp.setDataType(tensor->dataType()); std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
mulOp.setBackend(tensor->backend()); if (biasTensor)
return true;
std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(Aidge::Array1D<double, 1> {scaling}); }
scalingTensor->setDataType(tensor->dataType()); return false;
scalingTensor->setBackend(tensor->backend());
mulOp.associateInput(0, tensor);
mulOp.associateInput(1, scalingTensor);
mulOp.forward();
auto outTensor = mulOp.getOutput(0);
*tensor = *outTensor;
//tensor->copyCast(*outTensor);
} }
// TODO : make the retreival of argmax values backend independant (refCastFrom) // What is this thing ???
static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) // Function used to extract the local tensor (from a ProducerScalingNode)
std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node)
{ {
// get the abs tensor if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling"))
std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR {
std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
// flatten the abs tensor
std::int64_t nbElement = tensor->size();
auto reshapeOp = Reshape_Op({nbElement});
reshapeOp.setDataType(tensor->dataType());
reshapeOp.setBackend(tensor->backend());
reshapeOp.associateInput(0, absTensor);
reshapeOp.forward();
std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0);
const Tensor& localFlatTensor = flatTensor->refCastFrom(fallback, DataType::Float64, "cpu");
// Get the argmax
auto argmaxOp = ArgMax_Op(0, true, false);
argmaxOp.setDataType(tensor->dataType());
argmaxOp.setBackend(tensor->backend());
argmaxOp.associateInput(0, flatTensor);
argmaxOp.forward();
const Tensor& argMaxTensor = argmaxOp.getOutput(0)->refCastFrom(fallback, DataType::Float64, "cpu");
// Return the max
int maxIndex = std::round(argMaxTensor.get<double>(0));
return localFlatTensor.get<double>(maxIndex);
}
//Function used to extraxt the local tensor (from a ProducerScalingNode)
std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) {
if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) {
std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator()); std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator());
operatorTensor->forward();// We need the forward pass to compute the scaled value of the Tensor operatorTensor->forward(); // We need the forward pass to compute the scaled value of the Tensor
return operatorTensor->getOutput(0); return operatorTensor->getOutput(0);
} else { } else {
return getWeightTensor(node); return getWeightTensor(node);
...@@ -129,16 +85,16 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD ...@@ -129,16 +85,16 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
for (std::shared_ptr<Node> node : nodeVector) for (std::shared_ptr<Node> node : nodeVector)
{ {
if (node->getChildren().size() > 1) { if (node->getChildren().size() > 1) {
Log::notice(" Network have multiple branches, skipping the CLE ... "); Log::warn(" Network have multiple branches, skipping the CLE ... ");
return; return;
} }
if (isNotQuantized(node)) { if (isNotQuantized(node)) {
Log::notice(" Network contains non linear nodes, skipping the CLE ... "); Log::warn(" Network contains non linear nodes, skipping the CLE ... ");
return; return;
} }
} }
Log::info(" Applying the Cross-Layer Equalization ... "); Log::notice(" Applying the Cross-Layer Equalization ... ");
// Get the vector of affine nodes // Get the vector of affine nodes
...@@ -148,13 +104,14 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD ...@@ -148,13 +104,14 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
affineNodeVector.push_back(node); affineNodeVector.push_back(node);
double maxRangeDelta; double maxRangeDelta;
do do
{ {
maxRangeDelta = 0.0; maxRangeDelta = 0.0;
for (size_t i = 0; i < (affineNodeVector.size() - 1); i++) for (size_t i = 0; i < (affineNodeVector.size() - 1); i++)
{ {
// Log::notice(" node index : {} ", i);
std::shared_ptr<Node> n1 = affineNodeVector[i]; std::shared_ptr<Node> n1 = affineNodeVector[i];
std::shared_ptr<Node> n2 = affineNodeVector[i+1]; std::shared_ptr<Node> n2 = affineNodeVector[i+1];
...@@ -168,8 +125,11 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD ...@@ -168,8 +125,11 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
double s2 = std::sqrt(r1 * r2) / r2; double s2 = std::sqrt(r1 * r2) / r2;
insertScalingBelowProducer(n1->getParent(1), s1, graphView); insertScalingBelowProducer(n1->getParent(1), s1, graphView);
if (nodeHasBias(n1))
insertScalingBelowProducer(n1->getParent(2), s1, graphView);
insertScalingBelowProducer(n2->getParent(1), s2, graphView); insertScalingBelowProducer(n2->getParent(1), s2, graphView);
insertScalingBelowProducer(n1->getParent(2), s1, graphView);
double rangeDelta = std::abs(r1 - r2); double rangeDelta = std::abs(r1 - r2);
if (rangeDelta > maxRangeDelta) if (rangeDelta > maxRangeDelta)
......
This diff is collapsed.
...@@ -9,164 +9,164 @@ ...@@ -9,164 +9,164 @@
* *
********************************************************************************/ ********************************************************************************/
#include "aidge/quantization/QAT/QAT_LSQ.hpp" #include "aidge/quantization/QAT/QAT_LSQ.hpp"
#include "aidge/operator/LSQ.hpp" #include "aidge/operator/LSQ.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/graph/Matching.hpp" #include "aidge/graph/Matching.hpp"
#include "aidge/recipes/QuantRecipes.hpp" #include "aidge/recipes/QuantRecipes.hpp"
namespace Aidge namespace Aidge
{ {
static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
{ {
auto valueTensor = (*tensor).abs().mean(); auto valueTensor = (*tensor).abs().mean();
std::shared_ptr<Tensor> fallback; std::shared_ptr<Tensor> fallback;
const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu");
return localTensor.get<float>(0); return localTensor.get<float>(0);
} }
static float getTensorStd(std::shared_ptr<Tensor> tensor) static float getTensorStd(std::shared_ptr<Tensor> tensor)
{ {
auto valueTensor = (*tensor); auto valueTensor = (*tensor);
auto skewedTensor = valueTensor - valueTensor.mean(); auto skewedTensor = valueTensor - valueTensor.mean();
auto squaredTensor = skewedTensor * skewedTensor; auto squaredTensor = skewedTensor * skewedTensor;
auto varianceTensor = squaredTensor.mean(); auto varianceTensor = squaredTensor.mean();
std::shared_ptr<Tensor> fallback; std::shared_ptr<Tensor> fallback;
auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu");
float variance = localTensor.get<float>(0); float variance = localTensor.get<float>(0);
return std::sqrt(variance); return std::sqrt(variance);
} }
// INIT THE STEP SIZE OF A QUANTIZER NODE // INIT THE STEP SIZE OF A QUANTIZER NODE
static bool initStepSize(std::shared_ptr<Node> quantizer) static bool initStepSize(std::shared_ptr<Node> quantizer)
{ {
const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator());
// This formula is the one proposed in the paper ... // This formula is the one proposed in the paper ...
// float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0));
// float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second));
// .. but this formula seems to work better !!! // .. but this formula seems to work better !!!
float inputStd = getTensorStd(quantizerOp->getInput(0)); float inputStd = getTensorStd(quantizerOp->getInput(0));
float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); float stepSize = 8.0f * (inputStd / (quantizerOp->range().second));
// TODO : use the scalar constructor // TODO : use the scalar constructor
auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}));
// XXX Manage backend here ? // XXX Manage backend here ?
stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend());
stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType());
auto stepSizeProducer = quantizer->getParent(1); auto stepSizeProducer = quantizer->getParent(1);
stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor);
Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize);
return false; return false;
} }
static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
{ {
const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
for (const auto& match : matches) for (const auto& match : matches)
{ {
auto linearNode = match.graph->rootNode(); auto linearNode = match.graph->rootNode();
// Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type());
std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1};
// Create the input quantizer node // Create the input quantizer node
auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView);
auto quantizerNode = LSQ(signedRange, quantizerName); auto quantizerNode = LSQ(signedRange, quantizerName);
// Init the step-size using the node call stack // Init the step-size using the node call stack
quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
// Absorb the ReLU when possible ... // Absorb the ReLU when possible ...
bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ?
if (nodeHasParent) if (nodeHasParent)
{ {
bool allParentsAreReLU = true; bool allParentsAreReLU = true;
for (auto parentNode : linearNode->getParents()) for (auto parentNode : linearNode->getParents())
if (parentNode->type() != "ReLU") if (parentNode->type() != "ReLU")
allParentsAreReLU = false; allParentsAreReLU = false;
if (allParentsAreReLU) { if (allParentsAreReLU) {
auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator());
quantizerOp->range() = unsignedRange; quantizerOp->range() = unsignedRange;
} }
// TODO : remove the ReLUs when possible // TODO : remove the ReLUs when possible
} }
// Insert the quantizer in the graphView ... // Insert the quantizer in the graphView ...
// (We need to handle the case where the linear node is the first one) // (We need to handle the case where the linear node is the first one)
if (nodeHasParent) { if (nodeHasParent) {
graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); graphView->insertParent(linearNode, quantizerNode, 0, 0, 0);
} else { } else {
quantizerNode->addChild(graphView); quantizerNode->addChild(graphView);
graphView->add(quantizerNode); graphView->add(quantizerNode);
} }
} }
} }
// PARAM QUANTIZERS INSERTION // PARAM QUANTIZERS INSERTION
static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
{ {
const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)");
std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1};
for (const auto& match : matches) for (const auto& match : matches)
{ {
auto linearNode = match.graph->rootNode(); auto linearNode = match.graph->rootNode();
// Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type());
// TODO : double check this, and use createUniqueName() // TODO : double check this, and use createUniqueName()
auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView);
auto quantizerNode = LSQ(signedRange, quantizerName); auto quantizerNode = LSQ(signedRange, quantizerName);
// Init the step-size using the node call stack // Init the step-size using the node call stack
quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); });
// Insert the quantizer in the graphView // Insert the quantizer in the graphView
graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); graphView->insertParent(linearNode, quantizerNode, 1, 0, 0);
} }
} }
void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
{ {
sanitizeNodeNames(graphView); sanitizeNodeNames(graphView);
setupInputQuantizers(graphView, nbBits); setupInputQuantizers(graphView, nbBits);
setupParamQuantizers(graphView, nbBits); setupParamQuantizers(graphView, nbBits);
} }
} }
\ No newline at end of file \ No newline at end of file
...@@ -53,11 +53,11 @@ void Aidge::LSQImpl_cuda::backward() { ...@@ -53,11 +53,11 @@ void Aidge::LSQImpl_cuda::backward() {
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
if (gra_int0->size() > mWorkspaceSize) { if (gra_int0->size() > mWorkspaceSize) {
// std::cout << " reallocation " << sizeof(gra_int0) << " " << gra_int0->size() << std::endl;
if (mWorkspace != nullptr) { if (mWorkspace != nullptr) {
cudaFree(mWorkspace); cudaFree(mWorkspace);
} }
CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, 8 * gra_int0->size())); // XXX This must be changed !!! std::size_t sizeOfData = getDataTypeBitWidth(gra_int0->dataType()) / 8;
CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, sizeOfData * gra_int0->size()));
mWorkspaceSize = gra_int0->size(); mWorkspaceSize = gra_int0->size();
} }
......
...@@ -70,8 +70,6 @@ static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, st ...@@ -70,8 +70,6 @@ static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, st
return mulNode; return mulNode;
} }
void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor) void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor)
{ {
if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer") if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer")
......
...@@ -9,12 +9,17 @@ ...@@ -9,12 +9,17 @@
* *
********************************************************************************/ ********************************************************************************/
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/BatchNorm.hpp"
//#include "aidge/quantization/PTQ/PTQ.hpp" //#include "aidge/quantization/PTQ/PTQ.hpp"
#include "aidge/recipes/QuantRecipes.hpp" #include "aidge/recipes/QuantRecipes.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/graph/Matching.hpp"
namespace Aidge namespace Aidge
...@@ -121,4 +126,66 @@ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView) ...@@ -121,4 +126,66 @@ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView)
} }
} }
} void reorderMatMulInputs(std::shared_ptr<GraphView> graphView)
\ No newline at end of file {
const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)");
for (auto match : matches)
{
auto node = match.graph->rootNode();
// Check if the MatMul inputs have to be permuted
bool permuteInputs = false;
if (node->getParent(0))
if (node->getParent(0)->type() == "Producer")
permuteInputs = true;
if (node->getParent(1))
if (node->getParent(1)->type() == "Producer")
permuteInputs = false;
// Perform the permutation of the inputs ...
if (permuteInputs)
{
auto prevMatMul = node;
auto prevTensor = (std::static_pointer_cast<OperatorTensor> (node->getOperator()))->getInput(0);
// Create the new MatMul op and it's Producer
auto newMatMul = MatMul();
auto newDims = prevTensor->dims();
std::swap(newDims[0], newDims[1]);
auto newTensor = std::make_shared<Tensor>(newDims);
newTensor->setDataType(prevTensor->dataType());
newTensor->setBackend(prevTensor->backend());
newTensor->copyTranspose(*prevTensor, std::vector<Aidge::DimSize_t>({1, 0}));
auto newProducer = Producer(newTensor, "");
newProducer->addChild(newMatMul, 0, 1);
// Replace the node by a micrograph
auto prevMicroGraph = Sequential({prevMatMul});
prevMicroGraph->add(prevMatMul->getParent(0));
auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})});
newMicroGraph->add(newMatMul->getParent(1));
newMicroGraph->setDataType(prevTensor->dataType());
newMicroGraph->setBackend(prevTensor->backend());
graphView->replace(prevMicroGraph, newMicroGraph);
}
}
// TODO : fold the Transpose operators when possible ...
// USE REGEXPS !!!
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment