NVIDIA A100 80GB的主要特性是什么?

目前能够使用的最高端的训练显卡,就是A100了(一般人用不起H100,这里就不纳入考虑了)。A100有两种配置:40GB显存版本和80GB显存版本。乍一看,80GB显存版本的A100训练效率应该是40GB显存版本A100的两倍。但是,经过仔细计算,我发现这一结论是粗糙的、经不起推敲的。根据我的实验数据,在训练大语言模型时,80GB显存版本A100的训练效率可以达到40GB显存版本A100的四倍甚至更高。


训练过程中的显存占用分为三部分,下面分别计算。


一、模型参数的显存占用


模型参数的显存占用最容易理解,直接数模型参数的总量就行了。Mini-GPT4的模型参数主要包括三大块:


1、LLaMA 13B微调得到的Vicuna 13B模型,训练过程为半精度浮点数,每个参数占用2Byte显存,因此这部分参数占用13Billion x 2Byte=26Billion Byte=26GB显存。


2、抽取图像特征的ViT模型,这部分采用的是ViT-G/14模型,显存占用大约为1GB。


3、将图像特征转化成与语言模型特征对齐的Q-Former、语言模型FlanT5-XXL等其它模块,显存占用大约0.2GB。


因此,Mini-GPT4的模型参数共占用大约27.2GB显存。


二、优化器显存占用


优化器的显存占用正比于模型中可训练参数的大小。对于常用的Adam/AdamW来说,这个比例系数为2。考虑到Mini-GPT4模型中可训练参数很少(只有一层线性层),因此这部分的显存占用基本可以忽略。(具体来说,这部分参数大约占用45MB显存,加上优化器占用的显存 45MB x 2 = 90 MB,一共占用135MB≈0.1GB显存)


三、训练过程的中间结果的显存占用


这部分是真正能用于训练的显存,而且和batchsize成正比。batchsize则决定了训练效率。LLaMA模型的序列长度为2048,经过测算,batchsize每增大1,显存占用增大约0.65GB。详细的实验数据见下图。


image.png

Mini-GPT4模型的显存占用(GB)与batchsize的变化关系


数据分析


对于A100的40GB显卡,最大batchsize为20,但是此时显存占用已经到了临界点,容易出现out of memory的错误。因此实际中使用batchsize为16更稳妥。


对于A100的80GB显卡,理论最大batchsize(将上图中的直线进行外推)能达到82,保险起见可以设置为80。我喜欢2的幂次方,因此只测到了64。


由此可见,要达到A100 80GB的训练效率,需要4块甚至5块A100 40GB。而即使有了4块A100 40GB,为了使用多张卡同时训练,需要进行数据并行,多卡之间还得频繁通讯,如果只是PCIe版本的显卡而没有NVLink或者SXM配置,通信效率会极大地拖累训练进度。

蓝海大脑 京ICP备18017748号-1