diff --git a/unit_tests/operator/Test_ScalingMeta.cpp b/unit_tests/operator/Test_ScalingMeta.cpp index 592531b81b9ac133445a1a4f0bafc283a24a59c8..ae39ee0ac53c3c1b9c2300035f92444855bdcfb8 100644 --- a/unit_tests/operator/Test_ScalingMeta.cpp +++ b/unit_tests/operator/Test_ScalingMeta.cpp @@ -63,13 +63,16 @@ TEST_CASE("ScalingNodeMeta", "[ScalingMeta][CPU]") { auto scal = MulPTQ(2.001); auto scalop = std::static_pointer_cast<OperatorTensor>(scal->getOperator()); - t0->setBackend("cpu"); - scalop->associateInput(0,t0); - scalop->setBackend("cpu"); - scalop->forwardDims(); - //scalop->forward(); - scal->forward(); + scal->getOperator()->associateInput(0,t0); + + auto g = getConnectedGraphView(scal); + g->setDataType(DataType::Float32); + g->setBackend("cpu"); + + auto scheduler = SequentialScheduler(g); + scheduler.forward(); + auto out0 = scalop->getOutput(0); auto in0 = scalop->getInput(0); auto in1 = scalop->getInput(1);