diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 39d46267e4..fb76038f22 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -136,12 +136,12 @@ def test_weight_stripped_engine_sizes(self): ) assertions.assertTrue( len(bytes(weight_included_engine)) > len(bytes(weight_stripped_engine)), - msg=f"Weight-stripped engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight stripped engine size: {len(bytes(weight_stripped_engine))}", + msg=f"Weight-stripped engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped engine size: {len(bytes(weight_stripped_engine))}", ) assertions.assertTrue( - len(bytes(weight_stripped_engine)) + len(bytes(weight_included_engine)) > len(bytes(weight_stripped_refit_identical_engine)), - msg=f"Weight-stripped refit-identical engine size is not smaller than the weight-stripped engine size. Weight-stripped engine size: {len(bytes(weight_stripped_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", + msg=f"Weight-stripped refit-identical engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", ) def test_weight_stripped_engine_results(self): @@ -200,6 +200,9 @@ def test_weight_stripped_engine_results(self): msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + @unittest.skip( + "For now, torch-trt will save weighted engine if strip_engine_weights is False. In the near future, we plan to save weight-stripped engine regardless of strip_engine_weights, which is pending on TRT's feature development: NVBug #4914602" + ) def test_engine_caching_saves_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)