LWM:一个基于RingAttention的1M长上下文,多模态(语言、图像、视频)的开源工作
作者: NLP前沿 来源: NLP前沿
“
填填坑,看看假期的一些热门的工作,sora的技术报告这里不发了,很多技术笔者不太熟悉。
https://arxiv.org/pdf/2402.08268v1.pdf
https://github.com/LargeWorldModel/LWM
文章中提到的训练步骤主要分为两个阶段:第一阶段是学习长上下文语言模型(Learning Long-Context Language Models),第二阶段是学习长上下文视觉-语言模型(Learning Long-Context Vision-Language Models)。
第一阶段:学习长上下文语言模型
1.上下文扩展(Extending Context) :
-
使用RingAttention技术,通过分块计算和序列并行,理论上可以扩展到无限上下文,仅受限于可用设备数量。
-
分块计算:将长序列分割成多个较小的块(blocks),每个块包含固定数量的标记(tokens)。这样,模型只需要计算每个块内的注意力权重,而不是整个序列。
-
序列并行:在训练过程中,可以并行处理多个块,每个块由不同的GPU处理。这种方法允许模型在多个设备上同时处理序列的不同部分,从而提高了训练效率。
-
RingAttention 使用一个环形结构来组织块,这样每个块只需要与其相邻的块进行通信。这种结构减少了通信开销,因为每个块只需要与其直接相邻的块交换信息
-
-
采用渐进式训练方法,从32K标记开始,逐步增加到1M标记,以有效扩展上下文大小。
- RingAttention 支持渐进式训练,这意味着模型可以从处理较短的序列开始,然后逐步增加序列长度。这种方法有助于模型逐步学习处理更长序列的能力,同时保持训练效率。
*训练步骤(Training Steps) :
-
初始化模型参数,然后逐步增加上下文长度,分为5个阶段:32K、128K、256K、512K和1M标记。
-
在每个阶段,使用不同版本的Books3数据集进行训练,这些数据集经过过滤,以适应当前的上下文长度。
*聊天微调(Chat Fine-tuning for Long-Context Learning) :
-
构建模型生成的问答(QA)数据集,通过将文档分割成固定大小的块,然后使用短上下文语言模型生成问题和答案对。
-
在长上下文长度(如32K标记)下,通过连接相邻的块和在序列末尾添加相关的QA对来构建单个32K标记的示例。
第二阶段:学习长上下文视觉-语言模型
1.视觉架构修改(Architectural Modifications For Vision) :
-
使用预训练的VQGAN将图像和视频帧转换为离散标记。
-
引入新的标记来区分文本生成的结束和视觉生成的开始,以及视频帧的结束。
*训练步骤(Training Steps) :
-
从LWM-Text-1M文本模型初始化,然后在大量结合文本-图像和文本-视频数据上进行渐进式训练。
-
分别在1K、8K、32K、128K和1M标记的序列长度上进行训练,每个阶段都从先前的较短序列长度阶段初始化。
*评估结果(Vision-Language Evaluation Results) :
-
在长视频理解、图像理解和短视频理解等任务上评估模型性能。
-
展示了LWM模型在处理长视频和图像生成方面的能力。
更多AI工具,参考Github-AiBard123,国内AiBard123