社区供稿|快速上手谷歌Gemma模型中文指令微调
作者: Hugging Face 来源: Hugging Face
文/ 魏巍
谷歌在 2 月 21 日放出开放权重的 Gemma 系列大模型,包括 2B 和 7B 两个大小,并且有预训练和指令微调两个版本。虽然 Gemma 的预训练数据里面包含多种语言,不过在官方的技术报告里,明确指出了做指令微调的时候,只用了英文:
经过英文指令微调的 Gemma 模型,仍然保留一定程度的指令跟随能力,可以理解一部分中文指令,但有些时候我们未必希望使用官方的指令微调模型,而是希望将预训练模型重新进行中文指令微调,来达到我们的要求,所以在这里我们就分享 3 个方法来进行 Gemma 的中文指令微调。
注意在这里我们统一使用 gemma-2b 这个模型,同时我们选择了 Hello-SimpleAI/HC3-Chinese 数据集来作为微调数据。这个数据集有不同主题的问答,包括问题,人类回答和 ChatGPT 回答,涵盖了金融,百科,法律等等诸多题材,我们只使用 baike 这个子集,且不使用 ChatGPT 回答。
-
gemma-2bhttps://hf.co/google/gemma-2b-it
-
Hello-SimpleAI/HC3-Chinesehttps://hf.co/datasets/Hello-SimpleAI/HC3-Chinese
方法一:Hugging Face TRL + Colab GPU
Hugging Face 的工程师 Phil Schmid 在 2024 年 1 月底写了 一篇很好的博客,详细讲解如何用 Hugging Face TRL 库来进行模型微调。其中的代码绝大部分都可以复用,不过他使用的是 b-mc2/sql-create-contex 数据集(用于自然语言生成 SQL 语句),所以我们只需要将数据这部分换成我们想要的中文指令微调数据即可。
-
Phil 博文链接https://www.philschmid.de/fine-tune-llms-in-2024-with-trl
-
HF TRL 库https://hf.co/docs/trl/en/index
-
b-mc2/sql-create-contex 数据集https://hf.co/datasets/b-mc2/sql-create-context
from datasets import load_dataset
# Convert dataset to OAI messages
system_message = “““你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
"””def create_conversation(sample):
return {
“messages”: [
{“role”: “system”, “content”: system_message},
{“role”: “user”, “content”: sample[“question”]},
{“role”: “assistant”, “content”: sample[“human_answers”][0]} # for whatever reason the dataset uses a list of answers
]
}# Load dataset from the hub
dataset_dict = load_dataset(“Hello-SimpleAI/HC3-Chinese”, name=“baike”)
dataset = dataset_dict[’train']print(create_conversation(dataset[0]))
# # Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=[“id”, “chatgpt_answers”], batched=False)dataset = dataset.train_test_split(test_size=0.1)
# save datasets to disk
dataset[“train”].to_json(“train_dataset.json”, orient=“records”)
dataset[“test”].to_json(“test_dataset.json”, orient=“records”)
TRL 库可以自行处理 ChatML 格式,所以我们只需要将提问和人类回答填入模板即可。之后即可运行 trainer.train() 进行 LoRA 训练。在免费的 Colab T4 GPU 上大约需要 3.5 小时完成训练(当前设置为 3 个 epoch,可酌情减少)。
在进行指令微调之前,如果我们测试模型的中文能力:
Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 rm
Generated Answer:
,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,请 rm,
模型很快陷入循环(注意这里使用的不是已经做过指令微调的模型 gemma-2b-it,所以陷入循环是正常的,也正是我们要进行指令微调的原因)。完成我们的微调之后再测试:
Query:
我有一个信息科学相关的问题,请用中文回答,什么是 rm
Original Answer:
RM 格式是 RealNetworks 公司开发的一种流媒体视频文件格式,可以根据网络数据传输的不同速率制定不同的压缩比率,从而实现低速率的 Internet 上进行视频文件的实时传送和播放。它主要包含 RealAudio、RealVideo 和 RealFlash 三部分。
Generated Answer:
rm 是 remove 的缩写,用于删除指定的目录或文件,语法形式为:rm -r /path/to/file。
这里生成的答案就很合理也没有陷入循环(虽然跟 label,也就是 Origianl answer 不一样,但也是正确答案)。
所以这就是方案一。注:Phil Schmid 比较高产,最近两天又放出了一个 Gemma 7B + Dolly 数据集的 微调代码,稍微改改数据部分也可进行中文指令微调。
- 微调代码链接https://hf.co/philschmid/gemma-7b-dolly-chatml
方法二:Keras (JAX 后端) + Kaggle TPU -> Hugging Face
这个方法的思路是修改谷歌官方的微调教程,使用 Keras 在 Kaggle TPU 上微调 gemma-2b 模型,然后将模型转换成 Hugging Face 模型。用于这里使用 TPU v3,速度比 Colab 上的 T4 GPU 快很多。
官方教程 使用的是 IMDB 数据集,我们替换为 Hello-SimpleAI/HC3-Chinese baike 数据集:
-
官方教程链接https://ai.google.dev/gemma/docs/distributed_tuning
!wget -O baike.jsonl https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese/raw/main/baike.jsonl
import re
import json
data = []
context = “你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题。\n”
with open(“baike.jsonl”) as file:
for line in file:
features = json.loads(line)
template = context + “问题:\n{question}\n答案:\n{human_answers[0]}”
data.append(template.format(**features))# Manually construct a test case;
# Already made sure the finetuning dataset contains nothing about zsh
test_prompt = context + “问题:\n我有一个信息科学相关的问题,请用中文回答,什么是 zsh\n答案:\n”
# 4616 in total in baike split
train_data = data[:4600]
之后进行训练,完成之后将模型转化为 Hugging Face 模型:
# Download the conversion script from KerasNLP tools
!wget -nv -nc https://raw.githubusercontent.com/keras-team/keras-nlp/master/tools/gemma/export_gemma_to_hf.py
# Run the conversion script
# Note: it uses the PyTorch backend of Keras (hence the KERAS_BACKEND env variable)
!KERAS_BACKEND=torch python export_gemma_to_hf.py \
--weights_file $FINETUNED_WEIGHTS_PATH \
--size $MODEL_SIZE \
--vocab_path $FINETUNED_VOCAB_PATH \
--output_dir $HUGGINGFACE_MODEL_DIR
接下来就可以使用 Hugging Face 去调用微调过的模型:
import transformers
model = transformers.GemmaForCausalLM.from_pretrained(
HUGGINGFACE_MODEL_DIR,
local_files_only=True,
device_map="auto", # Library "accelerate" to auto-select GPU
)
tokenizer = transformers.GemmaTokenizer.from_pretrained(
HUGGINGFACE_MODEL_DIR,
local_files_only=True,
)
def test_transformers_model(
model: transformers.GemmaForCausalLM,
tokenizer: transformers.GemmaTokenizer,
) -> None:
inputs = tokenizer([test_prompt], return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=200)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"{output}\n{'- '*40}")
# This run on CPU so it is a bit slow
test_transformers_model(model, tokenizer)
对比微调前的输出:
'你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题。\n问题:\n我有一个信息科学相关的问题,请用中文回答,什么是 zsh\n答案:\nzsh 是一个命令行界面(CLI)的 shell,它支持许多命令行工具,包括 bash, fish, ksh, mksh, pdksh, tcsh, zsh, 和 yash。\nzsh 是一个命令行界面(CLI)的 shell,它支持许多命令行工具,包括 bash, fish, ksh, mksh, pdksh, tcsh, zsh, 和 yash。\nzsh 是一个命令行界面(CLI)的 shell,它支持许多命令行工具,包括 bash, fish, ksh, mksh, pdksh, tcsh, zsh, 和 yash。\nzsh 是一个命令行界面(CLI)的 shell,它支持'
和微调后的输出:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题。
问题:
我有一个信息科学相关的问题,请用中文回答,什么是 zsh
答案:
zsh(Z Shell)是一个POSIX兼容的shell,它在BSD/OS和Linux系统上被广泛使用。
zsh是Z shell的缩写,Z shell是Unix shell的一种,它继承了Bourne shell的特性,并增加了许多新的特性。
zsh的特性包括:
1.支持多级目录
2.支持命令别名
3.支持命令补全
4.支持命令历史
5.支持命令行参数
方法三:使用苹果 MLX 框架
该方法只适用于拥有苹果 M 系列 Mac 的同学。苹果的 MLX 框架 是专门针对苹果 M 系列芯片打造的机器学习框架,在 Apple Silicon 上性能非常好,能充分利用 Mac CPU/GPU 统一的内存高效训练 / 推理模型。有兴趣的同学可以去看 MLX 文档深入学习,在这里我们主要关注使用 MLX 对 gemma-2b 进行微调。
MLX 现在已经对多个流行的模型进行了支持,包括 Bert,Whisper, Stable Diffusion 等等,对 LLM 也对各种模型,如 Gemma, Llama, Mistral 等有了支持。在 LLM 微调文档 里,详细介绍了微调的步骤,我们只需要将数据整理成 MLX 需要的格式即可。MLX 的数据格式要求为 jsonl,即每一行为一个 json,且只有一个子域 “text”,例如:
-
MLX 框架https://github.com/ml-explore/mlx
-
LLM 微调文档https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md
{“text”: “你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题。\n问题:\n我有一个信息科学相关的问题,请用中文回答,什么是 简单波长分配协议\n答案:\n简单波长分配协议在波分复用技术中,根据流量状况建立交换通路和支持信息流的合并与疏导的一种信令协议。 \n当有一个流进入网络时,SWAP为其保持两种类型的状态信息:入口和出口状态信息。每一个输入支流都有它特有的入口状态,而且只有一个出口状态。入口状态的改变会激发更多的出口状态的改变,反之亦然。”}
{“text”: “你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题。\n问题:\n我有一个信息科学相关的问题,请用中文回答,什么是 多标签分类算法\n答案:\n多标签分类问题可以正式表述如下: \n假设有K类标签y={c1,c2,..,cK},给定网络G=(V,E,Y),其中V是顶点集,E是边集,Yi⊆y是顶点vi⊆V的类标签,并且已知道一些顶点vi∈VL(VL⊆V)的值,我们如何推断其余顶点VU=V-VL的Yi值(或针对每个标签的概率分布)?”}
{“text”: “你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题。\n问题:\n我有一个信息科学相关的问题,请用中文回答,什么是 顺序队列\n答案:\n顺序队列是队列的顺序存储结构,顺序队列实际上是运算受限的顺序表。和顺序表一样,顺序队列用一个向量空间来存放当前队列中的元素。由于队列的队头和队尾的位置是变化的,设置两个指针front和rear分别指示队头元素和队尾元素在向量空间中的位置,它们的初值在队列初始化时均应设置为0。”}
我们可以从 Hugging Face 上下载 Hello-SimpleAI/HC3-Chinese baike 子集的 jsonl 文件,然后运行这个脚本处理数据:
import json, jsonlines
context = "你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题。\n"
data = []
with open("baike.jsonl") as file:
for line in file:
features = json.loads(line)
template = context + "问题:\n{question}\n答案:\n{human_answers[0]}"
data.append({'text' : template.format(**features)})
# 4616 in total in baike split
train_data = data[:4600]
valid_data = data[4600:]
with jsonlines.open('train.jsonl', 'w') as writer:
writer.write_all(train_data)
with jsonlines.open('valid.jsonl', 'w') as writer:
writer.write_all(valid_data)
之后即可运行 MLX 进行微调:
python -m mlx_lm.lora --model google/gemma-2b --train --batch-size 2 --data <DATA_FOLDER> --iters 6001
完成后会生成一个 LoRA adapter.npz 文件,可以用 MLX 将它 fuse 进 base 模型并上传到 Hugging Face 上:
python -m mlx_lm.fuse \
--model <MODEL_FOLDER> \
--upload-repo mlx-community/gemma-2b-cn-it \
--hf-path google/gemma-2b
总结
今天我们介绍了 3 个针对谷歌 gemma-2b 模型进行中文指令微调的简单方法,这里展示的 所有的代码 都在 GitHub 上,并且配有有 2 个简短的视频 讲解。
-
本文展示的代码https://github.com/windmaple/Gemma-Chinese-instruction-tuning
-
两个讲解视频https://www.bilibili.com/video/BV14x4y1C7wi/https://www.bilibili.com/video/BV1YH4y177t9/
当然这里我们使用的数据,模型和算力都是比较小的,所以完成以后的模型性能肯定不能和最先进的模型相比,我们更重要的是分享思路方便大家学习。如果有足够资源的同学可以自己去探索收集更多的中文数据,使用 7b 模型,使用多 GPU/TPU 分布式训练等途径来打造更强大的自由模型。
最后,我们建立了一个 Gemma 中文群,欢迎对 Gemma 有兴趣的同学加入。如果人数已满或者二维码过期请添加群主微信 hustwindmaple 后拉入。
请先我们的群聊行为守则 ,扫码入群即视为你认同并将遵守我们下面列出的行为守则:
-
💬 请只讨论与群聊主题相关的内容,勿讨论任何违反法律和 Hugging Face 社区规定的内容
-
📢 请勿将群用于 AI 文章链接 / 活动群发渠道
-
🧑🤝🧑 群内加好友前请在群里使用 at 的方式请求同意
-
🧹 以上情况如有一次发生即视为违规,违规者会被移出所有群聊
-
🚔 如果你发现自己的隐私被侵犯,请直接拨打 110 报警
更多AI工具,参考Github-AiBard123,国内AiBard123