diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index f55894c71beded9e6b7a20c53f9f22bbea671a01..1c911801c543cac8cb464acaab80e6061703e6e7 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -74,6 +74,12 @@ namespace Aidge { */ bool isNotQuantized(std::shared_ptr<Node> node); + /** + * @brief Compute the absolute max of a tensor + * @param tensor The Tensor to process + */ + double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor); + /** * @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes. * @param graphView The graphView containing the nodes diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp index b1e7b6fcf99a50e707da2fdc7f7c35cdb2d778f7..7919b1af10647379f11d8819d1c3583a6c1fe9cb 100644 --- a/include/aidge/quantization/QAT/QAT_LSQ.hpp +++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp @@ -9,30 +9,30 @@ * ********************************************************************************/ - #ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ - #define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ - - #include <cstddef> // std::size_t - #include <memory> - - #include "aidge/data/Tensor.hpp" - #include "aidge/graph/GraphView.hpp" - - namespace Aidge { - namespace QuantLSQ { - - /** - * @brief Given a GraphView with parameters properly initialized, insert - * the LSQ quantizer nodes, and setup the adjustment their step-sizes. - * @param graphView The GraphView containing the network to quantize. - * @param nbBits Number of quantization bits. - */ - - void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); - - } // namespace QuantLSQ - } // namespace Aidge - - #endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */ +#ifndef AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ +#define AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ + +#include <cstddef> // std::size_t +#include <memory> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" + +namespace Aidge { +namespace QuantLSQ { + +/** + * @brief Given a GraphView with parameters properly initialized, insert + * the LSQ quantizer nodes, and setup the adjustment their step-sizes. + * @param graphView The GraphView containing the network to quantize. + * @param nbBits Number of quantization bits. + */ + +void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); + +} // namespace QuantLSQ +} // namespace Aidge + +#endif /* AIDGE_QUANTIZATION_QUANTIZATION_QAT_LSQ_H_ */ \ No newline at end of file diff --git a/include/aidge/recipes/QuantRecipes.hpp b/include/aidge/recipes/QuantRecipes.hpp index 39349f962d61970020741ba533403ba03559a53f..1e78699c579d53549ada884247ff545ac451f737 100644 --- a/include/aidge/recipes/QuantRecipes.hpp +++ b/include/aidge/recipes/QuantRecipes.hpp @@ -40,6 +40,14 @@ namespace Aidge * @param graphView The GraphView to process. */ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView); + + /** + * @brief Given a GraphView, set all it's MatMul weights to index 1 (required for the PTQ) + * This operation involve the insertion of Transpose nodes as well as the transposition of + * the MatMul weight tensors. + * @param graphView The GraphView to process. + */ + void reorderMatMulInputs(std::shared_ptr<GraphView> graphView); } #endif /* AIDGE_QUANTIZATION_QUANTRECIPES_H_ */ diff --git a/python_binding/recipes/pybind_QuantRecipes.cpp b/python_binding/recipes/pybind_QuantRecipes.cpp index 0b96aef775a32cd362013998dd786a9985cc3fc1..15257b0a6b292d3205b6256fecb221ea0a7c7297 100644 --- a/python_binding/recipes/pybind_QuantRecipes.cpp +++ b/python_binding/recipes/pybind_QuantRecipes.cpp @@ -18,13 +18,15 @@ namespace py = pybind11; -namespace Aidge { - -void init_QuantRecipes(py::module &m) { +namespace Aidge +{ +void init_QuantRecipes(py::module &m) +{ m.def("pop_softmax", &popSoftMax, py::arg("network")); m.def("insert_batchnorm_nodes", &insertBatchNormNodes, py::arg("network")); m.def("sanitize_node_names", &sanitizeNodeNames, py::arg("network")); + m.def("reorder_matmul_inputs", &reorderMatMulInputs, py::arg("network")); } } // namespace Aidge diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index 7115a2f17726c21666306aad8f75bd51eed3eb29..57787a8951a513cd0dc8660c6ef3a99b63e74729 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -52,68 +52,24 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); } -static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) +static bool nodeHasBias(std::shared_ptr<Node> node) { - auto mulOp = Mul_Op(); - mulOp.setDataType(tensor->dataType()); - mulOp.setBackend(tensor->backend()); - - std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(Aidge::Array1D<double, 1> {scaling}); - scalingTensor->setDataType(tensor->dataType()); - scalingTensor->setBackend(tensor->backend()); - - mulOp.associateInput(0, tensor); - mulOp.associateInput(1, scalingTensor); - - mulOp.forward(); - - auto outTensor = mulOp.getOutput(0); - *tensor = *outTensor; - //tensor->copyCast(*outTensor); + if (node->getParents().size() == 3) { + std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); + if (biasTensor) + return true; + } + return false; } -// TODO : make the retreival of argmax values backend independant (refCastFrom) -static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) +// What is this thing ??? +// Function used to extract the local tensor (from a ProducerScalingNode) +std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) { - // get the abs tensor - std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR - std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); - - // flatten the abs tensor - - std::int64_t nbElement = tensor->size(); - - auto reshapeOp = Reshape_Op({nbElement}); - reshapeOp.setDataType(tensor->dataType()); - reshapeOp.setBackend(tensor->backend()); - - reshapeOp.associateInput(0, absTensor); - reshapeOp.forward(); - std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0); - const Tensor& localFlatTensor = flatTensor->refCastFrom(fallback, DataType::Float64, "cpu"); - - // Get the argmax - - auto argmaxOp = ArgMax_Op(0, true, false); - argmaxOp.setDataType(tensor->dataType()); - argmaxOp.setBackend(tensor->backend()); - - argmaxOp.associateInput(0, flatTensor); - argmaxOp.forward(); - - const Tensor& argMaxTensor = argmaxOp.getOutput(0)->refCastFrom(fallback, DataType::Float64, "cpu"); - - // Return the max - - int maxIndex = std::round(argMaxTensor.get<double>(0)); - - return localFlatTensor.get<double>(maxIndex); -} -//Function used to extraxt the local tensor (from a ProducerScalingNode) -std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) { - if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) { + if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) + { std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator()); - operatorTensor->forward();// We need the forward pass to compute the scaled value of the Tensor + operatorTensor->forward(); // We need the forward pass to compute the scaled value of the Tensor return operatorTensor->getOutput(0); } else { return getWeightTensor(node); @@ -129,16 +85,16 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD for (std::shared_ptr<Node> node : nodeVector) { if (node->getChildren().size() > 1) { - Log::notice(" Network have multiple branches, skipping the CLE ... "); + Log::warn(" Network have multiple branches, skipping the CLE ... "); return; } if (isNotQuantized(node)) { - Log::notice(" Network contains non linear nodes, skipping the CLE ... "); + Log::warn(" Network contains non linear nodes, skipping the CLE ... "); return; } } - Log::info(" Applying the Cross-Layer Equalization ... "); + Log::notice(" Applying the Cross-Layer Equalization ... "); // Get the vector of affine nodes @@ -148,13 +104,14 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD affineNodeVector.push_back(node); double maxRangeDelta; - do { maxRangeDelta = 0.0; for (size_t i = 0; i < (affineNodeVector.size() - 1); i++) { + // Log::notice(" node index : {} ", i); + std::shared_ptr<Node> n1 = affineNodeVector[i]; std::shared_ptr<Node> n2 = affineNodeVector[i+1]; @@ -168,8 +125,11 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD double s2 = std::sqrt(r1 * r2) / r2; insertScalingBelowProducer(n1->getParent(1), s1, graphView); + + if (nodeHasBias(n1)) + insertScalingBelowProducer(n1->getParent(2), s1, graphView); + insertScalingBelowProducer(n2->getParent(1), s2, graphView); - insertScalingBelowProducer(n1->getParent(2), s1, graphView); double rangeDelta = std::abs(r1 - r2); if (rangeDelta > maxRangeDelta) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index dbf4b6234d2a64c7fd9572c77882a1dcf64f3f49..0eecc450d7567b8eb0421cd95251ba8ace447a7e 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -29,6 +29,7 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/ArgMax.hpp" #include "aidge/operator/Reshape.hpp" +#include "aidge/operator/MatMul.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/QuantRecipes.hpp" @@ -36,9 +37,25 @@ namespace Aidge { -bool isAffine(std::shared_ptr<Node> node) +static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr) +{ + return node->attributes()->hasAttr("quantization.ptq." + attr); +} + +static void addAttr(std::shared_ptr<Aidge::Node> node, std::string attr, double value = 0.0) { - return (affineNodeTypes.find(node->type()) != affineNodeTypes.end()); + node->attributes()->addAttr("quantization.ptq." + attr, value); +} + +bool isAffine(std::shared_ptr<Node> node) +{ + if (affineNodeTypes.find(node->type()) != affineNodeTypes.end()) + return true; + + if ((node->type() == "MatMul") && hasAttr(node, "isWeighted")) + return true; + + return false; } bool isSeamless(std::shared_ptr<Node> node) @@ -48,7 +65,13 @@ bool isSeamless(std::shared_ptr<Node> node) bool isMerging(std::shared_ptr<Node> node) { - return (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end()); + if (mergingNodeTypes.find(node->type()) != mergingNodeTypes.end()) + return true; + + if ((node->type() == "MatMul") && !hasAttr(node, "isWeighted")) + return true; + + return false; } bool isNotQuantized(std::shared_ptr<Node> node) @@ -58,14 +81,17 @@ bool isNotQuantized(std::shared_ptr<Node> node) bool checkArchitecture(std::shared_ptr<GraphView> graphView) { - std::set<std::string> otherNodeTypes({"Flatten", "Softmax", "BatchNorm2D", "ReLU", "Producer"}); + std::set<std::string> removedNodeTypes({"Flatten", "Softmax", "BatchNorm2D"}); + + std::set<std::string> specialNodeTypes({"MatMul", "ReLU", "Producer"}); std::set<std::string> notQuantizedNodesTypes; for (std::shared_ptr<Node> node : graphView->getNodes()) { - bool isOther = otherNodeTypes.find(node->type()) != otherNodeTypes.end(); - if (!isOther && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) { + bool isRemoved = removedNodeTypes.find(node->type()) != removedNodeTypes.end(); + bool isSpecial = specialNodeTypes.find(node->type()) != specialNodeTypes.end(); + if (!isRemoved && !isSpecial && !isAffine(node) && !isSeamless(node) && !isMerging(node) && !isNotQuantized(node)) { Log::warn(" GraphView can't be quantized : node type {} is not supported !", node->type()); return false; } @@ -86,25 +112,53 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView) void prepareNetwork(std::shared_ptr<GraphView> graphView) { + // remove the flatten nodes + removeFlatten(graphView); - sanitizeNodeNames(graphView); - bool containsBatchNorm = false; + // handle the MatMuls + + reorderMatMulInputs(graphView); + // matMulToFC(graphView); // not working properly atm ! + + // tag the weighted nodes + std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); for (std::shared_ptr<Node> node : nodeVector) + { + bool isWeighted = isAffine(node); + if (node->type() == "MatMul") + { + std::shared_ptr<Node> parent = node->getParent(1); + if (parent) + if (parent->type() == "Producer") + isWeighted = true; + } + + if (isWeighted) + addAttr(node, "isWeighted"); + } + + // fuse the batchnorms + + bool containsBatchNorm = false; + for (std::shared_ptr<Node> node : nodeVector) { if (node->type() == "BatchNorm") { containsBatchNorm = true; break; } + } if (containsBatchNorm) fuseBatchNorm(graphView); + // pop the softmax + popSoftMax(graphView); } -static std::shared_ptr<Aidge::Node> getUniqueChildren(std::shared_ptr<Aidge::Node> node) +static std::shared_ptr<Aidge::Node> getUniqueChild(std::shared_ptr<Aidge::Node> node) { std::set<std::shared_ptr<Aidge::Node>> childrenSet = node->getChildren(); AIDGE_ASSERT(childrenSet.size() == 1, " Attempted to access to a unique child while the parent have multiple ones ! "); @@ -152,8 +206,8 @@ static int getInputIndex(std::shared_ptr<Node> node, std::shared_ptr<Node> paren void multiplyScalingFactor(std::shared_ptr<Aidge::Node> node, double coeff) { - AIDGE_ASSERT(node->type() == "Mul" && (node->attributes()->hasAttr("quantization.ptq.isProducerScaling") || node->attributes()->hasAttr("quantization.ptq.isScaling")), - "Cannot update the scaling factor on Node of type {} with no scaling tag", node->type()); + AIDGE_ASSERT(node->type() == "Mul" && hasAttr(node, "isProducerScaling") || hasAttr(node, "isScaling"), + "Cannot update the scaling factor on Node of type {} with no scaling tag", node->type()); auto scalingFactorTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1); @@ -198,20 +252,21 @@ static void insertChildren(std::shared_ptr<Node> parent, std::shared_ptr<Node> n bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphView> graphView) { - if (node->attributes()->hasAttr("quantization.ptq.isProducerScaling") && node->type() != "Round") + if (hasAttr(node, "isProducerScaling") && node->type() != "Round") { std::shared_ptr<Aidge::Node> roundNode = Round(node->name() + "_Round"); roundNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) roundNode->getOperator()->setBackend(determineBackend(node)); insertChildren(node, roundNode, graphView); - roundNode->attributes()->addAttr("quantization.ptq.isProducerRounding", 0.0); + addAttr(roundNode, "isProducerRounding"); + return true; } return false; } -static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) +double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { // get the abs tensor std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR @@ -264,7 +319,7 @@ static std::vector<std::shared_ptr<Node>> removeProdScalingNodes(std::vector<std { std::vector<std::shared_ptr<Node>> remainingNodes; for (std::shared_ptr<Node> node : nodeVector) - if (!node->attributes()->hasAttr("quantization.ptq.isProducerScaling")) + if (!hasAttr(node, "isProducerScaling")) remainingNodes.push_back(node); return remainingNodes; @@ -300,14 +355,15 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) std::vector<std::shared_ptr<Node>> retrieveNodeVector(std::shared_ptr<GraphView> graphView, bool newSchedule, bool verbose) { std::vector<std::shared_ptr<Node>> nodeVector = graphView->getOrderedNodes(); + + fixScheduling(nodeVector); - fixScheduling(nodeVector); nodeVector = removeMatchingNodes(nodeVector, "Producer"); nodeVector = removeProdScalingNodes(nodeVector); if (verbose) { - Log::info("NB OF NODES = {}", nodeVector.size()); + Log::info(" NB OF NODES = {}", nodeVector.size()); for (std::shared_ptr<Node> node : nodeVector) Log::info("{} {}", node->type(), node->name()); } @@ -331,8 +387,8 @@ static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vec { std::shared_ptr<Node> scalingNode = Mul(name); - for (std::string attr : attributes) - scalingNode->attributes()->addAttr("quantization.ptq." + attr, 0.0); + for (std::string a : attributes) + addAttr(scalingNode, a); // Add the scaling factor as a producer of the node @@ -348,14 +404,14 @@ static std::shared_ptr<Aidge::Node> createScalingNode(std::string name, std::vec bool insertScalingBelowProducer(std::shared_ptr<Node> producerNode, double scalingFactor, std::shared_ptr<GraphView> graphView) { - if (producerNode->attributes()->hasAttr("quantization.ptq.isProducerRounding")) + if (hasAttr(producerNode, "isProducerRounding")) { // In this case we 'bump' the node to the one above him (an actual ProducerScaling) // because the round node is not usable (only used when SSA is enabled) producerNode = producerNode->getParent(0); } - if (producerNode->attributes()->hasAttr("quantization.ptq.isProducerScaling")) + if (hasAttr(producerNode, "isProducerScaling")) { // We accumulate the previous scaling factors by multiplying the SF of the ProducerScaling node // (adding new nodes each time would make the graph unusable) @@ -419,7 +475,7 @@ void insertResidualScalingNodes(std::shared_ptr<GraphView> graphView) static std::shared_ptr<Node> getPreviousScalingNode(std::shared_ptr<Node> node) { std::shared_ptr<Node> currNode = node; - while(!currNode->attributes()->hasAttr("quantization.ptq.isScaling")) + while(!hasAttr(currNode, "isScaling")) { if (currNode->getParents().size() == 0) { @@ -503,7 +559,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) for (std::shared_ptr<Node> node : nodeVector) { // Scaling nodes still have a ratio of 1, so they are seamless ... - if (node->type() == "ReLU" || node->attributes()->hasAttr("quantization.ptq.isScaling") || isSeamless(node)) + if (node->type() == "ReLU" || hasAttr(node, "isScaling") || isSeamless(node)) { if (node != firstNode) { @@ -557,37 +613,51 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // Revert the canceling by using the next scaling node accumulatedRatios[node] = prevRatio; - std::shared_ptr<Node> nextScalingNode = getUniqueChildren(node); + std::shared_ptr<Node> nextScalingNode = getUniqueChild(node); multiplyScalingFactor(nextScalingNode, prevRatio); } if (isMerging(node)) { - std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); + if (node->type() == "MatMul") + { + // Multiply the input scaling factors ! - // Compute the max ratio ... + double leftRatio = accumulatedRatios[node->getParent(0)]; + double rightRatio = accumulatedRatios[node->getParent(1)]; - double maxRatio = 0; - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double merginNodeRatio = accumulatedRatios[mergingNode]; - if (merginNodeRatio > maxRatio) - maxRatio = merginNodeRatio; + accumulatedRatios[node] = leftRatio * rightRatio; } + else + { + // Use a maximum arbitration ! - accumulatedRatios[node] = maxRatio; + std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - // Rescale the previous scaling Nodes - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double mergingNodeRatio = accumulatedRatios[mergingNode]; - double rescaling = mergingNodeRatio / maxRatio; + // Compute the max ratio ... + + double maxRatio = 0; + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double merginNodeRatio = accumulatedRatios[mergingNode]; + if (merginNodeRatio > maxRatio) + maxRatio = merginNodeRatio; + } + + accumulatedRatios[node] = maxRatio; + + // Rescale the previous scaling Nodes + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + double rescaling = mergingNodeRatio / maxRatio; - std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); + std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - multiplyScalingFactor(scalingNode, 1 / rescaling); + multiplyScalingFactor(scalingNode, 1 / rescaling); - accumulatedRatios[mergingNode] /= rescaling; // optional ... + accumulatedRatios[mergingNode] /= rescaling; // optional ... + } } } } @@ -610,7 +680,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< std::set<std::shared_ptr<Node>> nodeSet = graphView->getNodes(); for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) { std::shared_ptr<Operator> nodeOperator = node->getOperator(); std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); @@ -632,7 +702,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< // std::shared_ptr<Node> inputNode = getFirstNode(graphView); for (std::shared_ptr<Node> node : nodeSet) - if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) valueRanges.insert(std::make_pair(node, 0)); if (useCuda) @@ -659,7 +729,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< std::unordered_map<std::shared_ptr<Node>, double> sampleRanges; for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) { std::shared_ptr<Operator> nodeOperator = node->getOperator(); std::shared_ptr<Tensor> valueTensor = std::static_pointer_cast<Tensor> (nodeOperator->getRawOutput(0)); @@ -681,7 +751,7 @@ std::unordered_map<std::shared_ptr<Node>, double> computeRanges(std::shared_ptr< for (std::shared_ptr<Node> node : nodeSet) { - if ((scalingNodesOnly && (node->attributes()->hasAttr("quantization.ptq.isScaling"))) || (!scalingNodesOnly && (node->type() != "Producer"))) + if ((scalingNodesOnly && hasAttr(node, "isScaling")) || (!scalingNodesOnly && (node->type() != "Producer"))) if (sampleRanges[node] > valueRanges[node]) valueRanges[node] = sampleRanges[node]; } @@ -727,7 +797,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m // Use the Scaling nodes to rescale the ranges ... - if (node->attributes()->hasAttr("quantization.ptq.isScaling")) + if (hasAttr(node, "isScaling")) { std::shared_ptr<Node> prevNode = node->getParent(0); @@ -751,26 +821,35 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::unordered_m if (isMerging(node)) { - std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - - // Compute the max ratio ... - - double maxRatio = 0; - for (std::shared_ptr<Node> mergingNode : mergingNodes) + if (node->type() == "MatMul") { - double mergingNodeRatio = accumulatedRatios[mergingNode]; - if (mergingNodeRatio > maxRatio) - maxRatio = mergingNodeRatio; + double leftRatio = accumulatedRatios[node->getParent(0)]; + double rightRatio = accumulatedRatios[node->getParent(1)]; + accumulatedRatios[node] = leftRatio * rightRatio; } + else + { + std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); - accumulatedRatios[node] = maxRatio; + // Compute the max ratio ... - for (std::shared_ptr<Node> mergingNode : mergingNodes) - { - double mergingNodeRatio = accumulatedRatios[mergingNode]; - std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); - multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); - // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + double maxRatio = 0; + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + if (mergingNodeRatio > maxRatio) + maxRatio = mergingNodeRatio; + } + + accumulatedRatios[node] = maxRatio; + + for (std::shared_ptr<Node> mergingNode : mergingNodes) + { + double mergingNodeRatio = accumulatedRatios[mergingNode]; + std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); + multiplyScalingFactor(scalingNode, mergingNodeRatio / maxRatio); + // Log::notice(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); + } } } @@ -820,7 +899,7 @@ std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap( signMap[node].second = false; } - if (node->attributes()->hasAttr("quantization.ptq.isScaling")) + if (hasAttr(node, "isScaling")) { signMap[node].second = false; @@ -867,7 +946,7 @@ std::unordered_map<std::shared_ptr<Node>, std::pair<bool, bool>> computeSignMap( // Arbitration : Signed type wins ! for(std::shared_ptr<Node> parent : parentNodes) { - while (!parent->attributes()->hasAttr("quantization.ptq.isScaling")) + while (!hasAttr(parent, "isScaling")) { signMap[parent] = std::make_pair(false, false); // We are on a branch so nodes always have 1 parent ... @@ -975,7 +1054,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = getUniqueChildren(node); // TODO : assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ... multiplyScalingFactor(scalingNode,rescaling) ; } @@ -990,8 +1069,12 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ rescaling /= inputIsUnsigned ? unsignedMax : signedMax; rescaling *= outputIsUnsigned ? unsignedMax : signedMax; - std::shared_ptr<Node> scalingNode = getUniqueChildren(node); // TODO : assert if scalingNode is a Scaling ... + std::shared_ptr<Node> scalingNode = getUniqueChild(node); // TODO : assert if scalingNode is a Scaling ... + // TODO : double check this ... + if (node->type() == "MatMul") + rescaling /= inputIsUnsigned ? unsignedMax : signedMax; + multiplyScalingFactor(scalingNode, rescaling) ; } @@ -1002,19 +1085,19 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> prevScalingNode = node->getParent(0); multiplyScalingFactor(prevScalingNode, rescaling); - std::shared_ptr<Node> nextScalingNode = getUniqueChildren(node); + std::shared_ptr<Node> nextScalingNode = getUniqueChild(node); multiplyScalingFactor(nextScalingNode, 1 / rescaling); } // Handle the Scaling Nodes ... - if (node->attributes()->hasAttr("quantization.ptq.isScaling")) + if (hasAttr(node, "isScaling")) { // Don't touch the scalings that precede non-linearities ... bool precedesNonLinearNode = false; if (node->getChildren().size() == 1) - if (isNotQuantized(getUniqueChildren(node))) + if (isNotQuantized(getUniqueChild(node))) precedesNonLinearNode = true; if (!noQuant && !precedesNonLinearNode) @@ -1080,7 +1163,8 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); std::shared_ptr<Node> mulNode = Mul(mulNodeName); - mulNode->attributes()->addAttr("quantization.ptq.isCompensation", 0.0); + addAttr(mulNode, "isCompensation"); + mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) mulNode->getOperator()->setBackend(determineBackend(node)); @@ -1142,7 +1226,7 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool static void printScalingFactors(std::shared_ptr<GraphView> graphView) { for (auto node : retrieveNodeVector(graphView)) - if (node->attributes()->hasAttr("quantization.ptq.isScaling") || node->type() == "Quantizer") + if (hasAttr(node, "isScaling") || node->type() == "Quantizer") { double scalingFactor = getScalingFactor(node); Log::info(" {:.6f} ({})", scalingFactor, node->name()); @@ -1182,6 +1266,7 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::notice(" Inserting the scaling nodes ..."); insertScalingNodes(graphView); + // TODO : double check this ! crossLayerEqualization(graphView); Log::notice(" Normalizing the parameters ..."); @@ -1190,13 +1275,9 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::notice(" Computing the value ranges ..."); std::unordered_map<std::shared_ptr<Node>, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda); - //Log::info(" === RANGES (BEFORE ADJUST) ==="); - Log::notice(" Optimizing the clipping values ..."); valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose); - //Log:debug("=== RANGES (AFTER ADJUST) ==="); - //printRanges(graphView, valueRanges); Log::notice(" Normalizing the activations ..."); normalizeActivations(graphView, valueRanges); @@ -1248,7 +1329,7 @@ void clearBiases(std::shared_ptr<GraphView> graphView) if (node->type() == "FC" || node->type() == "Conv2D") { std::shared_ptr<Tensor> biasTensor = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); //rescaleTensor(biasTensor, 0); - insertScalingBelowProducer(node->getParent(2),0,graphView); + insertScalingBelowProducer(node->getParent(2), 0, graphView); } } } diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp index dcac6819365e134d777be7479a95d6b8e4093b5e..6eae077b060027eb4029f6b59f55376a1674df70 100644 --- a/src/QAT/QAT_LSQ.cpp +++ b/src/QAT/QAT_LSQ.cpp @@ -9,164 +9,164 @@ * ********************************************************************************/ - #include "aidge/quantization/QAT/QAT_LSQ.hpp" - #include "aidge/operator/LSQ.hpp" - #include "aidge/operator/ReLU.hpp" - - - #include "aidge/data/Tensor.hpp" - #include "aidge/graph/GraphView.hpp" - #include "aidge/scheduler/SequentialScheduler.hpp" - #include "aidge/scheduler/Scheduler.hpp" - #include "aidge/graph/Matching.hpp" - #include "aidge/recipes/QuantRecipes.hpp" - - - namespace Aidge - { - - static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) - { - auto valueTensor = (*tensor).abs().mean(); - std::shared_ptr<Tensor> fallback; - const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); - return localTensor.get<float>(0); - } - - static float getTensorStd(std::shared_ptr<Tensor> tensor) - { - auto valueTensor = (*tensor); - - auto skewedTensor = valueTensor - valueTensor.mean(); - auto squaredTensor = skewedTensor * skewedTensor; - auto varianceTensor = squaredTensor.mean(); - - std::shared_ptr<Tensor> fallback; - auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); - - float variance = localTensor.get<float>(0); - return std::sqrt(variance); - } - - - // INIT THE STEP SIZE OF A QUANTIZER NODE - - static bool initStepSize(std::shared_ptr<Node> quantizer) - { - const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); - - // This formula is the one proposed in the paper ... - - // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); - // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); - - // .. but this formula seems to work better !!! - - float inputStd = getTensorStd(quantizerOp->getInput(0)); - float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); - - // TODO : use the scalar constructor - auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - - // XXX Manage backend here ? - stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); - stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); - - auto stepSizeProducer = quantizer->getParent(1); - - stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); - - Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); - - return false; - } - - static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) - { - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); - - // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); - - std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; - - // Create the input quantizer node - - auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); - auto quantizerNode = LSQ(signedRange, quantizerName); - - // Init the step-size using the node call stack - - quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - - // Absorb the ReLU when possible ... - - bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? - - if (nodeHasParent) - { - bool allParentsAreReLU = true; - for (auto parentNode : linearNode->getParents()) - if (parentNode->type() != "ReLU") - allParentsAreReLU = false; - - if (allParentsAreReLU) { - auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); - quantizerOp->range() = unsignedRange; - } - - // TODO : remove the ReLUs when possible - } - - // Insert the quantizer in the graphView ... - // (We need to handle the case where the linear node is the first one) - - if (nodeHasParent) { - graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); - } else { - quantizerNode->addChild(graphView); - graphView->add(quantizerNode); - } - } - } - - // PARAM QUANTIZERS INSERTION - - static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) - { - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - - std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); - - // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); - - // TODO : double check this, and use createUniqueName() - auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); - auto quantizerNode = LSQ(signedRange, quantizerName); - - // Init the step-size using the node call stack - - quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - - // Insert the quantizer in the graphView - - graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); - } - } - - void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) - { - sanitizeNodeNames(graphView); - setupInputQuantizers(graphView, nbBits); - setupParamQuantizers(graphView, nbBits); - } - - } \ No newline at end of file +#include "aidge/quantization/QAT/QAT_LSQ.hpp" +#include "aidge/operator/LSQ.hpp" +#include "aidge/operator/ReLU.hpp" + + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/scheduler/Scheduler.hpp" +#include "aidge/graph/Matching.hpp" +#include "aidge/recipes/QuantRecipes.hpp" + + +namespace Aidge +{ + +static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor).abs().mean(); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + return localTensor.get<float>(0); +} + +static float getTensorStd(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor); + + auto skewedTensor = valueTensor - valueTensor.mean(); + auto squaredTensor = skewedTensor * skewedTensor; + auto varianceTensor = squaredTensor.mean(); + + std::shared_ptr<Tensor> fallback; + auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + + float variance = localTensor.get<float>(0); + return std::sqrt(variance); +} + + +// INIT THE STEP SIZE OF A QUANTIZER NODE + +static bool initStepSize(std::shared_ptr<Node> quantizer) +{ + const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); + + // This formula is the one proposed in the paper ... + + // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); + // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); + + // .. but this formula seems to work better !!! + + float inputStd = getTensorStd(quantizerOp->getInput(0)); + float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); + + // TODO : use the scalar constructor + auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); + + // XXX Manage backend here ? + stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); + stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); + + auto stepSizeProducer = quantizer->getParent(1); + + stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); + + Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); + + return false; +} + +static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); + + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); + + // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); + + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; + + // Create the input quantizer node + + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); + + // Init the step-size using the node call stack + + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); + + // Absorb the ReLU when possible ... + + bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? + + if (nodeHasParent) + { + bool allParentsAreReLU = true; + for (auto parentNode : linearNode->getParents()) + if (parentNode->type() != "ReLU") + allParentsAreReLU = false; + + if (allParentsAreReLU) { + auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); + quantizerOp->range() = unsignedRange; + } + + // TODO : remove the ReLUs when possible + } + + // Insert the quantizer in the graphView ... + // (We need to handle the case where the linear node is the first one) + + if (nodeHasParent) { + graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); + } else { + quantizerNode->addChild(graphView); + graphView->add(quantizerNode); + } + } +} + +// PARAM QUANTIZERS INSERTION + +static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); + + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); + + // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); + + // TODO : double check this, and use createUniqueName() + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); + + // Init the step-size using the node call stack + + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); + + // Insert the quantizer in the graphView + + graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); + } +} + +void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + sanitizeNodeNames(graphView); + setupInputQuantizers(graphView, nbBits); + setupParamQuantizers(graphView, nbBits); +} + +} \ No newline at end of file diff --git a/src/backend/cuda/operator/LSQImpl.cpp b/src/backend/cuda/operator/LSQImpl.cpp index fa45f211e72f6742b72584aadf2a109c3bdca594..bb30cc10b6e87f3d6797918d02874ebca48d47ea 100644 --- a/src/backend/cuda/operator/LSQImpl.cpp +++ b/src/backend/cuda/operator/LSQImpl.cpp @@ -53,11 +53,11 @@ void Aidge::LSQImpl_cuda::backward() { std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); if (gra_int0->size() > mWorkspaceSize) { - // std::cout << " reallocation " << sizeof(gra_int0) << " " << gra_int0->size() << std::endl; if (mWorkspace != nullptr) { cudaFree(mWorkspace); } - CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, 8 * gra_int0->size())); // XXX This must be changed !!! + std::size_t sizeOfData = getDataTypeBitWidth(gra_int0->dataType()) / 8; + CHECK_CUDA_STATUS(cudaMalloc(&mWorkspace, sizeOfData * gra_int0->size())); mWorkspaceSize = gra_int0->size(); } diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index f86d454245a7fe088edd027732a91f5775cd2acf..c70a7726c143ed4cd028099f849de25a16ab11d3 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -70,8 +70,6 @@ static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, st return mulNode; } - - void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor) { if(metaOpNode->type() != "Scaling" && metaOpNode->type() != "Quantizer") diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp index f03eb462088b16645fe600769e2a5e2c990f21b6..1806e3d4fd4da402f76c9046b84fcb6acfe69606 100644 --- a/src/recipes/QuantRecipes.cpp +++ b/src/recipes/QuantRecipes.cpp @@ -9,12 +9,17 @@ * ********************************************************************************/ - +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/Transpose.hpp" +#include "aidge/operator/MatMul.hpp" #include "aidge/operator/BatchNorm.hpp" //#include "aidge/quantization/PTQ/PTQ.hpp" #include "aidge/recipes/QuantRecipes.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/graph/Matching.hpp" namespace Aidge @@ -121,4 +126,66 @@ void sanitizeNodeNames(std::shared_ptr<GraphView> graphView) } } -} \ No newline at end of file +void reorderMatMulInputs(std::shared_ptr<GraphView> graphView) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(MatMul#)"); + + for (auto match : matches) + { + auto node = match.graph->rootNode(); + + // Check if the MatMul inputs have to be permuted + + bool permuteInputs = false; + + if (node->getParent(0)) + if (node->getParent(0)->type() == "Producer") + permuteInputs = true; + + if (node->getParent(1)) + if (node->getParent(1)->type() == "Producer") + permuteInputs = false; + + // Perform the permutation of the inputs ... + + if (permuteInputs) + { + auto prevMatMul = node; + auto prevTensor = (std::static_pointer_cast<OperatorTensor> (node->getOperator()))->getInput(0); + + // Create the new MatMul op and it's Producer + + auto newMatMul = MatMul(); + + auto newDims = prevTensor->dims(); + std::swap(newDims[0], newDims[1]); + auto newTensor = std::make_shared<Tensor>(newDims); + + newTensor->setDataType(prevTensor->dataType()); + newTensor->setBackend(prevTensor->backend()); + newTensor->copyTranspose(*prevTensor, std::vector<Aidge::DimSize_t>({1, 0})); + + auto newProducer = Producer(newTensor, ""); + newProducer->addChild(newMatMul, 0, 1); + + // Replace the node by a micrograph + + auto prevMicroGraph = Sequential({prevMatMul}); + prevMicroGraph->add(prevMatMul->getParent(0)); + + auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})}); + newMicroGraph->add(newMatMul->getParent(1)); + + newMicroGraph->setDataType(prevTensor->dataType()); + newMicroGraph->setBackend(prevTensor->backend()); + + graphView->replace(prevMicroGraph, newMicroGraph); + } + } + + // TODO : fold the Transpose operators when possible ... + + // USE REGEXPS !!! +} + +}