Skip to content

Allow PyTree inputs for coil objectives #3562

Allow PyTree inputs for coil objectives

Allow PyTree inputs for coil objectives #3562

Workflow file for this run

name: Dependency test JAX
on:
pull_request:
types: [labeled, synchronize]
workflow_dispatch:
jobs:
jax_tests:
if: ${{ contains(github.event.pull_request.labels.*.name, 'test_jax') || github.event_name == 'workflow_dispatch' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
jax-version:
[0.5.0, 0.5.3, 0.6.1, 0.6.2, 0.7.0, 0.7.2]
# 0.4.x versions are not tested because they fail with jax-finufft
# 0.5.1 and 0.5.2 installations are broken, see jax#26781
# 0.6.0 has a bug with jax.grad, see jax#28144
# 0.7.0 have performance issues but we still support it, see diffrax#680
# 0.7.1 fails with equinox, see equinox#1081
group: [1, 2]
steps:
- uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
- name: Install dependencies with given JAX version
run: |
sed -i '1{/^jax/d}' requirements.txt
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
pip install matplotlib==3.9.2
- name: Verify dependencies
run: |
python --version
pip --version
pip list
- name: Test with pytest
run: |
pwd
lscpu
python -m pytest -v -m unit \
--durations=0 \
--mpl \
--maxfail=1 \
--splits 3 \
--group ${{ matrix.group }} \
--splitting-algorithm least_duration