From 242f367def17eb118ba9b4810298d726711f8f36 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 3 Jul 2024 11:02:07 +0200 Subject: [PATCH] Added folding example --- unit_tests/recipies/Test_ConvToMatMul.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/unit_tests/recipies/Test_ConvToMatMul.cpp b/unit_tests/recipies/Test_ConvToMatMul.cpp index ef3920f0..e815b103 100644 --- a/unit_tests/recipies/Test_ConvToMatMul.cpp +++ b/unit_tests/recipies/Test_ConvToMatMul.cpp @@ -63,4 +63,14 @@ TEST_CASE("[ConvToMatMul] conv") { auto g1OutOp = std::static_pointer_cast<OperatorTensor>((*g1->outputNodes().cbegin())->getOperator()); auto g2OutOp = std::static_pointer_cast<OperatorTensor>((*g1->outputNodes().cbegin())->getOperator()); REQUIRE(*(g1OutOp->getOutput(0)) == *(g2OutOp->getOutput(0))); + + // Simplify the graph: freeze parameters to allow reshaping of the Producers + for (auto node : g2->getNodes()) { + if (node->type() == Producer_Op::Type && node->name() != "dataProvider") { + std::static_pointer_cast<Producer_Op>(node->getOperator())->getAttr<bool>("Constant") = true; + } + } + + constantFolding(g2); + g2->save("convToMatMul_after_folding"); } -- GitLab