AI 文摘

Hands-onLLM谷歌Gemma模型分布式微调和推理





作者: Mist君的风控与数据成长路 来源: Mist君的风控与数据成长路

在上一篇推文中,介绍了如何在Kaggle平台使用并微调 Gemma 模型,本文将介绍在Colab平台如何对Gemma模型进行分布式微调,由Google技术文档及Build with AI in Shanghai Coding Time 使用的Notebook代码(colabtools/notebooks/Gemma_Distributed_Fine_tuning_on_TPU.ipynb at main · googlecolab/colabtools · GitHub,可点击阅读原文查看或下载或导入您的Colab中)整理而成。

使用 Keras 对 Gemma 进行分布式调优

*概述

Gemma 是一个轻量级、最先进的开放模型系列,由用于创建 Google Gemini 模型的研究和技术构建而成。Gemma 可以进一步微调以满足特定需求。但是大型语言模型LLM(如 Gemma)的尺寸可能非常大,其中一些可能不适合在单个加速器上进行微调。在这种情况下,有两种通用方法可以对它们进行微调:

  • 参数高效微调 (PEFT),旨在通过牺牲一些保真度来缩小有效模型大小。LoRA 属于这一类,使用 LoRA 在 Keras 中微调 Gemma 模型教程演示了如何在单个 GPU 上使用 KerasNLP 微调 Gemma 7B 模型gemma_instruct_7b_en与 LoRA。

  • 通过模型并行性进行全参数微调。模型并行性将单个模型的权重分布在多个设备上,并实现水平缩放。您可以在此 Keras 指南中找到有关分布式训练的更多信息。

本教程将指导您使用 Keras 和 JAX 后端来微调 Gemma 7B 模型,并在 Google 的张量处理单元 (TPU) 上进行 LoRA 和模型视变分布式训练。请注意,在本教程中可以关闭 LoRA,以实现更慢但更准确的全参数优化。

*使用加速器

从技术上讲,您可以在本教程中使用 TPU 或 GPU。

  • 关于 TPU 环境的注意事项

Google 有 3 种提供 TPU 的产品:

  • Colab 免费提供 TPU v2,适用于本教程。

  • Kaggle 免费提供 TPU v3,这也适用于本教程。

  • Cloud TPU 提供 TPU v3 和更新版本。一种设置方法是:

  1. 创建新的 TPU VMCreate a new TPU VM

  2. 为预期的 Jupyter 服务器端口设置 SSH 端口转发

  3. 安装 Jupyter 并在 TPU VM 上启动它,然后通过“连接到本地运行时”连接到 Colab。请参见:https://research.google.com/colaboratory/local-runtimes.html

多 GPU 设置注意事项

虽然本教程重点介绍 TPU 用例,但如果您拥有多 GPU 计算机,则可以轻松地根据自己的需求对其进行调整。

如果您更喜欢使用 Colab,也可以直接通过 Colab Connect 菜单中的“连接到自定义 GCE VM”为 Colab 预配多 GPU VM。

*准备工作

  • Kaggle 凭据

Gemma 模特由 Kaggle 托管。要使用 Gemma,请在 Kaggle 上请求访问权限:

  • 登录或注册 kaggle.com

  • 打开 Gemma 模型卡并选择“请求访问”

  • 填写同意书并接受条款和条件

然后,要使用 Kaggle API,请创建一个 API 令牌:

  • 打开 Kaggle 设置

  • 选择“创建新令牌”

  • 将下载kaggle.json文件。它包含您的 Kaggle 凭据


import os
from google.colab import userdata
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

*安装

使用 Gemma 模型安装 Keras 和 KerasNLP。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install tensorflow-cpu~=2.16.0 keras-nlp==0.8.2 tensorflow-hub==0.16.1 keras==3.0.5 tensorflow-text==2.16.1

*设置 Keras JAX 后端

导入 JAX 并对 TPU 运行健全性检查。Colab 提供 TPUv2-8 设备,这些设备具有 8 个 TPU 内核和 8GB 高带宽内存。


import jax
  

jax.devices()






[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),  
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),  
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),  
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),  
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),  
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),  
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),  
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

import os
  

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

*导入模型

import keras
import keras_nlp
  • 关于在 NVIDIA GPU 上进行混合精度训练的说明

  • 在NVIDIA GPU 上训练时,可以使用混合精度 (keras.mixed_precision.set_global_policy(‘mixed_bfloat16’)) 来加快训练速度,同时对训练质量的影响最小。在大多数情况下,建议打开混合精度,因为它可以节省内存和时间。但是,请注意,在小批量中,它可能会将内存使用量增加 1.5 倍(权重将加载两次,半精度和全精度)。

  • 对于推理,半精度 (keras.config.set_floatx(“bfloat16”)) 将起作用并节省内存,而混合精度不适用。

keras.config.set_floatx("bfloat16")

要使用分布在 TPU 上的权重和张量加载模型,请首先创建一个新的 DeviceMesh。DeviceMesh 表示为分布式计算配置的硬件设备集合,并在 Keras 3 中作为统一分布式 API 的一部分引入。

分布式 API 支持数据和模型并行性,允许在多个加速器和主机上高效扩展深度学习模型。它利用底层框架(例如 JAX)通过称为单程序、多数据 (SPMD) 扩展的过程根据分片指令分发程序和张量。在新的 Keras 3 发行版 API 指南中查看更多详细信息。


# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

分布式 API 中的LayoutMap 指定应如何使用字符串键(例如下面的 token_embedding/embeddings)对权重和张量进行分片或复制,这些键被视为正则表达式以匹配张量路径。匹配的张量使用模型维度(8 TPU)进行分片;其他的将被完全复制。


model_dim = "model"
  

  

layout_map = keras.distribution.LayoutMap(device_mesh)
  

  

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (None, model_dim)
  

  

# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    None, model_dim, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
    None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

****ModelParallel 允许您在 DeviceMesh 上的所有开发中对模型权重或激活张量进行分片。在这种情况下,根据上述定义的layout_map,一些 Gemma 7B 模型权重被分片到 8 个 TPU 内核。现在以分布式方式加载模型。


model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")
  

  

keras.distribution.set_distribution(model_parallel)
  

  

# Download the Gemma 7B model.
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_7b_en")

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_instruct_7b_en/2/download/config.json… 100%|██████████| 552/552 [00:00<00:00, 529kB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_instruct_7b_en/2/download/model.weights.h5… 100%|██████████| 15.9G/15.9G [03:24<00:00, 83.6MB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_instruct_7b_en/2/download/tokenizer.json… 100%|██████████| 401/401 [00:00<00:00, 519kB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_instruct_7b_en/2/download/assets/tokenizer/vocabulary.spm...100%|██████████| 4.04M/4.04M [00:00<00:00, 17.0MB/s]

现在验证模型是否已正确分区。我们以decoder_block_1为例。


decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')

<class ‘keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock’> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, ‘model’, None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, ‘model’, None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, ‘model’, None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, ‘model’) decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec(‘model’, None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec(‘model’, None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, ‘model’)

*微调前的推理

gemma_lm.generate("Best comedy movies: ", max_length=64)

‘Best comedy movies: \n\n1.The Hangover (2009)\n2.Bridesmaids (2011)\n3.Anchorman (2007)\n4.Superbad (2007)\n5.**The Room (2’

该模型生成了 90 年代要观看的精彩喜剧电影列表。现在,我们对 Gemma 模型进行微调以更改输出样式。

*使用IMDB进行微调


import tensorflow_datasets as tfds
  

  

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
  

  

imdb_train.unbatch().take(1).get_single_element().numpy()

Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0…

Dl Completed…: 100%

1/1 [00:08<00:00, 8.37s/ url]

Dl Size…: 100%

80/80 [00:08<00:00, 11.08 MiB/s]

Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.  






b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

使用低秩自适应 (LoRA) 执行微调。

LoRA 是一种微调技术,它通过冻结模型的全部权重并在模型中插入少量新的可训练权重,大大减少了下游任务的可训练参数数量。基本上,LoRA 通过 2 个较小的低秩矩阵 AxB 重新参数化较大的全权重矩阵进行训练,这种技术使训练速度更快、记忆效率更高。

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

# Fine-tune on the IMDb movie reviews dataset.
  

  

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
  

  

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)






Preprocessor: "gemma_causal_lm_preprocessor"  


┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  
┃ Tokenizer (type)                                   ┃                                             Vocab # ┃  
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩  
│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │  
└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘  


Model: "gemma_causal_lm"  


┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  
┃ Layer (type)                  ┃ Output Shape              ┃         Param # ┃ Connected to               ┃  
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩  
│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │  
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤  
│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │  
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤  
│ gemma_backbone                │ (None, None, 3072)        │   8,548,748,288 │ padding_mask[0][0],        │  
│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │  
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤  
│ token_embedding               │ (None, None, 256000)      │     786,432,000 │ gemma_backbone[0][0]       │  
│ (ReversibleEmbedding)         │                           │                 │                            │  
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘  


 Total params: 8,548,748,288 (15.92 GB)  


 Trainable params: 11,067,392 (21.11 MB)  


 Non-trainable params: 8,537,680,896 (15.90 GB)  


2000/2000 ━━━━━━━━━━━━━━━━━━━━ 324s 141ms/step - loss: 14.7736 - sparse_categorical_accuracy: 0.4807  


<keras.src.callbacks.history.History at 0x7e7e3c589cc0>

注意到,启用 LoRA 会显著减少可训练参数的数量,从 70 亿个减少到仅 1100 万个。

*微调后的推理

gemma_lm.generate("Best comedy movies: ", max_length=256)

‘Best comedy movies: \n\n1.The Hangover (2009)\n2.Bridesmaids (2011)\n3.Anchorman (2007)\n4.Superbad (2007)\n5.Tallahassee (2009)\n\nWhat is the common thread between the movies on this list?\n\nThey are all American comedy films.’

经过微调,该模型已经学会了电影评论的风格,现在正在 90 年代喜剧电影的背景下以这种风格生成输出。

*后续步骤

在本教程中,您学习了如何使用 KerasNLP JAX 后端在强大的 TPU 上以分布式方式微调 IMDb 数据集上的 Gemma 模型。以下是一些关于其他学习内容的建议:

  • 了解如何开始使用 Keras Gemma。

  • 了解如何在 GPU 上微调 Gemma 模型。

更多AI工具,参考Github-AiBard123国内AiBard123

可关注我们的公众号:每天AI新工具