diff --git a/unit_tests/operator/Test_SliceImpl.cpp b/unit_tests/operator/Test_SliceImpl.cpp index 0bf12f9b0faa01798b041462a50ec7db07347130..bc129daeddbf0c04530e836fb9363f3fee684b24 100644 --- a/unit_tests/operator/Test_SliceImpl.cpp +++ b/unit_tests/operator/Test_SliceImpl.cpp @@ -33,8 +33,10 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice]") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); - mySlice->getOperator()->output(0).print(); + // mySlice->getOperator()->output(0).print(); REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput); + REQUIRE(mySlice->getOperator()->output(0).dims() == expectedOutput->dims()); + REQUIRE(mySlice->getOperator()->output(0).dataType() == expectedOutput->dataType()); } SECTION("2D Tensor") { @@ -57,8 +59,10 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice]") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); - mySlice->getOperator()->output(0).print(); + // mySlice->getOperator()->output(0).print(); REQUIRE(*mySlice->getOperator()->getOutput(0) == *expectedOutput); + REQUIRE(mySlice->getOperator()->output(0).dims() == expectedOutput->dims()); + REQUIRE(mySlice->getOperator()->output(0).dataType() == expectedOutput->dataType()); } SECTION("3D Tensor") { @@ -88,8 +92,10 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice]") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); - mySlice->getOperator()->output(0).print(); + // mySlice->getOperator()->output(0).print(); REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput); + REQUIRE(mySlice->getOperator()->output(0).dims() == expectedOutput->dims()); + REQUIRE(mySlice->getOperator()->output(0).dataType() == expectedOutput->dataType()); } SECTION("4D Tensor") { @@ -148,7 +154,9 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice]") { mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->computeOutputDims(); mySlice->forward(); - mySlice->getOperator()->output(0).print(); + // mySlice->getOperator()->output(0).print(); REQUIRE(mySlice->getOperator()->output(0) == *expectedOutput); + REQUIRE(mySlice->getOperator()->output(0).dims() == expectedOutput->dims()); + REQUIRE(mySlice->getOperator()->output(0).dataType() == expectedOutput->dataType()); } } \ No newline at end of file