Skip to content

Commit 589b352

Browse files
authored
Update instruction to install jax on GPU (#1470)
1 parent f48e341 commit 589b352

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ pip install numpyro[cpu]
238238

239239
To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command:
240240
```
241-
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html
241+
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
242242
```
243243
If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda).
244244

docker/dev/Dockerfile

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04
88

99
# declare the image name
1010
# note that this image uses Python 3.8
11-
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
12-
# declare the cuda version for pulling appropriate jaxlib wheel
13-
JAXLIB_CUDA=111
11+
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04
1412

1513
# install python3 and pip on top of the base Ubuntu image
1614
# unlike for release, we need to install git and setuptools too
@@ -22,11 +20,7 @@ RUN apt update && \
2220
ENV PATH=/root/.local/bin:$PATH
2321

2422
# install python packages via pip
25-
# install pip-versions to detect the latest version of jax and jaxlib
26-
RUN pip3 install pip-versions
27-
# this uses latest version of jax and jaxlib available from pypi
28-
RUN pip-versions latest jaxlib | xargs -I{} pip3 install jaxlib=={}+cuda${JAXLIB_CUDA} -f https://storage.googleapis.com/jax-releases/jax_releases.html \
29-
jax
23+
RUN pip3 install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
3024

3125
# clone the numpyro git repository and run pip install
3226
RUN git clone https://github.com/pyro-ppl/numpyro.git && \

0 commit comments

Comments
 (0)