diff --git a/test/smoke_test.py b/test/smoke_test.py index d672d46ad9e..a8f36aacc47 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -1,5 +1,6 @@ """Run smoke tests""" +import re import sys from pathlib import Path @@ -79,11 +80,21 @@ def main() -> None: print(f"torchvision: {torchvision.__version__}") print(f"torch.cuda.is_available: {torch.cuda.is_available()}") - # Turn 1.11.0aHASH into 1.11 (major.minor only) - version = ".".join(torchvision.__version__.split(".")[:2]) - if version >= "0.16": - print(f"{torch.ops.image._jpeg_version() = }") - assert torch.ops.image._is_compiled_against_turbo() + # The "a0" after the semantic version should only be present on the main branch or nightly builds, + # but not release branches. + if re.match(r"\d+\.\d+\.\d+(?!a0)\+", torchvision.__version__): + try: + from torchvision import prototype + except ModuleNotFoundError: + pass + else: + raise AssertionError( + "torchvision.prototype available on a release version. " + "Run rm -r torchvision/prototype test/test_prototype* .github/workflows/prototype*" + ) + + print(f"{torch.ops.image._jpeg_version() = }") + assert torch.ops.image._is_compiled_against_turbo() smoke_test_torchvision() smoke_test_torchvision_read_decode()