This post is based on 12th March 2024.

For Antibody design model, I needed to install JAX.

I tried installing it according to the official guideline. (First, created conda environment)

1
2
3
  # Create Conda Environment (python 3.9)
  conda create -n af python=3.9
  conda activate af

[Trial 1]


However, there was an issue where the GPU was not recognized due to the CUDA version (v11.1) being too low.

Detail (Although the CUDA versions are not different ) The CUDA version appeared differently for each command(Why they are different?), but since it was 11.x, I attempted installation using the 'CUDA 11 installation' code from the official Guideline, Installing JAX.
1
  nvidia-smi
1
  nvcc --version

1
2
3
  # CUDA 11 installation
  # Note: wheels only available on linux.
  pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
1
2
3
  python
  from jax.lib import xla_bridge
  print(xla_bridge.get_backend().platform)
I realized that the CUDA version was too low through the log message.
CUDA backend failed to initialize: Found cuBLAS version 11201, but JAX was built against version 111103, which is newer. The copy of cuBLAS that is installed must be at least as new as the version against which JAX built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu


[Trial 2]


To update CUDA version, referred the official documents, I checked

  1. Linux version
  2. GPU

I attempted installation of CUDA toolkit v12.4 and cuDNN v8.9.7 for CUDA 12.x, but I don’t know why… there was the same issue where the GPU was not recognized.

Detail
1
2
3
4
5
6
  # in .bash_profile
  export PATH=/usr/local/cuda-12.4/bin${PATH:+:${PATH}}
  export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

  # after save
  source ~/.bash_profile


[Trial 3]


Then, using conda, I attempted other way JAX Document. I don’t know why but it was failed.

Detail
1
2
  conda install jax -c conda-forge
  conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
no matches found: jaxlib=*=*cuda*


How to install successfully

I success to install using this way.

1
2
3
# create conda environment
conda create -n af python=3.9
conda activate af

Initially, I tried to install CUDA 12.4 version, but since the cudnn version available for installation from conda is based on CUDA 11, I installed CUDA 11.8 version [reference](https://jax.readthedocs.io/en/latest/installation.

JAX currently ships two CUDA wheel variants:

  • CUDA 12.3, cuDNN 8.9, NCCL 2.16
  • CUDA 11.8, cuDNN 8.6, NCCL 2.16
1
2
3
4
5
  # install cuda using conda
  conda install nvidia/label/cuda-11.8.0::cuda

  # install cudnn using conda
  conda install anaconda::cudnn

And I installed JAX via conda, but the installed jax version (v0.4.25) and jaxlib version (v0.4.23) were different, so to align them, I finally installed as follows.

1
2
  # install jax using conda
  conda install jax==0.4.23 -c conda-forge

For the installation of the GPU version of jaxlib, I used pip to install only jaxlib.

1
2
  # install jaxlib for gpu using conda
  pip install jaxlib==0.4.23+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html