AI 文摘

大语言模型的参数高效微调





作者: HOME of Being 来源: HOME of Being

大模型微调技术的入门文章,欢迎批评指正。特别感谢上海人工智能实验室(Shanghai AI Laboratory)的算力支持及其书生·浦语大模型(InternLM2)的模型基座。

我们期待回答下列问题:

  • 什么是大模型和它的微调

  • 替代方法:检索生成增强(RAG)

  • 需要调整全部参数吗

  • 参数高效微调的分类和原理

  • LoRA的代码实现:基于InternLM2-1.8b和XTuner

下游任务和迁移学习:什么是微调(Fine Tuning)

闻道有先后,术业有专攻,如是而已。

——韩愈:师说

大模型,即大语言模型(Large Language Models, LLM),主要是指生成式的(generative,在Transformer范式下,也即decoder-only架构的)、参数规模巨大(一般认为在1B-100B不等是门槛)的语言模型。

这包含了两个角度,“生成式模型”直指它的架构(相应地,它的功能)本质,“大参数模型”直指它的数学本质——合起来看,大模型是一组特定模型架构的参数权重(weights),这组权重被一组模型配置文件(model configuration)来描述。这个架构-参数二元结构的本质决定了它的训练、推理方式和物理存储方式。

我们经常会说大模型是一个预训练(pre-trained)产品,这一方面是说大模型是被训练好等待用户提问的,即等待用户输入提示词(prompt)由大模型进行推理的,另一方面是说大模型是等待被进一步调整来适配特定的下游任务(downstream tasks)的。

这个适配特定下游任务的过程也叫做迁移学习(transfer learning):顾名思义,在领域A的数据集上训练的模型用于推理领域B的任务。之所以需要迁移学习,是因为大模型被预训练好后,其权重是完全固定的,而我们的世界时刻在变,大模型的知识可能会落后了,所谓“闻道有先后”;另一方面,它面对特定领域可能不够专业,所谓“术业有专攻”。两者都会造成大模型的幻觉(hallucination),即说胡话。

微调是迁移学习的一种形式。所谓微调,指的是我们固定基座模型(foundation model)的架构,利用特定数据集来调整其权重,即下图的增量调整(delta-tuning),以反映特定领域知识(domain knowledge)。如果说基座大模型好比在一切领域都懂一点、但都不够深入的学生,微调后的大模型则好比是在某一特定领域特别专长,(作为代价)在其他领域则更弱一些的偏科生。

微调还是检索生成增强(RAG):权衡与讨论

前面提到,我们的目标是实现大模型在实际任务上的迁移学习以减少大模型的幻觉。微调并非唯一的解决办法,一种竞争性替代是检索生成增强(retrieval augmented generation, RAG)。

仍然把大模型比作一个学生。微调是让大模型对特定领域数据进一步学习(continual learning, 持续学习),让他考试前突击复习,重塑他的知识体系,但往往会造成偏科;检索生成增强则为大模型外挂一个知识库(vector DB, 向量数据库),让闭卷考试变成开卷考试,但查找参考答案会在作答环节需要更多时间。

站在产品使用者的角度考虑这个问题。RAG的优势在于灵活,它不改变模型权重,训练成本较低,知识库是即插即用的,便于更换和更新,且不少框架(例如LangChain)已经集成了其实现方法。其缺点在于推理环节的检索-召回程序造成了额外的计算时间。简而言之,和微调相比,RAG训练成本低、推理成本高。关于RAG技术我们将在后面的文章中介绍。

当然,我们可以同时做微调和检索生成增强,这叫做RAFT(retrieval augmented fine tuning),也可以两者都不做,只把领域知识放到提示词中做提示工程(prompt engineering)。

理论上讲,提示工程更好更便捷。但制约提示工程的现实因素在于大模型支持的上下文长度。如果大模型能支持100K上下文长度,我们不需要做向量库,直接把知识库输入提示词就可以了。

大模型上下文长度的限制主要来自于其主流的底层算法,即Transformer架构使用的注意力机制的平方时间复杂度。如今,例如KimiChat的产品已经能支持上面提到的超长上下文了。目前的研究进展有两块:

第一,针对Transformer注意力机制的改进:谷歌的无限注意力机制(Munkhdalai, T., Faruqui, M., Gopal, S., 2024, Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention)。

第二,从底层算法上替代Transformer:线性复杂度的、基于RNN的Mamba(Gu, A., Dao, T, 2023, Mamba: A Selective State Space Model for Long Sequence Modeling)。

需要调整全部参数吗?

