1、安装jax
1.1、前提条件
已经安装好了NVIDIA
显卡驱动和CUDA。如果你还没安装,那么你可以参考我的这篇文章。
jax是谷歌推出的深度学习框架。
这里安装的是GPU版本的jax。
1.2、安装
源码地址:
https://github.com/google/jax
更新pip:
pip install --upgrade pip
安装jax:
cuda 11
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
cuda 12
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
测试是否安装成功,可以参考这篇文章。
import jax.numpy as np
from jax import random
import time
x = random.uniform(random.PRNGKey(0),[5000,5000])
st = time.time()
try:
y=np.matmul(x,x)
except Exception:
print("error")
print(time.time()-st)
print(y)
如果使用国外的源无法下载时,可以更换为国内的镜像,这里推荐更换为清华大学的镜像。
2、安装cuDNN
下载完成之后上传到服务器,然后解压cuDNN。
tar -xvf xxx.tar
然后进入到解压后的目录。
然后复制到CUDA-12.1目录。
sudo cp include/* /usr/local/cuda-12.1/include
sudo cp lib/libcudnn* /usr/local/cuda-12.1/lib64
sudo chmod a+r /usr/local/cuda-12.1/include/cudnn*
sudo chmod a+r /usr/local/cuda-12.1/lib64/libcudnn*
查看cuDNN版本。
cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
3、安装whisper-jax
源码:
https://github.com/sanchit-gandhi/whisper-jax
安装:这里推荐在Anaconda的虚拟环境中安装。如何安装Anaconda,可以去参考我的这篇文章。
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
注意不能使用国内pip源安装。
修改pip配置文件。
vim ~/.pip/pip.conf
把国内镜像源注释掉。
验证jax是否使用GPU
import jax
print(jax.devices()[0])
如果返回gpu:0说明可以使用GPU。
4、whisper-jax的使用
whisper-jax没有提供命令行方式运行。
from whisper_jax import FlaxWhisperPipline
# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-small")
# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")
# used cached function thereafter - super fast!!
text = pipeline("audio.mp3")
个人感觉whisper-jax更吃显卡性能,经过我在Tesla T4 15GB显卡上的测试,无法跑large模型。
转录的速度要比whisper更快。
4.1、调整精度
可以修改精度,加快转录速度。
对于A100显卡和TPU:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.bfloat16)
对于非A100显卡:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.float16)
4.2、批处理
from whisper_jax import FlaxWhisperPipline
# instantiate pipeline with batching
pipeline = FlaxWhisperPipline("openai/whisper-small", batch_size=16)
4.3、输出带时间戳的文件
from whisper_jax import FlaxWhisperPipline
pipeline = FlaxWhisperPipline("openai/whisper-small")
text = pipeline("audio.mp3", return_timestamps=True)
chunks = text["chunks"]
with open('output', 'w') as f:
for item in chunks:
f.write(str(item) + "\n")
f.close()
Q.E.D.