大力再出奇迹,1024 张TPU,65536 batch size,仅76分钟训练完BERT!

大数据文摘出品

作者:Andy

BERT 作为目前工业界中训练最耗时的应用,计算量甚至远大于机器视觉中的 ImageNet 训练。在 BERT 原论文中,Jacob Devlin 也是用了 16 台云TPU(64 个 TPU 芯片)花了整整四天,才训练完了 BERT-large 模型。

正因如此难以训练,导致大家也都几乎是直接拿 Google 开源出的模型,用 GPU 在上面 finetune 一下就来用。Github上的 BERT 开源库也因为其在 GPU 机器上预训练的困难和耗时,往往没给出预训练相关脚本。

论文下载地址:

https://arxiv.org/pdf/1904.00962.pdf

而在这篇愚人节当天于arxiv发出的 《Reducing BERT Pre-Training Time from 3 Days to 76 Minutes》论文中,正如标题所说,作者将BERT的训练时间缩减到了仅仅 76 分钟。

而所用的计算机器数也是惊人的,一个 TPUv3 pod (1024张TPU芯片)。不要以为这只是个靠堆积机器就能堆积上去的成果,这还涉及到具体的训练优化问题解决,比如说如何在增大batch size来提高计算通讯比的同时,又能保证其收敛。这篇论文中就是提出了一套训练方案,在新优化器 LAMB 的帮组下来解决这个问题。

这篇论文中就提出了一套训练方案,还有一个新的优化器LAMB来解决这个问题。

这是一作加大伯克利分校博士尤洋作为 Google Brain 实习生时完成的项目,其现在的主攻方向便是“高速并行分布深度学习算法”,这也是与这次论文密切相关的主题。

看他之前的论文也能看到类似的大规模训练项目,比如之前在英伟达实习时的技术报告《LARGE BATCH TRAINING OF CONVOLUTIONAL NETWORKS》,怎么用大 batch size 训练CNN网络。

他在里面提出的 LARS,也正是这篇论文中提出优化器 LAMB 的最初原型。论文中也有拿 LARS 做了实验。而说到在Google Brain实习也是有迹可循,之前他便已在这实习过一段时间研究关于TPU上使用Tensorflow的课题,也为此次项目打下铺垫。

论文下载地址:

https://arxiv.org/pdf/1708.03888.pdf

看他之前的论文也能看到类似的大规模训练项目,比如他之前在英伟达实习时的技术报告《LARGE BATCH TRAINING OF CONVOLUTIONAL NETWORKS》,怎么用大 batch size 训练CNN网络。

他在里面提出的 LARS,也正是这篇论文中提出优化器 LAMB 的最初原型。论文中也拿LARS做了实验。而说到在Google Brain实习也是有迹可循,之前他便已在这实习过一段时间研究关于TPU上使用Tensorflow的课题。

下面就来介绍论文吧。

首先在导言部分,作者稍稍介绍了大批量训练的困难,还有此次面对的训练对象BERT。为处理大批量 BERT 的训练,作者提出了 LAMB 优化器,通过这个优化器将BERT训练的batch size推到了64k(具体65536)的量级,而同时不损失精度,此外LAMB优化器还有一个优点就是只用调学习率。

最后的预训练包括两个阶段:前九个epoch用128的序列长度和64k的batch size;最后一个epoch用512的序列长度和32k的batch size,只用了惊人的8599个迭代便训练完了BERT。而它和基线模型还有其他batch size训练对比,如下图。

之后,在背景部分,作者给我们分享了关于大批量训练的知识。目前大批量训练常常会遇到的几个问题,还有如何对其中一些进行处理。

  • 大批量会导致测试精度丢失,所以需要调节超参,比如学习率。所以需要随着batch size的增大,线性或平方根级地增大学习率;
  • 但大学习率会导致训练初始不稳定,因此需要使用学习率预热技巧(learning rate warmup),先用一个小学习率然后慢慢增大,到一定点切换到正常的学习率策略;
  • 还有大批量训练里的泛化间距(generalization gap)问题,大批量训练模型会倾向于收敛到比较尖锐的局部最小点,这会导致训练很容易达到一个比较高的训练精度,却很难获得比较好的测试精度,也就是存在比较大的间距(gap)。目前还没找到合适方法解决。

然后再来详细了解一下本文最重要的创新,LAMB (Layer-wise Adaptive Moments optimizer for Batch training)优化器吧。它是基于作者之前提出的 LARS (Layer-wise Adaptive Rate Scaling)优化器,同时又参考了BERT原有优化器进行了改进。

首先为了理解 LAMB 我们需要对 LARS 有个粗略的了解。LARS 提出的背景是,作者发现对于网络各层,其权重和梯度的L2-norm的比值|w|/|g|变化会非常大,比如5.76和1345。

