阿里开源Lookahead:RAG场景LLM推理吞吐提升2-5倍
作者: 人工智能技术与时代人物风云 来源: 人工智能技术与时代人物风云
一、背景
LLM 发展很快,能力也很强大,而其巨大的推理成本成为了制约其发展的主要因素,因此出现了很多优化 LLM 推理相关的工作。
对于有损优化,相关工作主要集中在量化、蒸馏、剪枝等,比如常见的 LLM.int8()、GPTQ、AWQ 等,KV Cache INT8 量化也是使用非常多的一种手段。
对于无损优化,早期的优化重点主要集中在推理框架和引擎层面,比如 vLLM、LMdeploy 这些框架,相继集成了 PagedAttention、Continuous Batching、FlashAttention 和 FlashDecoding 等特性,也取得了不错的效果。
对于无损优化,最近也逐渐有一些尝试从算法层面优化,比如我们之前介绍的各种投机采样相关工作。然而之前的很多投机采样工作都需要额外的模型或对现有模型进行调整,导致其使用成本比较高,此外,其增加的计算量相比带来的加速比而言性价比不高,比如一次验证 100 个 Token,只接受 5 个 Token,则意味着有 20 倍的计算浪费。最近阿里的蚂蚁团队发布了 Lookahead 投机采样方案(注意和之前介绍的 Lookahead Decoding 不一样,只是重名而已),主要针对 RAG(Retrieval Augmented Generation,检索增强生成) 场景,获得了不错的效果。(PS:实际使用最好还是要综合数据分布,推理框架,batch size 等因素综合考虑)
对应的论文为:[2312.12728] Lookahead: An Inference Acceleration Framework for Large Language Model with Lossless Generation Accuracy
对应的代码库为:GitHub - alipay/PainlessInferenceAcceleration
其他 LLM 加速解码方案可参考:
其他 LLM 推理优化也可以参考:
GPU 相关具体参数可以参考:
多 LoRA 模型 LLM 推理可以参考:
二、 方案
2.1. RAG
LLM 受训练语料的限制,无法感知最新的内容,比如 LLM 训练后的新闻;此外,LLM 也容易产生幻觉,生成不正确的内容。为了解决这个问题,业界提出了通过检索外部知识库来获得额外语料,并使用 ICL(In-Context-Learning,上下文学习)来改进 LLM 生成效果的范式。当用户发起一个生成请求时,首先基于用户的 prompt 来检索相关信息,然后这些信息会被组合到 prompt 中,为生成过程提供额外的上下文,最后将组合后的 prompt 输入 LLM 生成结果。
RAG 的主要优势在于可以避免针对特定任务再次进行训练;用户可以额外附加外部知识库,丰富输入,从而优化模型的输出效果。RAG 因其高可用性和低门槛而成为 LLM 系统中最受欢迎的方案之一,许多 LLM 应用都会基于 RAG 构建。
在 RAG 系统中,LLM 生成的内容很可能来自 Prompt 中之前步骤检索的内容,这就很适合作为投机采样方案中猜测的 Token 序列,避免需要额外的模型或者额外的 Head 来生成待验证 Token 的过程。本文作者也是基于这个情况设计了 Lookahead 机制。
2.2. Lookahead 方案
其实整体的思路和之前的投机采样方案类似,主要就是待验证 Token 的来源,作者从 Prompt 构建待验证 Token 序列,与单序列相比,多序列可以提升接受率,Token 前缀树可以进一步降低成本。如下图 Figure 2 所示,第二行验证了 6 个 Token 只接受了 3 个,而第三行同样验证了 6 个但接受了 4 个。
具体来说,是通过设计了如下图 Figure 3 所示的 Mask 来实现一次验证多个 Token 序列或者 Token 前缀树,这种思路在之前的投机采样方案 SpecInfer 和 Medusa 等也有使用:
2.3. 前缀树
2.3.1. 前缀树维护
作者设计了前缀树,但不是一直不变的,而是不停地更新迭代,具体来说包含如下几个方面:
-
插入:除了 Prompt 中的 Token 会插入前缀树外,新生成的 Token 也会插入。
-
清除:本文的方案只针对单个用户请求,因此当某个 Prompt 的请求处理完之后会直接删除整个前缀树。(最近 LMSYS 团队发布的 SGLang 和 RadixAttention 进一步针对多个请求之间的重复前缀进行了优化,也取得了不错的效果)
-
删除:随着 LLM 不断生成输出,前缀树中节点会越来越多,当超过一定数值后会通过 LRU 机制动态删除最少被使用的节点。
2.3.2. 前缀树检索
通过提供一个 Token 前缀就可以在前缀树中检索相关分支,Token 前缀越长,检索到的分支越少,Token 前缀越短,检索到的分支越多。
为了在分支的数量和相关性之间取得平衡,作者采用多阶段检索策略。首先尝试匹配更长的前缀,如果与匹配分支关联的 Token 数量明显小于 CDL(Critical Decoding Length,最大解码长度),将缩短前缀的长度并重试匹配过程,直到获得足够的匹配分支的 Token。如果匹配到的分支的所有 Token 数小于阈值,则使用所有 Token 来验证,如果超过阈值,则挑选最高频的 Token 来验证。
三、实验
3.1. 实验设置
评估验证使用了如下的两个数据集,AntRAG 为蚂蚁内部数据集,Dolly 为开源数据集,输入和输出 Token 数如下图 Table 2 所示,评估都使用 1000 条 test 数据:
评估的两个模型配置如下 Table 3 所示,AntRAG 使用 Antglm-10B,Dolly 使用 LLama2-13B:
评估实验都在 A100-SXM(80G) GPU 上完成,没有特别说明的情况下 batch size 都为 1。
3.2. 评估结果
推理加速结果如下图 Table 4 所示,可见在 AntRAG 上很明显,达到 5 倍,在 Dolly 相对差点,也有 2 倍。其中的 Baseline 使用的是 Huggingface 的 Transformer 库(性能可能比较低,比如 vLLM 或 TensorRT-LLM 在这个数据上的性能如何?),LLMA 为微软发布的方案(GitHub - microsoft/unilm: Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities),而 Lookahead(Parallel) 上面介绍的多分支的方案,Lookahead(Hierarchical) 为前缀树的方案:
作者在 GitHub 上也提供了其他模型在 Dolly-15k 和 GSM-8k 上的测试结果,提升同样只有 2 倍左右,其中 decoding length(生成 Token 长度)为 64,branch length(并行验证的 Token 数) 为 8:
如下图所示为不同 decoding length 和不同 branch length 下的吞吐,可以看出,branch length 越长,吞吐越高,在 20-40 左右趋于平稳。
如下图所示为不同 decoding length 和不同 branch length 下的 EDL(Effective Decoding Length,接受的 Token 数),可以看出,branch length 越长,接受的 Token 数越多。当 branch length 为 30-40 时,接受的 Token 数在 9-12,基本可以达到 1/4,冗余计算相比之前的方案少了很多。
3.3. 消融实验
如下图 Table 5 所示,作者也验证了不同配置的影响,比如是否使用 Prompt 中的 Token,是否使用输出的 Token 等:
此外作者也验证了前缀树中 Token 个数的影响,测试发现当前缀树中 Token 个数为 Decoding Length 的 16 或 32 倍时获得最好的吞吐,如下图 Table 6 所示:
作者也进一步验证了在不同数据场景的情况,如下图 Table 7 所示和数据密切相关,整体加速比还不错,实际使用时需进一步评估:
作者也进一步补充了 batch size 为 2 和 batch size 为 4 的情况,如下图 Table 8 可以看出,随着 batch size 增加,加速比开始降低,这是因为使用大 batch size 后 GPU 的访存瓶颈会降低,为投机采样并行验证留下的冗余算力空间也会降低:
四、参考链接
更多AI工具,参考Github-AiBard123,国内AiBard123