AI 文摘

Proxy-Tuning:大模型无须调整权重的大幅提高效果的调优方法





作者: NLP前沿 来源: NLP前沿

这里只放了原理,学习一下方法,还有2个示例code没放过来

https://lightning.ai/lightning-ai/studios/improve-llms-with-proxy-tuning  
https://arxiv.org/abs/2401.08565  

Proxy-tuning 是一种在不改变模型权重的情况下调整LLM的方法。如果给定的LLM训练资源过于昂贵,或者用户无法访问LLM的权重,这种方法尤其具有吸引力。举几个具体的示例:

  • 假设目前还不存在Llama 2 70B Chat模型。相反,我们只有Llama 2 70B基础模型。Proxy-tuning使我们能够使这个基础模型表现得和聊天模型一样好,而无需改变基础模型的权重。

  • 在比如,我们有2个7B模型,通过proxy-tunning可以达到13B模型的效果。

  • 又或者我们把13B模型的code能力注入到7B的模型里边

Understanding Proxy-Tuning

Proxy-tuning 提供了一个目标 LLM,具有经过调优的版本的能力,而实际上并没有对其进行调优。如下图

下面是步骤:

  • 选择一个比目标LLM(例如,未调整的70B Llama 2模型)更小更便宜的基础LLM(例如,未调整的7B Llama 2模型)

  • 微调这个较小的基础LLM模型,以获得一个小型的微调LLM模型(例如,对一个7B Llama 2模型进行指令微调,以获得一个微调后的7B模型)。

  • 计算基本模型(步骤1)和调整模型(步骤2)之间的输出差异。

  • 将这个差异加到目标LLM的输出上

  • 将第4步中修改后的输出进行规范化处理,然后生成答案。

如果上面的步骤过于抽象,可以考虑下面这个PyTorch伪代码的具体示例:

generated_tokens = []  
  
input_txt = (  
  "If I have 5 apples and eat 2, but then find 3 more"  
  " on my way home, how many do I have?"  
)  
input_ids = tokenizer.encode(input_text)  
  
for _ in range(max_length):  
    # Obtain logits  
    logits_base = model_base(input_ids).logits # Llama 7B Base  
    logits_tuned = model_tuned(input_ids).logits # Llama 7B Chat  
    logits_target = model_target(input_ids).logits # Llama 70B Base  
                                 
    # Apply proxy-tuning                              
    logits = (  
        logits_target + (logits_tuned - logits_base)  
    )  
  
    # Normalize logits and obtain token  
    predictions = torch.softmax(logits[:, -1, :], dim=-1)  
    next_token_id = torch.argmax(predictions).unsqueeze(0)  
    generated_tokens.append(next_token_id.item())  
      
generated_text = tokenizer.decode(generated_tokens)  
print(generated_text)  
  
# Output:   
# You start with 5 apples and eat 2,  
# so you have 5 - 2 = 3 apples left.  
# Then, you find 3 more apples on your way home,   
# so you have 3 + 3 = 6 apples in total.  

上面的代码片段获取了编码输入文本,并使用三种模型分别获得了每个输出token的logits。然后,使用之前描述的Proxy-tuning逻辑,计算了调整模型和基础模型之间logits的差异。然后,将这种差异应用到目标模型的logits上。随后,我们像往常一样对logits进行归一化,并对下一个标记进行采样。(这里,我们使用PyTorch的贪婪采样,选择具有最高概率的标记,但我们也可以使用其他采样技术,比如top-k采样或nucleus采样。)logits_tuned - logits_base argmax

根据原论文,Proxy-tuning效果非常好。

例如,在AlpacaFarm和GSM数据集上,Proxy-tuning 70B Llama 2基础模型导致了显著的性能提升(AlpacaFarm为88.0%,GSM为32.0%)。此外,在TruthfulQA数据集上,代理调整的模型比直接调整的模型更真实。这种方法在领域特定任务中也非常有效,比如编码和任务特定调整。

更多AI工具,参考Github-AiBard123国内AiBard123

可关注我们的公众号:每天AI新工具