这导致的一个问题就是,一个学习率并不适应于所有层训练,它对于有些层或许适合,但对于有些层却可能太大。因此作者提出应该要按照层,来获得每层的学习率(Local LR),而这个每层学习率的计算则需要之后 LAMB 中多次提到的一个trust ratio,置信比,有多大的可能我们相信当前层会在这次更新中改变它的权重。

于是 LAMB 对 LARS 最大的改进有三点,都是与trust ratio相关。

1.通过在Tensorflow具体的LARS优化器实现中,移除其中一个当某层的|w|和|g|都非零时用于计算 trust ratio 的系数 eeta,从而避免了BERT大批量训练中的发散;

2.LARS 自身的trust ratio在一些自适应优化器,比如BERT里用到的ADAM with weight decay 或 ADAM中,会导致不准确的学习率矫正,因为它们用了元素级的更新策略。于是作者们将trust_ratio做了些修正,保证了它和自适应优化器的结合。

3.最后一点,在计算梯度的L2-norm的时候,还加入了梯度的一阶和二阶惯量的信息。

LAMB 算法具体如下:

最后具体实验部分便不累述,值得一提的一个细节是,在混合batch训练的时候,因为中间有一个阶段转换过程,将batch size从64k降到了32k,这里会给优化过程带来噪音,进而导致训练发散。为解决这个问题,在第二个阶段的时候作者又重新进行了一次学习率预热(re-warm-up)。

Image placeholder
YinJiongjie
未设置
  17人点赞

没有讨论,发表一下自己的看法吧

推荐文章
AI赌神升级!无惧bluff,6人局德扑完胜世界冠军,训练只用了8天

大数据文摘出品作者:曹培信、宁静2017年年初,BrainvsAI的德州扑克人机大战在卡耐基梅隆大学(CMU)落幕,由4名人类职业玩家组成的人类大脑不敌人工智能程序Libratus。获胜后人类还遭到了

秘籍在手,训练不愁!特斯拉AI负责人Karpathy的超全神经网络训练套路

大数据文摘出品编译:周素云、宋欣仪、熊琰、ZoeY、顾晨波训练神经网络到底有诀窍和套路吗?AndrejKarpathy认为,还的确有。这位特斯拉的人工智能研究负责人、李飞飞的斯坦福高徒刚刚难得更新了博

10分钟,用TensorFlow.js库,训练一个没有感情的“剪刀石头布”识别器

大数据文摘出品编译:Luciana、小七、宁静“剪刀石头布”是我们小时候经常玩的游戏,日常生活中做一些纠结的决策,有时候也常常使用这种规则得出最后的选择,我们人眼能很轻松地认知这些手势,“石头”呈握拳

万字长文|1分36秒,100亿,支付宝技术双11答卷:没有不可能

2019年双11来了。1分36秒100亿,5分25秒超过300亿,12分49秒超500亿……如果没有双11,中国的互联网技术要发展到今天的水平,或许要再多花20年。从双11诞生至今的11年里,有一个场

1万属性,100亿数据,每秒10万吞吐,架构如何设计?

有一类业务场景,没有固定的schema存储,却有着海量的数据行数,架构上如何来实现这类业务的存储与检索呢?58最核心的数据“帖子”的架构实现技术细节,今天和大家聊一聊。一、背景描述及业务介绍什么是58

pymysql fetchone () , fetchall () , fetchmany ()

最近在用python操作mysql数据库时,碰到了下面这两个函数,标记一下: 1.定义 1.1fetchone(): 返回单个的元组,也就是一条记录(row),如果没有结果则返回None 1.2fet

到2025年,全球VoIP市场将达到550亿美元

基于IP的语音传输(VoIP)是当今世界许多人和企业主现代生活中不可或缺的一部分。几十年来,该技术发展迅速,延伸出了VCaaS、CCaaS、UCaaS等。然而,即使在“VoIP”已经成为常用术语的世界

56岁潘石屹下决心学Python,60岁程序语言之父们还在敲代码,你呢

比你成功的人,比你还努力。上周,SOHO中国董事长、地产大亨 潘石屹,56岁生日当天发布微博宣布进军编程语言Python。 紧接着第二天,又更新微博解释为何会做出此举。潘石屹给出的解释大致就是,在不断

我的天!这是史上最烂的项目:苦撑12年,600多万行代码…

编译:欧剃来源:projectfailures.wordpress.com转载自:Java技术栈你见过最烂的项目,撑了多长时间才完蛋?六个月?一年?今天介绍的这个奇葩项目,不但一开始就烂得透透的,还硬

树莓派 4 正式发布!硬件性能大提升:CPU提升3倍,支持USB3.0、蓝牙5.0、千兆以太网、4G LPDDR4、H.265

