Proxy-Tuning:大模型无须调整权重的大幅提高效果的调优方法
作者: NLP前沿 来源: NLP前沿
“
这里只放了原理,学习一下方法,还有2个示例code没放过来
https://lightning.ai/lightning-ai/studios/improve-llms-with-proxy-tuning
https://arxiv.org/abs/2401.08565
Proxy-tuning 是一种在不改变模型权重的情况下调整LLM的方法。如果给定的LLM训练资源过于昂贵,或者用户无法访问LLM的权重,这种方法尤其具有吸引力。举几个具体的示例:
-
假设目前还不存在Llama 2 70B Chat模型。相反,我们只有Llama 2 70B基础模型。Proxy-tuning使我们能够使这个基础模型表现得和聊天模型一样好,而无需改变基础模型的权重。
-
在比如,我们有2个7B模型,通过proxy-tunning可以达到13B模型的效果。
-
又或者我们把13B模型的code能力注入到7B的模型里边
Understanding Proxy-Tuning
Proxy-tuning 提供了一个目标 LLM,具有经过调优的版本的能力,而实际上并没有对其进行调优。如下图
下面是步骤:
-
选择一个比目标LLM(例如,未调整的70B Llama 2模型)更小更便宜的基础LLM(例如,未调整的7B Llama 2模型)
-
微调这个较小的基础LLM模型,以获得一个小型的微调LLM模型(例如,对一个7B Llama 2模型进行指令微调,以获得一个微调后的7B模型)。
-
计算基本模型(步骤1)和调整模型(步骤2)之间的输出差异。
-
将这个差异加到目标LLM的输出上
-
将第4步中修改后的输出进行规范化处理,然后生成答案。
如果上面的步骤过于抽象,可以考虑下面这个PyTorch伪代码的具体示例:
generated_tokens = []
input_txt = (
"If I have 5 apples and eat 2, but then find 3 more"
" on my way home, how many do I have?"
)
input_ids = tokenizer.encode(input_text)
for _ in range(max_length):
# Obtain logits
logits_base = model_base(input_ids).logits # Llama 7B Base
logits_tuned = model_tuned(input_ids).logits # Llama 7B Chat
logits_target = model_target(input_ids).logits # Llama 70B Base
# Apply proxy-tuning
logits = (
logits_target + (logits_tuned - logits_base)
)
# Normalize logits and obtain token
predictions = torch.softmax(logits[:, -1, :], dim=-1)
next_token_id = torch.argmax(predictions).unsqueeze(0)
generated_tokens.append(next_token_id.item())
generated_text = tokenizer.decode(generated_tokens)
print(generated_text)
# Output:
# You start with 5 apples and eat 2,
# so you have 5 - 2 = 3 apples left.
# Then, you find 3 more apples on your way home,
# so you have 3 + 3 = 6 apples in total.
上面的代码片段获取了编码输入文本,并使用三种模型分别获得了每个输出token的logits。然后,使用之前描述的Proxy-tuning逻辑,计算了调整模型和基础模型之间logits的差异。然后,将这种差异应用到目标模型的logits上。随后,我们像往常一样对logits进行归一化,并对下一个标记进行采样。(这里,我们使用PyTorch的贪婪采样,选择具有最高概率的标记,但我们也可以使用其他采样技术,比如top-k采样或nucleus采样。)logits_tuned - logits_base argmax
根据原论文,Proxy-tuning效果非常好。
例如,在AlpacaFarm和GSM数据集上,Proxy-tuning 70B Llama 2基础模型导致了显著的性能提升(AlpacaFarm为88.0%,GSM为32.0%)。此外,在TruthfulQA数据集上,代理调整的模型比直接调整的模型更真实。这种方法在领域特定任务中也非常有效,比如编码和任务特定调整。
更多AI工具,参考Github-AiBard123,国内AiBard123