Skip to content

Commit 47ff3b1

Browse files
committed
Check default=None only in schema command
We have existing models that violate this check. So only enable this check in schema command which is called during cog build.
1 parent ed355a2 commit 47ff3b1

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

python/coglet/inspector.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ def _find_coders(module: ModuleType) -> None:
327327
)
328328

329329

330-
def create_predictor(module_name: str, predictor_name: str) -> adt.Predictor:
330+
def create_predictor(
331+
module_name: str, predictor_name: str, inspect_ast: bool = False
332+
) -> adt.Predictor:
331333
module = importlib.import_module(module_name)
332334
fullname = f'{module_name}.{predictor_name}'
333335
assert hasattr(module, predictor_name), f'predictor not found: {fullname}'
@@ -354,7 +356,9 @@ def create_predictor(module_name: str, predictor_name: str) -> adt.Predictor:
354356
predictor = _predictor_adt(module_name, predictor_name, predict_fn, is_class_fn)
355357

356358
# AST checks at the end after all other checks pass
357-
if module.__file__ is not None:
359+
# Only check when running from cog.command.openapi_schema -> coglet.schema
360+
# So that old models that violate this check can still run
361+
if inspect_ast and module.__file__ is not None:
358362
asts.inspect(module.__file__)
359363

360364
return predictor

python/coglet/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _write(s: str) -> int:
2929
# - Bad dependencies
3030
# - Bad input/output types
3131
# - Libraries downloading weights on init
32-
p = inspector.create_predictor(sys.argv[1], sys.argv[2])
32+
p = inspector.create_predictor(sys.argv[1], sys.argv[2], inspect_ast=True)
3333

3434
# Check that test_inputs exists and is valid
3535
module = importlib.import_module(p.module_name)

python/tests/test_bad_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run(module_name: str, predictor_name: str) -> None:
2020
m = importlib.import_module(module_name)
2121
err_msg = getattr(m, 'ERROR')
2222
with pytest.raises(AssertionError, match=re.escape(err_msg)):
23-
inspector.create_predictor(module_name, predictor_name)
23+
inspector.create_predictor(module_name, predictor_name, inspect_ast=True)
2424
except PythonVersionError as e:
2525
pytest.skip(reason=str(e))
2626

0 commit comments

Comments
 (0)