Skip to content

[BUG] Adding model_checkpoint.dirpath to ModelCheckpoint callback results in validation error #311

@natwille1

Description

@natwille1

🧠 Describe the Bug

Adding model_checkpoint.dirpath to ModelCheckpoint callback results in validation error when running the following:

checkpoint_callback = {
      "model_checkpoint": {
          "every_n_epochs": 10,
          "dirpath": checkpoint_dirname,
          "filepath": checkpoint_filename,
          "save_top_k": 3,
          "save_last": True,
          "monitor": "train_loss",
          "mode": "min",
      }
  }
  mlflow.pytorch.autolog()
  lightly_train.train(
      out=f"experiments/{ssl_method}_{encoder_string}_{batch_size}",
      data=dataset_path,
      method=ssl_method,
      model=encoder,
      transform_args=transform_args,
      epochs=max_epochs,
      batch_size=batch_size,
      num_workers=num_workers,
      accelerator="gpu",
      model_args=model_kwargs,
      precision="bf16-mixed",
      callbacks=checkpoint_callback,
      loader_args={
          "pin_memory": True,
          "prefetch_factor": 8,
          "persistent_workers": True,
      },
  )

Traceback (most recent call last):
File "/app/.venv/lib/python3.10/site-packages/lightly_train/_configs/validate.py", line 30, in pydantic_model_validate
return model.model_validate(obj)
File "/app/.venv/lib/python3.10/site-packages/pydantic/main.py", line 705, in model_validate
return cls.pydantic_validator.validate_python(
pydantic_core._pydantic_core.ValidationError: 1 validation error for CallbackArgs
model_checkpoint.dirpath
Extra inputs are not permitted [type=extra_forbidden, input_value='/mnt/azureml/cr/j/9db43a...y/wd/checkpoint_dirname', input_type=str]
For further information visit https://errors.pydantic.dev/2.11/v/extra_forbidden

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/mnt/azureml/cr/j/9db43aa9856642f2b69aa8aa533a67b3/exe/wd/train_lightly_azure.py", line 152, in
train(
File "/mnt/azureml/cr/j/9db43aa9856642f2b69aa8aa533a67b3/exe/wd/train_lightly_azure.py", line 105, in train
lightly_train.train(
File "/app/.venv/lib/python3.10/site-packages/lightly_train/_commands/train.py", line 238, in train
train_from_config(config=config)
File "/app/.venv/lib/python3.10/site-packages/lightly_train/_commands/train.py", line 329, in train_from_config
config.callbacks = callback_helpers.get_callback_args(
File "/app/.venv/lib/python3.10/site-packages/lightly_train/_callbacks/callback_helpers.py", line 55, in get_callback_args
return validate.pydantic_model_validate(CallbackArgs, callback_args)
File "/app/.venv/lib/python3.10/site-packages/lightly_train/_configs/validate.py", line 34, in pydantic_model_validate
raise ConfigValidationError(
lightly_train.errors.ConfigValidationError: Found 1 errors in the config!
Unknown key: 'model_checkpoint.dirpath'

Is it possible for lightly_train to support all kwargs passable to the various Lightning callbacks please?

🔁 Steps to Reproduce

Please provide a minimal, self-contained code example that reproduces the issue:

  checkpoint_callback = {
        "model_checkpoint": {
            "every_n_epochs": 10,
            "dirpath": checkpoint_dirname,
            "filepath": checkpoint_filename,
            "save_top_k": 3,
            "save_last": True,
            "monitor": "train_loss",
            "mode": "min",
        }
    }
    mlflow.pytorch.autolog()
    lightly_train.train(
        out=f"experiments/{ssl_method}_{encoder_string}_{batch_size}",
        data=dataset_path,
        method=ssl_method,
        model=encoder,
        transform_args=transform_args,
        epochs=max_epochs,
        batch_size=batch_size,
        num_workers=num_workers,
        accelerator="gpu",
        model_args=model_kwargs,
        precision="bf16-mixed",
        callbacks=checkpoint_callback,
        loader_args={
            "pin_memory": True,
            "prefetch_factor": 8,
            "persistent_workers": True,
        },
    )

🤖 Environment Details

  • OS: Ubuntu 22.04
  • Python version: 3.10
  • Frameworks/Libraries (with versions):

file=requirements.txt
absl-py==2.3.1
accelerate==1.1.1
adal==1.2.7
aenum==3.1.16
aiohappyeyeballs==2.6.1
aiohttp==3.11.4
aiosignal==1.4.0
albumentations==1.3.1
alembic==1.16.4
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.10.0
applicationinsights==0.11.10
appnope==0.1.4 ; sys_platform == 'darwin'
argcomplete==3.6.2 ; sys_platform != 'win32'
asgiref==3.9.1
asttokens==3.0.0
async-timeout==5.0.1 ; python_full_version < '3.11'
attrs==25.3.0
azure-ai-ml==1.28.1
azure-common==1.1.28
azure-core==1.35.0
azure-core-tracing-opentelemetry==1.0.0b12
azure-graphrbac==0.61.2
azure-identity==1.24.0
azure-mgmt-authorization==4.0.0
azure-mgmt-containerregistry==13.0.0
azure-mgmt-core==1.6.0
azure-mgmt-keyvault==11.0.0
azure-mgmt-network==29.0.0 ; sys_platform != 'win32'
azure-mgmt-resource==8.0.1 ; sys_platform == 'win32'
azure-mgmt-resource==24.0.0 ; sys_platform != 'win32'
azure-mgmt-storage==23.0.0
azure-ml==0.0.1
azure-ml-component==0.9.1.post2 ; sys_platform == 'win32'
azure-ml-component==0.9.18.post2 ; sys_platform != 'win32'
azure-monitor-opentelemetry==1.6.13
azure-monitor-opentelemetry-exporter==1.0.0b40
azure-storage-blob==12.19.0
azure-storage-file-datalake==12.14.0
azure-storage-file-share==12.22.0
azureml-contrib-services==1.60.0
azureml-core==1.0.85.6 ; sys_platform == 'win32'
azureml-core==1.60.0.post1 ; sys_platform != 'win32'
azureml-dataprep==5.1.6 ; sys_platform != 'win32'
azureml-dataprep-native==41.0.0 ; sys_platform != 'win32'
azureml-dataprep-rslex==2.22.5 ; sys_platform != 'win32'
azureml-dataset-runtime==1.60.0 ; sys_platform != 'win32'
azureml-defaults==1.0.85 ; sys_platform == 'win32'
azureml-defaults==1.60.0 ; sys_platform != 'win32'
azureml-inference-server-http==1.4.1 ; sys_platform != 'win32'
azureml-mlflow==1.60.0.post1
azureml-model-management-sdk==1.0.1b6.post1 ; sys_platform == 'win32'
azureml-telemetry==1.0.85.2 ; sys_platform == 'win32'
azureml-telemetry==1.60.0 ; sys_platform != 'win32'
backports-tempfile==1.0
backports-weakref==1.0.post1
bcrypt==4.3.0 ; sys_platform != 'win32'
beartype==0.19.0
blinker==1.9.0 ; sys_platform != 'win32'
byol-pytorch==0.8.2
bytecode==0.16.2 ; sys_platform != 'win32'
cachetools==5.5.2
certifi==2025.8.3
cffi==1.17.1 ; implementation_name == 'pypy' or platform_python_implementation != 'PyPy' or sys_platform != 'win32'
charset-normalizer==3.4.3
cheroot==10.0.1
clearml==1.16.5
click==8.1.8 ; python_full_version < '3.10'
click==8.2.1 ; python_full_version >= '3.10'
cloudpickle==2.2.1
colorama==0.4.6
colorlog==6.9.0
comm==0.2.3
configparser==3.7.4 ; sys_platform == 'win32'
contextlib2==21.6.0
contourpy==1.3.0 ; python_full_version < '3.10'
contourpy==1.3.2 ; python_full_version >= '3.10'
cryptography==45.0.6
cycler==0.12.1
databricks-sdk==0.63.0
debugpy==1.8.16
decorator==5.2.1
dill==0.4.0 ; sys_platform == 'win32'
docker==7.1.0
etils==1.5.2 ; python_full_version < '3.10'
etils==1.13.0 ; python_full_version >= '3.10'
eval-type-backport==0.2.2
exceptiongroup==1.3.0 ; python_full_version < '3.11'
executing==2.2.0
fastapi==0.116.1
filelock==3.19.1
fixedint==0.1.6
flask==1.0.3 ; sys_platform == 'win32'
flask==3.1.1 ; sys_platform != 'win32'
flask-cors==6.0.1 ; sys_platform != 'win32'
fonttools==4.59.1
frozenlist==1.7.0
fsspec==2025.7.0
furl==2.1.3
fusepy==3.0.1 ; sys_platform != 'win32'
gcsfs==2025.7.0
gitdb==4.0.12
gitpython==3.1.45
google-api-core==2.25.1
google-auth==2.40.3
google-auth-oauthlib==1.2.2
google-cloud-core==2.4.3
google-cloud-storage==3.3.0
google-crc32c==1.7.1
google-resumable-media==2.7.2
googleapis-common-protos==1.70.0
graphene==3.4.3
graphql-core==3.2.6
graphql-relay==3.2.0
greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
grpcio==1.74.0
gunicorn==19.9.0 ; sys_platform == 'win32'
gunicorn==23.0.0 ; sys_platform != 'win32'
gviz-api==1.10.0
h11==0.16.0
hf-xet==1.1.7 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
huggingface-hub==0.34.4
humanfriendly==10.0 ; sys_platform != 'win32'
hydra-core==1.3.2
idna==3.10
imageio==2.37.0
importlib-metadata==8.6.1
importlib-resources==6.5.2
inference-schema==1.8 ; sys_platform != 'win32'
ipykernel==6.30.1
ipython==8.18.1 ; python_full_version < '3.10'
ipython==8.37.0 ; python_full_version == '3.10.'
ipython==9.4.0 ; python_full_version >= '3.11'
ipython-pygments-lexers==1.1.1 ; python_full_version >= '3.11'
isodate==0.7.2
itsdangerous==2.2.0
jaraco-functools==4.3.0
jedi==0.19.2
jeepney==0.9.0
jinja2==3.1.6
jmespath==1.0.1
joblib==1.5.1
json-logging-py==0.2 ; sys_platform == 'win32'
jsonpickle==4.1.1
jsonschema==4.25.1
jsonschema-specifications==2025.4.1
jupyter-client==8.6.3
jupyter-core==5.8.1
kiwisolver==1.4.7 ; python_full_version < '3.10'
kiwisolver==1.4.9 ; python_full_version >= '3.10'
knack==0.12.0 ; sys_platform != 'win32'
lazy-loader==0.4
liac-arff==2.5.0 ; sys_platform == 'win32'
lightly==1.5.22
lightly-train==0.10.0
lightly-utils==0.0.2
lightning==2.4.0
lightning-utilities==0.15.2
mako==1.3.10
markdown==3.8.2
markupsafe==3.0.2
marshmallow==3.26.1
matplotlib==3.9.4 ; python_full_version < '3.10'
matplotlib==3.10.5 ; python_full_version >= '3.10'
matplotlib-inline==0.1.7
mlflow==2.22.1
mlflow-skinny==2.22.1
more-itertools==10.7.0
mpmath==1.3.0
msal==1.33.0
msal-extensions==1.3.1
msrest==0.7.1
msrestazure==0.6.4.post1
multidict==6.6.4
ndg-httpsclient==0.5.1
nest-asyncio==1.6.0
networkx==3.2.1 ; python_full_version < '3.10'
networkx==3.4.2 ; python_full_version == '3.10.
'
networkx==3.5 ; python_full_version >= '3.11'
numpy==1.23.5
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
oauthlib==3.3.1
omegaconf==2.3.0
opencensus==0.11.4 ; sys_platform != 'win32'
opencensus-context==0.1.3 ; sys_platform != 'win32'
opencensus-ext-azure==1.1.15 ; sys_platform != 'win32'
opencv-python-headless==4.11.0.86
opentelemetry-api==1.36.0
opentelemetry-instrumentation==0.57b0
opentelemetry-instrumentation-asgi==0.57b0
opentelemetry-instrumentation-dbapi==0.57b0
opentelemetry-instrumentation-django==0.57b0
opentelemetry-instrumentation-fastapi==0.57b0
opentelemetry-instrumentation-flask==0.57b0
opentelemetry-instrumentation-psycopg2==0.57b0
opentelemetry-instrumentation-requests==0.57b0
opentelemetry-instrumentation-urllib==0.57b0
opentelemetry-instrumentation-urllib3==0.57b0
opentelemetry-instrumentation-wsgi==0.57b0
opentelemetry-resource-detector-azure==0.1.5
opentelemetry-sdk==1.36.0
opentelemetry-semantic-conventions==0.57b0
opentelemetry-util-http==0.57b0
optuna==4.5.0
orderedmultidict==1.0.1
packaging==24.2
pandas==2.2.2
paramiko==3.5.1 ; sys_platform != 'win32'
parso==0.8.4
pathlib2==2.3.7.post1
pathspec==0.12.1
pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
pillow==11.0.0
pkginfo==1.12.1.2 ; sys_platform != 'win32'
platformdirs==4.3.8
prompt-toolkit==3.0.51
propcache==0.3.2
proto-plus==1.26.1
protobuf==6.32.0
psutil==6.1.0
ptyprocess==0.7.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
pure-eval==0.2.3
pyarrow==19.0.1
pyasn1==0.6.1
pyasn1-modules==0.4.2
pycparser==2.22 ; implementation_name == 'pypy' or platform_python_implementation != 'PyPy' or sys_platform != 'win32'
pydantic==2.11.7
pydantic-core==2.33.2
pydantic-settings==2.10.1 ; sys_platform != 'win32'
pydash==8.0.5
pygments==2.19.2
pyjwt==2.8.0
pynacl==1.5.0 ; sys_platform != 'win32'
pyopenssl==25.1.0
pyparsing==3.2.3
pysocks==1.7.1 ; sys_platform != 'win32'
python-dateutil==2.9.0.post0
python-dotenv==1.1.1 ; sys_platform != 'win32'
pytorch-lightning==2.5.3
pytz==2025.2
pywin32==311 ; sys_platform == 'win32'
pyyaml==6.0.2
pyzmq==27.0.1
qudida==0.0.4
referencing==0.36.2
regex==2025.7.34
requests==2.32.4
requests-oauthlib==2.0.0
rpds-py==0.27.0
rsa==4.9.1
ruamel-yaml==0.15.89 ; sys_platform == 'win32'
ruamel-yaml==0.17.16 ; sys_platform != 'win32'
ruamel-yaml-clib==0.2.12 ; python_full_version < '3.10' and platform_python_implementation == 'CPython' and sys_platform != 'win32'
safetensors==0.6.2
scikit-image==0.24.0
scikit-learn==1.6.1 ; python_full_version < '3.10'
scikit-learn==1.7.1 ; python_full_version >= '3.10'
scipy==1.13.1 ; python_full_version < '3.10'
scipy==1.15.3 ; python_full_version >= '3.10'
secretstorage==3.3.3
setuptools==80.9.0
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
sqlalchemy==2.0.43
sqlparse==0.5.3
stack-data==0.6.3
starlette==0.47.2
strictyaml==1.7.3
sympy==1.13.1
tabulate==0.9.0 ; sys_platform != 'win32'
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorboard-plugin-profile==2.20.6
threadpoolctl==3.6.0
tifffile==2024.8.30 ; python_full_version < '3.10'
tifffile==2025.5.10 ; python_full_version == '3.10.*'
tifffile==2025.6.11 ; python_full_version >= '3.11'
timm==1.0.19
tokenizers==0.21.4
tomli==2.2.1 ; python_full_version < '3.11'
torch==2.5.1 ; platform_machine == 'aarch64' and sys_platform == 'linux'
torch==2.5.1 ; sys_platform != 'linux'
torch==2.5.1+cu124 ; platform_machine != 'aarch64' and sys_platform == 'linux'
torchmetrics==1.8.1
torchvision==0.20.1 ; platform_machine != 'x86_64' or sys_platform != 'linux'
torchvision==0.20.1+cu124 ; platform_machine == 'x86_64' and sys_platform == 'linux'
tornado==6.5.2
tqdm==4.67.0
traitlets==5.14.3
transformers==4.55.2
triton==3.1.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typing-extensions==4.14.1
typing-inspection==0.4.1
tzdata==2025.2
urllib3==2.5.0
uvicorn==0.35.0
waitress==3.0.2 ; sys_platform == 'win32'
wcwidth==0.2.13
werkzeug==3.1.3
wrapt==1.16.0
xprof==2.20.6
yarl==1.20.1
zipp==3.23.0

  • How did you install the package:

python -m pip install -r requirements.txt

📷 Screenshots (optional)

If applicable, add screenshots to help explain your problem.

📌 Additional Context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions