一、前提条件
ubuntu 20.04
python 3.9
cuda 11.8
nvidia-cublas-cu11 11.11.3.6
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-nvcc-cu11 11.8.89
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cudnn-cu11 8.9.6.50
nvidia-cufft-cu11 10.9.0.58
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusparse-cu11 11.7.5.86
nvidia-nccl-cu11 2.20.5
查看cuDNN版本命令
cat /usr/local/cuda/include/cudnn.h | grep CUDNN_MAJOR -A 2
这里最推荐安装cuda 11.8,安装最新的版本会与jax不兼容
二、安装Whisper-Jax
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
三、安装GPU版本的jax
我们在执行上一步安装whisper-jax的时候,会自动安装cpu版本的jax和jaxlib
jaxlib是jax运行时所需的依赖库
所以我们在安装jax的时候还需要考虑jaxlib版本问题,一般而且,jax和jaxlib的版本号是相同的
我们需要把cpu版本的jax和jaxlib都卸载掉,重新安装GPU版本的
卸载命令
pip uninstall jax jaxlib
安装命令
pip install -U "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -i https://pypi.tuna.tsinghua.edu.cn/simple
这里推荐使用清华镜像源
四、解决运行时自动联网从huggingface下载模型资料的问题
找到pipeline.py
,添加下面内容
WhisperProcessor.from_pretrained(self.checkpoint) # 修改前
WhisperProcessor.from_pretrained(self.checkpoint, local_files_only=True) # 修改后
然后模型修改下面代码
pipeline = FlaxWhisperPipline("openai/whisper-base") # 修改前
pipeline = FlaxWhisperPipline("/root/.cache/huggingface/hub/models--openai--whisper-base") # 修改后
如何解决无法访问huggingface
参考文档
Q.E.D.