diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index ea592a65cb67ac04f6f02273d28de939df40b6d5..ee179912068f126409e98574221cf4dcf7916934 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -112,24 +112,19 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView) void prepareNetwork(std::shared_ptr<GraphView> graphView) { - // XXX remove this ! - - sanitizeNodeNames(graphView); - // remove the flatten nodes removeFlatten(graphView); - std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); - // handle the MatMuls reorderMatMulInputs(graphView); - - matMulToFC(graphView); + // matMulToFC(graphView); // not working properly atm ! // tag the weighted nodes + std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); + for (std::shared_ptr<Node> node : nodeVector) { bool isWeighted = isAffine(node); diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp index c1848828dd3aef121819547d1e82174185decdab..1806e3d4fd4da402f76c9046b84fcb6acfe69606 100644 --- a/src/recipes/QuantRecipes.cpp +++ b/src/recipes/QuantRecipes.cpp @@ -176,6 +176,9 @@ void reorderMatMulInputs(std::shared_ptr<GraphView> graphView) auto newMicroGraph = Sequential({Transpose({1, 0}), newMatMul, Transpose({1, 0})}); newMicroGraph->add(newMatMul->getParent(1)); + newMicroGraph->setDataType(prevTensor->dataType()); + newMicroGraph->setBackend(prevTensor->backend()); + graphView->replace(prevMicroGraph, newMicroGraph); } }