Conda를 이용한 간단한 GPU버전 JAX 설치법
이 포스트는 2024년 3월 12일을 기준으로 작성되었습니다.
Antibody design을 위해 JAX를 설치해야 하는 상황에 생겼다. (pytorch, tensorflow 정도로 끝날 줄 알았는데…)
공식 홈페이지의 Installation Guideline 대로 설치를 시도했다. (먼저 conda
환경 만들기!)
1
2
3
# Create Conda Environment (python 3.9)
conda create -n af python=3.9
conda activate af
[시도 1]
하지만, CUDA 버전 (v11.1)이 너무 낮아, GPU를 인식하지 못하는 문제가 발생하였다.
자세히
CUDA version이 다음과 같이 다르게 나타났지만(차이가 나는 이유), 11.x라 공식 Guideline인 Installing JAX에서 'CUDA 11 installation' 코드로 설치 시도함.
1
nvidia-smi
1
nvcc --version
1
2
3
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
1
2
3
python
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
CUDA backend failed to initialize: Found cuBLAS version 11201, but JAX was built against version 111103, which is newer. The copy of cuBLAS that is installed must be at least as new as the version against which JAX built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu
cpu
[시도 2]
CUDA 버전을 업데이트 하기 위해 공식 문서를 참고하여,
- 리눅스 버전
- GPU 모델
확인 후 CUDA toolkit v12.4와 cuDNN v8.9.7 for CUDA 12.x 를 설치하였으나, 어떤 이유에서 인지… 이번에도 GPU를 인식하지 못하는 문제가 발생하였다.
자세히
1
2
3
4
5
6
# in .bash_profile
export PATH=/usr/local/cuda-12.4/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
# after save
source ~/.bash_profile
[시도 3]
그리고 conda
를 활용해 설치하는 방법을 시도했지만, 어째서인지 명령어가 먹히지 않았다.
자세히
1
2
conda install jax -c conda-forge
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
no matches found: jaxlib=*=*cuda*
설치 성공한 방법
마지막으로 아래와 같이 시도해서 성공했다.
1
2
3
# create conda environment
conda create -n af python=3.9
conda activate af
처음에는 CUDA 12.4 버전을 설치하려고 했지만, conda
에서 설치 가능한 cudnn version이 CUDA 11을 기반으로 하기 때문에, 11.8 버전을 기준으로 설치하였다.
JAX currently ships two CUDA wheel variants:
- CUDA 12.3, cuDNN 8.9, NCCL 2.16
- CUDA 11.8, cuDNN 8.6, NCCL 2.16
1
2
3
4
5
# install cuda using conda
conda install nvidia/label/cuda-11.8.0::cuda
# install cudnn using conda
conda install anaconda::cudnn
그리고 JAX를 conda로 설치했는데, 설치된 jax 버전 (v0.4.25)과 jaxlib 버전 (v0.4.23)이 달라, 이를 맞추기 위해 최종적으로 아래와 같이 설치하였다.
1
2
# install jax using conda
conda install jax==0.4.23 -c conda-forge
gpu 버전인 jaxlib 설치를 위해 jaxlib만 pip를 사용해 설치해주었다.
1
2
# install jaxlib for gpu using conda
pip install jaxlib==0.4.23+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html