从新的角度看待大模型微调
作者: 小白学视觉 来源: 小白学视觉
点击上方“**小白学视觉** ”,选择加"**星标** "或“**置顶** ”
重磅干货,第一时间送达![](https://api.allorigins.win/raw?url=https://mmbiz.qpic.cn/mmbiz_jpg/ow6przZuPIENb0m5iawutIf90N2Ub3dcPuP2KXHJvaR1Fv2FnicTuOy3KcHuIEJbd9lUyOibeXqW8tEhoJGL98qOw/640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1&wx_co=1)
一、前言
一切要从最近大火的Lora(《LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS》 )说起,该文章在ICLR2022中提出。说的是利用低秩适配(low-rank adaptation)的方法,可以在使用大模型适配下游任务时只需要训练少量的参数即可达到一个很好的效果。
LoRA是怎么去微调适配下游任务的?流程很简单,LoRA利用对应下游任务的数据,只通过训练新加部分参数来适配下游任务。而当训练好新的参数后,利用重参的方式,将新参数和老的模型参数合并,这样既能在新任务上到达fine-tune整个模型的效果,又不会在推断的时候增加推断的耗时。
LoRA的示意图如下:
图中蓝色部分为预训练好的模型参数,LoRA在预训练好的模型结构旁边加入了A和B两个结构,这两个结构的参数分别初始化为高斯分布和0,那么在训练刚开始时附加的参数就是0。A的输入维度和B的输出维度分别与原始模型的输入输出维度相同,而A的输出维度和B的输入维度是一个远小于原始模型输入输出维度的值,这也就是low-rank的体现(有点类似Resnet的结构),这样做就可以极大的减少待训练的参数了。在训练时只更新A,B的参数,预训练好的模型参数是固定不变的。在推断时可以利用重参数(reparametrization)思想,将AB与W合并,这样就不会在推断时引入额外的计算了。而且对于不同的下游任务只需要在预训练模型基础上重新训练AB就可以了,这样也能加快大模型的训练节奏。
由于本文不具体介绍LoRA,所以详细信息可以查看LoRA原文。我们只需要知道LoRA文章后续的实验已经论证该方法的有效性。那么进一步思考,为什么LoRA的这种思路能work的不错呢?
答案就是接下来要讲的Intrinsic dimension了。这点LoRA原文也提到过,该文章灵感来源于下面两篇文章:
-
MEASURING THE INTRINSIC DIMENSION OF OBJECTIVE LANDSCAPES , 发表在ICLR2018,为了方便接下来该论文称为【论文1】
-
INTRINSIC DIMENSIONALITY EXPLAINS THE EFFECTIVENESS OF LANGUAGE MODEL FINE-TUNING , 发表在ACL2021,为了方便接下来该论文称为【论文2】
二、本征维度(Intrinsic dimension)定义
本征维度的概念在由【论文1】提出。
训练一个神经网络往往包含如下几步:
-
对于一个给定的数据集,先设计网络的结构和选择对应的loss
-
对网络中的参数进行随机的初始化
-
训练网络使得loss越来越低
而训练阶段可以认为是在一个固定的目标图(objective landscape)上,寻找出有效的路径。
这里解释一下为什么是固定的目标图。因为在数据集和网络结构固定下来后,待优化的问题就已经定义好了,所以目标图也就是确定的了。
如下图所示:
那么对于一个参数量为D的模型,我们训练该模型,也就意味着在D维空间上寻找有效的解。文章认为D可能是冗余的,可能实际上只需要优化其中的d个参数就可以找到一个有效的解。
用公式表示如下:
其中 表示D维的优化参数,表示随机初始化的一个参数并且在训练时是不进行更新的, P是一个随机初始化的大小的矩阵且训练时也不进行更新,表示待优化的d维参数。
也就是说可以在训练网络时只更新d维参数,就可以达到该网络应有的效果。那么这个d就是所谓的该模型的本征维度。
这里讲完可能还有点晕,我们看一下如下这张图:
上图中,蓝色部分为初始化好的网络参数,绿色为, 红色为。网络训练的时候只训练红色部分,其它参数都是固定的。d就是本征维度。
上面讲的只更新d维参数,让网络达到应有的效果,那么什么“应有的效果”呢?文章定义,在只更新d维参数的情况下,网络效果达到训练原始模型时效果的90%时,那么久认为达到了“应有的效果”,并且的d就为本征维度。
例如在做mnist这个数字分类任务时,如果原始模型精度能到0.9,那么在只更新d维参数的时候,精度能够达到90% x 0.9 = 0.81,就认为这时候的d为本征维度记为。
三、使用本征维度思考大模型微调的有效性
【论文2】将之前提出的本征维度用来思考大模型微调的有效性,为什么现在用几百或者几千张图片就可以对大模型进行有效的微调?
根据【论文1】阐述,对于某一类问题,在一定精度上(比如达到90%的精度)有本征特征的存在。对于大模型而言,进行本征维度的测试就能知道在解决某一类下游问题时,需要调整多少参数就能近似的解决当前的问题。如果真的有实验能证明仅仅调整少数的参数就能很好的解决下游问题,那么也就能回答上述问题,即对大模型做少量的微调(调整少量的参数),就能解决当前的问题。
下面无特殊说明的话,“文章”指的都是【论文2】
3.1 实验一:对于大模型而言,是否存在本征维度
同【论文1】,【论文2】中也利用公式来进行模型的训练,即训练时只调整d维参数 。但与【论文1】中的实验有点不同的是,【论文1】中是随机初始化的,而【论文2】中是预训练好的参数。
【论文2】首先选择BERT-Base\BERT-Large\RoBERTa-Base\RoBERTa-Large四个模型,并选择GLUE benchmark中的 MRPC和QQP两个数据集(两个数据集都是用来测试句子对是否相同意义的任务)。
实验结果如下图所示:
上下两个子图分别表示MRPC和QQP两个任务,每个子图有四条实线表示四个模型的准确率,四条虚线表示达到fine-tune整个模型90%的准确率的值,横坐标表示训练d维的大小。从图中可以看出两个任务,四个不同的模型,只需要训练较小的d维参数就可以达到90%的精度。本征维度这个概念在大模型中是成立的。
所以在训练某个下游任务时,只需要训练少量参数就能达到不错的效果了。这时文章开头的问题就已经解决了。但是作者做了一些其他的实验,发现了一些有意思的结论。
3.2 预训练的好坏与本征维度的关系
文章提出这样一个假设,预训练模型能够隐式的降低模型在NLP各个任务的本征维度。
基于这个猜想,文章做了下面实验,在预训练RoBERTa-base模型的时候,每隔10K保存下对应的预训练模型,然后测试保存下来的预训练模型在MRPC, QQP, Yelp Polarity,SST-2, MNLI,ANLI六个数据集本征维度。
结果如下:
可以看出,在不同数据集上,有相同的趋势,就是预训练次数越多,模型在各个任务上的本征维度越低。实验并没有特意去优化所谓的本征维度,只是预训练久一点而已。所以印证了预训练模型的表征能力越强(训练的越好)本征维度越小。
3.3 预训练模型参数与本征维度的关系
本来在做预训练参数与本征维度关系的时候,需要统一模型的结构,这样更有说服力。但是作者说,这样要训练很多大模型的实验,为了更方便的对比文章根据已有的结构来做实验。从实验结果的趋势来看,不同结构也能得到有效的结论。
文章利用已有的预训练模型,在MRPC数据集上计算本征维度。
实验结果如下:
上图中纵坐标表示本征维度的值,很坐标表示模型的参数量。从图中的趋势可以明显看出,模型越大本征维度越小,即越强的模型本征维度越低。
3.4 本征维度与泛化能力的关系
上面介绍了fine-tune(3.1)、预训练(3.2)和本征维度的关系,但本征维度与泛化能力的关系还没有验证。即我们现在知道了让本征维度小的方式,但是本征维度小了,泛化能力就能上去吗?
文章又做了下面的实验,把3.2保存下来的模型,在对应的的本征维度上,进行不同数据集的测试,结果如下:
可以看出本征维度低的模型,训练出来的模型准确率是更高的。也就是说本征维度越低,泛化性能越好。
回到引言的问题:为什么LoRA思路能work?
因为大模型存在本征维度的概念,只需要调整少量参数就能在下游任务上得到很好的效果。
参考文献
本文为粉丝投稿,转载自:https://michaelliudev.blog.csdn.net/article/details/131745794
未经原作者授权禁止转载。
下载1:OpenCV-Contrib扩展模块中文版教程
在「**小白学视觉** 」公众号后台回复:**扩展模块中文教程****,** 即可下载全网第一份OpenCV扩展模块教程中文版,涵盖**扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理** 等二十多章内容。
下载2:Python视觉实战项目52讲
在「**小白学视觉** 」公众号后台回复:**Python视觉实战项目****,** 即可下载包括**图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别** 等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「**小白学视觉** 」公众号后台回复:**OpenCV实战项目20讲****,** 即可下载含有**20** 个基于**OpenCV** 实现20个**实战项目** ,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。**请勿** 在群内发送广告,否则会请出群,谢谢理解~
![](https://api.allorigins.win/raw?url=https://mmbiz.qpic.cn/sz_mmbiz_jpg/4AqSEnNUeric5L4iaE0Ev62IMvcQhyKHgGtrLkTO4pa7aMsoWkoT5wOkrHDiaNfPK1tTfTKLpxC8fyVxnjwjV3AuA/640?wx_fmt=other&wxfrom=5&wx_lazy=1&wx_co=1&random=0.23600528518335206&random=0.7536859237629767&random=0.8531908591825383&random=0.4806503559435422&random=0.6129608204402515&random=0.8200900513725702&random=0.03307512416643643&random=0.6886017910001143&random=0.7395588037951473&random=0.6352837921902186&random=0.30347866949269453&random=0.7038012468080515&random=0.8284857819643325&random=0.9749264331008225&random=0.8250331877047652&random=0.11700576547150043&random=0.9975773186161709&random=0.4379201216736548&random=0.4195828464834974&random=0.5340342947166905&random=0.6876807338631687&random=0.3939837984539263&random=0.3019324750763477&random=0.27349620953709985&random=0.6411853932393252&random=0.5529938696288133&random=0.6605551596662009&random=0.5885758351314165&random=0.4827946502272662&random=0.6150317866736881&random=0.5353939627000315&random=0.06823854245939232&random=0.22599635624892134&random=0.3499311427827825&random=0.22063153819233294&random=0.9951194897438997&random=0.1338193031411008&random=0.44384832775905214&tp=webp)
![](https://api.allorigins.win/raw?url=https://mmbiz.qpic.cn/mmbiz_png/4AqSEnNUer8Co8oDDJzFWWECpytMibasA1TMFTFGku5JPkwhp0ywxsmibkOhmichYM6iah0YrIYSsRn9YFWYXuic5Iw/640?wx_fmt=other&wxfrom=5&wx_lazy=1&wx_co=1&random=0.5774347616294973&random=0.9978291741892458&random=0.9372589662953192&random=0.11191579995970558&random=0.10262418582549904&random=0.6380495845214269&random=0.7784165226526307&random=0.27856596734359274&random=0.02697602976562341&random=0.26581795011892173&random=0.039996022172145596&random=0.21759946619282666&random=0.8468474982053957&random=0.3121106256513675&random=0.5055025089425877&random=0.6802497835412251&random=0.1505023208163898&random=0.5453281379147104&random=0.6268842835345501&random=0.6012592853181984&random=0.8864916262303144&random=0.8114129344469361&random=0.9263710705779902&random=0.21959750537418787&random=0.3046588216352746&random=0.5847112869442301&random=0.5403279438473796&random=0.5338488046258207&random=0.7656235328882937&random=0.9462572093978854&random=0.2295769393686915&random=0.7200559950640175&random=0.2379898552892521&random=0.018725790142357424&random=0.8102885619723987&random=0.987922209417774&random=0.7009694790628918&random=0.9359105428691536&tp=webp)
更多AI工具,参考Github-AiBard123,国内AiBard123