本文转自|EETOP树莓派(RaspberryPi)基金会,6月24日正式发布了RaspberryPi4ModelB。树莓派是全球知名的基本计算微型电脑,深受全球开发者、编程者、极客等人士的追捧和喜爱

算法-最小栈-LeetCode155

题目最小栈设计一个支持push,pop,top操作,并能在常数时间内检索到最小元素的栈。push(x) --将元素x推入栈中。pop() --删除栈顶的元素。top() --获取栈顶元素。getMin

制定机器学习训练数据策略的6个技巧

人工智能(AI)和机器学习(ML)如今已经十分常见。AI指的是机器模仿人类进行认知的概念,ML是一种用于构建AI的方法。如果AI是指计算机可以根据指令执行一组任务,那么ML就是机器从数据中摄取、解析和

一份关于机器学习“模型再训练”的终极指南

机器学习模型的训练,通常是通过学习某一组输入特征与输出目标之间的映射来进行的。一般来说,对于映射的学习是通过优化某些成本函数,来使预测的误差最小化。在训练出最佳模型之后,将其正式发布上线,再根据未来生

Python可视化 | Seaborn5分钟入门(二)——barplot&countplot&pointplot

微信公众号:「Python读财」如有问题或建议,请公众号留言Seaborn是基于matplotlib的Python可视化库。它提供了一个高级界面来绘制有吸引力的统计图形。Seaborn其实是在matp

聊聊chronos的pullFromDefaultCFAndPush

序本文主要研究一下chronos的pullFromDefaultCFAndPushpullFromDefaultCFAndPushDDMQ/carrera-chronos/src/main/java/

Github标星十万+!愤怒的程序员发起996.ICU,小本本投诉过度加班公司

大数据文摘出品作者:蒋宝尚哪里有压迫,哪里就有反抗。作为程序员的你,这几天一定被名为996.ICU的github项目刷屏。3月27日,有开发者在GitHub上建了一个名为996.ICU的repo,该r

TPC-C解析系列02_OceanBase如何做TPC-C测试

导语:蚂蚁金服自研数据库OceanBase登顶TPC-C引起业内广泛关注,为了更清楚的展示其中的技术细节,我们特意邀请OceanBase核心研发人员对本次测试进行技术解读,共包括五篇:1)TPC-C基

Python可视化 | Seaborn5分钟入门(一)——kdeplot和distplot

微信公众号:「Python读财」如有问题或建议,请公众号留言Seaborn是基于matplotlib的Python可视化库。它提供了一个高级界面来绘制有吸引力的统计图形。Seaborn其实是在matp

10分钟搞懂:亿级用户的分布式数据存储解决方案!

来源:IT进阶思维原创,转载请注明原出处内容提供:李智慧,前阿里巴巴技术专家,《大型网站技术架构》作者6月6日晚,林志玲与Akira公布婚讯、徐蔡坤祝福高考同学超常发挥,粉丝们百万的转发和点赞造成微博

10 分钟彻底理解 Redis 的持久化和主从复制

在这篇文章,我们继续有关Redis方面知识的学习,一起了解一下其中一个非常重要的内容:Redis的持久化机制。什么是Redis持久化?Redis作为一个键值对内存数据库(NoSQL),数据都存储在内存

PHP 内核:讲下 PHP 7 底层虚拟机工作原理 —— Zend Virtual Machine 7.2 版本

本文旨在提供Zend虚拟机的概述,如php7所示。这不是一个全面的描述,但我试图涵盖大部分重要部分,以及一些更精细的细节。 此描述针对的是PHP7.2版(目前正在开发中),但几乎所有内容都适用于PHP

大神程序员,夜夜coding到天明?Python之父昼伏夜出,PHP创始人24小时都在线

栗子鱼羊 发自凹非寺转自量子位 |公众号QbitAI大神程序员,夜夜coding到天明?有位名叫IvanBessarabov(简称“伊万”)的好事者,刚刚统计了各路大佬的代码提交(gitcommit)

蓦然回首,Java 已经 24 岁了!

01、真没想到,Java竟然24岁了(算是90后)!提起Java,印象最深刻的当然就是:classCmower{ publicstaticvoidmain(String[]args){  System

简化企业组网 H3C S1224F以太网交换机评测

相信很多企业都有自己的无线组网,且企业业务和办公越来越依赖于无线网络。有一些企业认为,在BYOD时代,无线组网满足了大部分的办公需求,相比有线来讲更加便捷易用。但在网络的稳定性方面,有线网络有着更明显

124. Binary Tree Maximum Path Sum - 二叉树中的最大路径和

1描述给定一个非空二叉树,返回其最大路径和。路径:一条从树中任意节点出发,达到任意节点的序列。该路径至少包含一个节点,且不一定经过根节点。用例 输入:[1,2,3] 1 /\ 23 输出:6输入: