Install JAX GPU version using Conda
This post is based on 12th March 2024.
For Antibody design model, I needed to install JAX.
I tried installing it according to the official guideline. (First, created conda
environment)
1
2
3
# Create Conda Environment (python 3.9)
conda create -n af python=3.9
conda activate af
[Trial 1]
However, there was an issue where the GPU was not recognized due to the CUDA version (v11.1) being too low.
Detail
(Although the CUDA versions are not different ) The CUDA version appeared differently for each command(Why they are different?), but since it was 11.x, I attempted installation using the 'CUDA 11 installation' code from the official Guideline, Installing JAX.
1
nvidia-smi
1
nvcc --version
1
2
3
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
1
2
3
python
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
cpu
[Trial 2]
To update CUDA version, referred the official documents, I checked
- Linux version
- GPU
I attempted installation of CUDA toolkit v12.4 and cuDNN v8.9.7 for CUDA 12.x, but I don’t know why… there was the same issue where the GPU was not recognized.
Detail
1
2
3
4
5
6
# in .bash_profile
export PATH=/usr/local/cuda-12.4/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
# after save
source ~/.bash_profile
[Trial 3]
Then, using conda
, I attempted other way JAX Document. I don’t know why but it was failed.
Detail
1
2
conda install jax -c conda-forge
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
How to install successfully
I success to install using this way.
1
2
3
# create conda environment
conda create -n af python=3.9
conda activate af
Initially, I tried to install CUDA 12.4 version, but since the cudnn version available for installation from conda
is based on CUDA 11, I installed CUDA 11.8 version
[reference](https://jax.readthedocs.io/en/latest/installation.
JAX currently ships two CUDA wheel variants:
- CUDA 12.3, cuDNN 8.9, NCCL 2.16
- CUDA 11.8, cuDNN 8.6, NCCL 2.16
1
2
3
4
5
# install cuda using conda
conda install nvidia/label/cuda-11.8.0::cuda
# install cudnn using conda
conda install anaconda::cudnn
And I installed JAX via conda
, but the installed jax version (v0.4.25) and jaxlib version (v0.4.23) were different, so to align them, I finally installed as follows.
1
2
# install jax using conda
conda install jax==0.4.23 -c conda-forge
For the installation of the GPU version of jaxlib, I used pip to install only jaxlib.
1
2
# install jaxlib for gpu using conda
pip install jaxlib==0.4.23+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html