diff --git a/unit_tests/recipies/Test_ConvToMatMul.cpp b/unit_tests/recipies/Test_ConvToMatMul.cpp index ef3920f0af7e9c12873416c3f64b429bb9a8b947..e815b1035f4653a20d757798408ca53b6565f589 100644 --- a/unit_tests/recipies/Test_ConvToMatMul.cpp +++ b/unit_tests/recipies/Test_ConvToMatMul.cpp @@ -63,4 +63,14 @@ TEST_CASE("[ConvToMatMul] conv") { auto g1OutOp = std::static_pointer_cast<OperatorTensor>((*g1->outputNodes().cbegin())->getOperator()); auto g2OutOp = std::static_pointer_cast<OperatorTensor>((*g1->outputNodes().cbegin())->getOperator()); REQUIRE(*(g1OutOp->getOutput(0)) == *(g2OutOp->getOutput(0))); + + // Simplify the graph: freeze parameters to allow reshaping of the Producers + for (auto node : g2->getNodes()) { + if (node->type() == Producer_Op::Type && node->name() != "dataProvider") { + std::static_pointer_cast<Producer_Op>(node->getOperator())->getAttr<bool>("Constant") = true; + } + } + + constantFolding(g2); + g2->save("convToMatMul_after_folding"); }