训练最基础的transformer模型用多大的gpu就行?

8gb或者12gb就够训练 12层的 encoder-decoder 架构 transformer 模型了。序列长度在512左右。batch size什么的可以通过 gradient checkpoint 或者 accumulate gradient 等操作间接提升。小显存推荐开混合精度训练,或者开bf16缓解一下显存压力(如果卡支持的话)。有能力可以用个 fp16/bf16 算力大点的。



蓝海大脑 京ICP备18017748号-1