从现在开始,让我们专注微调本身。我们前面只说了要做迁移学习、增量学习,对参数“微微地”调整,从计算角度看,一个现实问题是:我们需要调整全部参数吗?调整部分参数能不能达到目的呢?

前者被称作全参微调(full tuning),后者被称作参数高效微调(parameter efficient fine tuning, PEFT)。

在AdaMix: Mixture-of-Adaptations for Parameter-efficient Model Tuning (Wang et al., 2022)中,研究者指出,调整约0.2%参数的LoRA、AdaMix等参数高效微调方法的效果得分和调整100%参数的全参微调效果相近,如下图所示。

这是一个很棒的结果。我们只需要调整很少的参数就能达到相近甚至更好(例如使用AdaMix或下面替到的P-tuning)的效果!这建立了参数高效微调有效性的基础,也即回答了为什么参数高效微调是“高效”的。

参数高效微调的分类和原理

依据其原理,参数高效微调分为下面几类。

整体来看,参数高效微调分为可加的、选择性的和重参数化的三种。

一种常见范式是适配器(adapters),它就像大模型的一个额外插件。更广义地讲,它属于第一种范式——可加的(additive),即把训练好的适配器添加到基座大模型上。

代表性的可加的方法还包括(IA)^3,即为Transformer架构加入额外的三类块进行训练,如下图所示。

另一种属于可加性方法的是软提示词方法(soft prompt),它也属于可加的,例如P-Tuning为decoder-only的架构外挂了一个encoder架构,通过提示词模板生成虚词元(pseudo tokens)辅助调参;Prefix-Tuning使得decoder和encoder都具备可训练参数。

还有一种是选择的(selective),即选择极少一部分关键参数进行调整。

最后,我们重点介绍重参数化的(reparameterization based)——低秩法(low rank, LoRA)。

LoRA和QLoRA

LoRA的原理是把参数的增量学习矩阵做一个低秩分解。其背后的数学假设是:深度神经网络看似高维的参数空间,其实只需要在其一个低维子空间上进行训练就可以了,也就是这个高维参数空间存在一个低秩表示。

在Lora中,模型的权重被分解为两部分:一部分是低秩的适配器权重,另一部分是原始模型参数。适配器权重是通过随机初始化和微调得到的,它们是稀疏的,并且与原始模型参数相比,它们的数量非常少。这样设计的好处是,它保持了模型的稀疏性,同时允许模型在特定任务上进行适应。

QLoRA则是LoRA的一种性能优化,我们将在后续大模型的量化中介绍。下面,我们通过XTuner来实现InternLM2的LoRA微调。

InternLM2-1.8b的LoRA微调:基于XTuner实现

环境和数据准备

XTuner的安装


#terminal
conda create --name xtuner0.1.17 python=3.10 -y
conda activate xtuner0.1.17
cd ~
mkdir -p /root/xtuner0117 && cd /root/xtuner0117
git clone -b v0.1.17  https://github.com/InternLM/xtuner
cd /root/xtuner0117/xtuner
pip install -e '.[all]'

微调数据集准备


mkdir -p /root/ft && cd /root/ft
mkdir -p /root/ft/data && cd /root/ft/data
touch /root/ft/data/generate_data.py

下列代码写入generate_data.py


import json
n =  10000
# 初始化OpenAI格式的数据结构
data = [
    {
        "messages": [
            {
                "role": "user",
                "content": "请做一下自我介绍"
            },
            {
                "role": "assistant",
                "content": "我是的小助手,内在是上海AI实验室书生·浦语的1.8B大模型哦"
            }
        ]
    }
]
for i in range(n):
    data.append(data[0])
with open('personal_assistant.json', 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=4)   

运行数据集生成代码

cd /root/ft/data
python /root/ft/data/generate_data.py

模型准备


mkdir -p /root/ft/model
cp -r /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b/* /root/ft/model/
ln -s /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b /root/ft/model

模型配置文件

xtuner list-cfg
xtuner list-cfg -p internlm2_1_8b

mkdir -p /root/ft/config
xtuner copy-cfg internlm2_1_8b_qlora_alpaca_e3 /root/ft/confi

配置文件修改

下列代码覆盖internlm2_1_8b_qlora_alpaca_e3_copy.py


# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig)
  

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import openai_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
                                 VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
  

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
pretrained_model_name_or_path = '/root/ft/model'
use_varlen_attn = False
  

# Data
alpaca_en_path = '/root/ft/data/personal_assistant.json'
prompt_template = PROMPT_TEMPLATE.default
max_length = 1024
pack_to_max_length = True
  

# parallel
sequence_parallel_size = 1
  

# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 16
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 0
max_epochs = 2
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.03
  

# Save
save_steps = 300
save_total_limit = 3  # Maximum checkpoints to keep (-1 means unlimited)
  

# Evaluate the generation performance during the training
evaluation_freq = 300
SYSTEM = ''
evaluation_inputs = ['请你介绍一下你自己', '你是谁', '你是我的小助手吗']
  

#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    trust_remote_code=True,
    padding_side='right')
  

model = dict(
    type=SupervisedFinetune,
    use_varlen_attn=use_varlen_attn,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        quantization_config=dict(
            type=BitsAndBytesConfig,
            load_in_4bit=True,
            load_in_8bit=False,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4')),
    lora=dict(
        type=LoraConfig,
        r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        bias='none',
        task_type='CAUSAL_LM'))
  

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
alpaca_en = dict(
    type=process_hf_dataset,
    dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=openai_map_fn,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length,
    use_varlen_attn=use_varlen_attn)
  

sampler = SequenceParallelSampler \
    if sequence_parallel_size > 1 else DefaultSampler
train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=alpaca_en,
    sampler=dict(type=sampler, shuffle=True),
    collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
  

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(
        type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale='dynamic',
    dtype='float16')
  

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        end=max_epochs,
        convert_to_iter_based=True)
]
  

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
  

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        system=SYSTEM,
        prompt_template=prompt_template)
]
  

if use_varlen_attn:
    custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
  

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 10 iterations.
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per `save_steps`.
    checkpoint=dict(
        type=CheckpointHook,
        by_epoch=False,
        interval=save_steps,
        max_keep_ckpts=save_total_limit),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)
  

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)
  

# set visualizer
visualizer = None
  

# set log level
log_level = 'INFO'
  

# load from which checkpoint
load_from = None
  

# whether to resume training from the loaded checkpoint
resume = False
  

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
  

# set log processor
log_processor = dict(by_epoch=False)

代码解释:

PART 1 Settings:涵盖了模型基本设置,如预训练模型的选择、数据集信息和训练过程中的一些基本参数(如批大小、学习率等)。

PART 2 Model & Tokenizer:指定了用于训练的模型和分词器的具体类型及其配置,包括预训练模型的路径和是否启用特定功能(如可变长度注意力),这是模型训练的核心组成部分。

PART 3 Dataset & Dataloader:描述了数据处理的细节,包括如何加载数据集、预处理步骤、批处理大小等,确保了模型能够接收到正确格式和质量的数据。

PART 4 Scheduler & Optimizer:配置了优化过程中的关键参数,如学习率调度策略和优化器的选择,这些是影响模型训练效果和速度的重要因素。

PART 5 Runtime:定义了训练过程中的额外设置,如日志记录、模型保存策略和自定义钩子等,以支持训练流程的监控、调试和结果的保存。

模型训练

xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train

模型续训

xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train --resume /root/ft/train/iter_600.pth

模型格式转换为hugging face格式


mkdir -p /root/ft/huggingface
xtuner convert pth_to_hf /root/ft/train/internlm2_1_8b_qlora_alpaca_e3_copy.py /root/ft/train/iter_768.pth /root/ft/huggingface

把微调结果整合到基座模型上


mkdir -p /root/ft/final_model
export MKL_SERVICE_FORCE_INTEL=1
xtuner convert merge /root/ft/model /root/ft/huggingface /root/ft/final_model

可以看到适配多种基座大模型的提示词模板templates脚本,然后进行对话

xtuner chat /root/ft/final_model --prompt-template internlm2_chat

也可以这样匹配不同的适配器,便于进行比较和选择

xtuner chat /root/ft/model --adapter /root/ft/huggingface --prompt-template internlm2_chat

Web demo部署


pip install streamlit==1.24.0
mkdir -p /root/ft/web_demo && cd /root/ft/web_demo
git clone https://github.com/InternLM/InternLM.git
cd /root/ft/web_demo/InternLM

参考文献

  1. Wang, Y., Agarwal, S., Mukherjee, S., Liu, X., Gao, J., Awadallah, A. H., & Gao, J. (2022). AdaMix: Mixture-of-Adaptations for Parameter-efficient Model Tuning. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, pages 5744-5760, Abu Dhabi, United Arab Emirates. Association for Computational Linguistics.

  2. Munkhdalai, T., Faruqui, M., Gopal, S. (2024). Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention. arXiv preprint arXiv:2404.07143.

  3. Gu, A., Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint arXiv:2312.00752.

  4. https://www.mercity.ai/blog-post/fine-tuning-llms-using-peft-and-lora.

更多AI工具,参考Github-AiBard123国内AiBard123

可关注我们的公众号:每天AI新工具