From c3b64dd1aada2be79f85a7c0cff066c3224d0c0b Mon Sep 17 00:00:00 2001 From: bhalimi <benjamin.halimi@cea.fr> Date: Tue, 18 Mar 2025 11:00:35 +0000 Subject: [PATCH] modify prepareNetwork() --- src/PTQ/PTQ.cpp | 11 +++-------- src/recipes/QuantRecipes.cpp | 3 +++ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index ea592a6..ee17991 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -112,24 +112,19 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView) void prepareNetwork(std::shared_ptr<GraphView> graphView) { - // XXX remove this ! - - sanitizeNodeNames(graphView); - // remove the flatten nodes removeFlatten(graphView); - std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - // handle the MatMuls reorderMatMulInputs(graphView); - - matMulToFC(graphView); + // matMulToFC(graphView); // not working properly atm ! // tag the weighted nodes + std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); + for (std::shared_ptr<Node> node : nodeVector) { bool isWeighted = isAffine(node); diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp index c184882..1806e3d 100644 --- a/src/recipes/QuantRecipes.cpp +++ b/src/recipes/QuantRecipes.cpp @@ -176,6 +176,9 @@ void reorderMatMulInputs(std::shared_ptr<GraphView> graphView) auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})}); newMicroGraph->add(newMatMul->getParent(1)); + newMicroGraph->setDataType(prevTensor->dataType()); + newMicroGraph->setBackend(prevTensor->backend()); + graphView->replace(prevMicroGraph, newMicroGraph); } } -- GitLab