Allow PyTree inputs for coil objectives #3562
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Dependency test JAX | |
| on: | |
| pull_request: | |
| types: [labeled, synchronize] | |
| workflow_dispatch: | |
| jobs: | |
| jax_tests: | |
| if: ${{ contains(github.event.pull_request.labels.*.name, 'test_jax') || github.event_name == 'workflow_dispatch' }} | |
| runs-on: ubuntu-latest | |
| strategy: | |
| fail-fast: false | |
| matrix: | |
| jax-version: | |
| [0.5.0, 0.5.3, 0.6.1, 0.6.2, 0.7.0, 0.7.2] | |
| # 0.4.x versions are not tested because they fail with jax-finufft | |
| # 0.5.1 and 0.5.2 installations are broken, see jax#26781 | |
| # 0.6.0 has a bug with jax.grad, see jax#28144 | |
| # 0.7.0 have performance issues but we still support it, see diffrax#680 | |
| # 0.7.1 fails with equinox, see equinox#1081 | |
| group: [1, 2] | |
| steps: | |
| - uses: actions/checkout@v6 | |
| - name: Set up Python 3.12 | |
| uses: actions/setup-python@v6 | |
| with: | |
| python-version: "3.12" | |
| - name: Upgrade pip | |
| run: | | |
| python -m pip install --upgrade pip | |
| - name: Install dependencies with given JAX version | |
| run: | | |
| sed -i '1{/^jax/d}' requirements.txt | |
| sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt | |
| cat ./requirements.txt | |
| pip install -r ./devtools/dev-requirements.txt | |
| pip install matplotlib==3.9.2 | |
| - name: Verify dependencies | |
| run: | | |
| python --version | |
| pip --version | |
| pip list | |
| - name: Test with pytest | |
| run: | | |
| pwd | |
| lscpu | |
| python -m pytest -v -m unit \ | |
| --durations=0 \ | |
| --mpl \ | |
| --maxfail=1 \ | |
| --splits 3 \ | |
| --group ${{ matrix.group }} \ | |
| --splitting-algorithm least_duration |