如何快速地设计并评估fewshot示例的效果:OpenICL上下文示例学习框架推荐及实现源码
作者: 老刘说NLP 来源: 老刘说NLP
今天是2023年11月18日,星期六,北京,天气晴。
在昨天的文章中,我们谈了谈**《IN-CONTEXT LEARNING WITH ITERATIVE DEMON- STRATION SELECTION》** 这一工作,该工作提出了迭代示例选择(IDS)的方案,IDS利用Zero-shot-CoT,迭代选择多样化但仍与测试样本密切相关的示例作为ICL示例,并对比了不同构造方式的效果。
常用的ICL示例构造方式如下:
而最近在做评估方面的工作,在看Opencompass源码的时候,看到了其中对于prompt的一些生成策略,主要来自于OpenICL框架,OpenICL是一个用于上下文学习(In-context learning)工具包,同时也是一个LLLM评估的开源工具包,其中具体实现方式以及源码读了一通,很有趣。
这对我们 如何快速地设计并评估few shot示例的效果很有帮助。推荐大家看看,其中还牵引出了好几个构造方案,本文对这些工作进行介绍,供大家一起参考。
一、OpenICL上下文示例学习框架
传统的微调方法不同,ICL采用预先训练的模式,适应了无形的任务,而没有任何参数更新。然而,由于所涉及的检索和推断方法多种多样,不同模型、数据集和任务都有不同的预处理要求,ICL的实施十分复杂。迫切需要为ICL建立一个统一灵活的框架,以方便上述组成部分的执行。
《OpenICL: An Open-Source Framework for In-context Learning》这一工作,地址https://github.com/Shark-NLP/OpenICL项目。论文地址:https://arxiv.org/abs/2303.02913 。
基本的ICL流程如上图所示,OpenICL首先通过用户指定的检索方法(例如TopK或VoteK)从索引集中获取适当的上下文示例,针对每个测试输入或整个测试集。然后,根据提供的提示模板,将上下文示例和测试输入连接成单个序列。最后,所有提示都被馈送到语言模型中,通过定义的推理策略(例如Chain-of-thought)推断输出。
1、Retriever模块
Retriever负责从预先存在的训练数据中检索上下文示例。该模块支持语料库级别的(即仅为整个测试集检索一组示例)和实例级别的(即为每个测试输入分别检索示例)检索方法。包括以下方法:
随机方法(Random): 使用随机方法来选择例子构建上下文。虽然随机方法的性能变化范围较大,但在只有少量演示可用时仍然是常用的选择。
启发式方法(Heuristic method) :为了克服随机方法的缺点,包括基于语义相似性的检索方法,例如 BM25、TopK和 VoteK,已经显示出很大的潜力。
基于模型的方法(Model-based method) :使用模型对输出的置信度来选择和排序示例,例如熵(Entropy)和 MDL(https://aclanthology.org/2023.acl-long.79.pdf)。
2、Inferencer模块
Inferencer会调用预训练语言模型,基于上下文示例和测试输入的串联生成预测。
主要包括两种,
一种是直接方法,使用词汇表中的标记来表示候选答案,并选择概率最高的作为最终预测结果。
另一种是困惑度方法,Brown等人计算输入和候选答案的序列连接的句子困惑度,并选择困惑度最低的作为最终预测结果。
例如,对于ppl方式,ppl方式旨在解决base模型指令不遵循以及生成方式评估太慢的问题,可以应用于选择题以及生成任务上。
假设由ABCD四个选项。对于一个样本,构造四个prompt,分别由问题+”答案:“+[ABCD]中的一个组成。然后将这四个prompt分别输入到大模型中,调用大模型生成logits。
huggingface中关于ppl的实现:
import torch
from tqdm import tqdm
max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
# to the left by 1.
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
其思想在于:遵循了滑窗的策略,即设置一个窗口,按照某个步长向右滑动,每次在窗口内计算ppl,直到窗口滑到结尾,然后将所有的ppl进行求平均得到最终的ppl。
3、几个例子
通过串联rtriver,inference、lm、metric等,可以快速地进行模型评估验证。
1)采用PPL推理策略的 GPT2-XL (1.5B)在SST-2数据集上的ICL性能。
2) 基于直接生成方式-bleu推理策略评估XGLM (7.5B)在WMT16 (de-en) 数据集上的ICL性能
3)评估GPT-3 (175B) text-davinci-003 版本在GSM8K数据集上使用思维链推理策略的ICL性能
二、Opencompass中几个内置的fewshot源码实现
Opencompass中几个内置的fewshot主要依靠retriever模块,主要用于大模型从数据集中选择N个固定数量的样本加入到prompt中。
其预置了多种不同的样本选择方式。这个和基于知识库的问答应用差不多,主要是通过一些方法选择最适合当前数据样本的上下文召回,帮助大模型更好得理解当前数据样本。
典型的方法有:基于BM25的检索方式、基于KNN的检索方式、随机采样方式等,当然也有zeroshot方式,即不带任何先验样本。
地址https://github.com/open-compass/opencompass/blob/main/opencompass/openicl/icl_retriever中提供了bm25、random、topk、votek以及dpp等多种构造方法的实现。
1、重要参数
数据集(BaseDataset ): 任何 BaseDataset 实例。将使用reader 、train 和test 的属性。
ice_separator (可选[str] ): 每个上下文示例模板之间的分隔符。默认值为’\n’。
ice_eos_token (Optional[str] ):上下文示例模板的句末标记。默认为"\n"。
ice_num (Optional[int] ):上下文示例模板的数量。默认为 1。
index_split(Optional[str] ): 用于检索内示例索引的数据集分割。默认为 “train”。
test_split(Optional[str] ): 数据集的分割,用于检索示例样本。默认为 “test”。
2、zeroRetriver方案
zeroshot检索,所有示例均置为空。
class ZeroRetriever(BaseRetriever):
"""Zeroshot Retriever. The retriever returns empty list for all queries.
"""
def __init__(self, dataset, ice_eos_token: Optional[str] = '') -> None:
super().__init__(dataset, '', ice_eos_token, 0)
def retrieve(self, id_list: List[int] = None) -> List[List]:
if id_list is not None:
get_logger().warning('id_list is not empty, but will be ignored.')
rtr_idx_list = [[] for _ in range(len(self.test_ds))]
return rtr_idx_list
3、FixKRetriever方案
测试提示的每个上下文示例都是从索引集中检索出相同的K个示例。
class FixKRetriever(BaseRetriever):
"""
Fix-K Retriever. Each in-context example of the test prompts is
retrieved as the same K examples from the index set.
"""
def __init__(self,
dataset,
fix_id_list: List[int],
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.fix_id_list = fix_id_list
def retrieve(self):
"""Retrieve the in-context example index for each test example."""
num_idx = len(self.index_ds)
for idx in self.fix_id_list:
assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'
rtr_idx_list = []
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
rtr_idx_list.append(self.fix_id_list)
return rtr_idx_list
4、RandomRetriever方案
RandomRetriever以随机方式检索测试提示的每个上下文示例。
class RandomRetriever(BaseRetriever):
"""Random Retriever. Each in-context example of the test prompts is
retrieved in a random way.
"""
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1,
seed: Optional[int] = 43) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.seed = seed
def retrieve(self):
np.random.seed(self.seed)
num_idx = len(self.index_ds)
rtr_idx_list = []
logger.info('Retrieving data for test set...')
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
idx_list = np.random.choice(num_idx, self.ice_num,
replace=False).tolist()
rtr_idx_list.append(idx_list)
return rtr_idx_list
5、BM25Retriever方案
BM25Retriever中测试提示的每个测试提示的每个上下文示例都是通过BM25算法检索的。
class BM25Retriever(BaseRetriever):
"""
Random Retriever. Each in-context example of the test prompts is
retrieved in a random way.
"""
bm25 = None
index_corpus = None
test_corpus = None
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.index_corpus = [
word_tokenize(data) for data in
self.dataset_reader.generate_input_field_corpus(self.index_ds)
]
self.bm25 = BM25Okapi(self.index_corpus)
self.test_corpus = [
word_tokenize(data) for data in
self.dataset_reader.generate_input_field_corpus(self.test_ds)
]
def retrieve(self) -> List[List]:
"""Retrieve the in-context example index for each test example."""
rtr_idx_list = []
logger.info('Retrieving data for test set...')
for idx in trange(len(self.test_corpus),
disable=not self.is_main_process):
query = self.test_corpus[idx]
scores = self.bm25.get_scores(query)
near_ids = list(np.argsort(scores)[::-1][:self.ice_num])
near_ids = [int(a) for a in near_ids]
rtr_idx_list.append(near_ids)
return rtr_idx_list
6、TopkRetriever方案
TopkRetriever,使用 基本knn实现。SentenceTransformer用于计算嵌入。Faiss 用于进行近邻搜索
class TopkRetriever(BaseRetriever):
model = None
"""Base class for Topk In-context Learning Retriever, implemented with
basic knn. SentenceTransformer is used to calculate embeddings. Faiss is
used to do the nearest neighbor search.
"""
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1,
sentence_transformers_model_name: Optional[
str] = 'all-mpnet-base-v2',
tokenizer_name: Optional[str] = 'gpt2-xl',
batch_size: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.batch_size = batch_size
self.tokenizer_name = tokenizer_name
gen_datalist = self.dataset_reader.generate_input_field_corpus(
self.test_ds)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.tokenizer.padding_side = 'right'
self.encode_dataset = DatasetEncoder(gen_datalist,
tokenizer=self.tokenizer)
co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,
device=self.device)
self.dataloader = DataLoader(self.encode_dataset,
batch_size=self.batch_size,
collate_fn=co)
self.model = SentenceTransformer(sentence_transformers_model_name)
self.model = self.model.to(self.device)
self.model.eval()
self.index = self.create_index()
## 使用knn进行索引
def create_index(self):
import faiss
self.select_datalist = self.dataset_reader.generate_input_field_corpus(
self.index_ds)
encode_datalist = DatasetEncoder(self.select_datalist,
tokenizer=self.tokenizer)
co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,
device=self.device)
dataloader = DataLoader(encode_datalist,
batch_size=self.batch_size,
collate_fn=co)
index = faiss.IndexIDMap(
faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension()))
res_list = self.forward(dataloader,
process_bar=True,
information='Creating index for index set...')
id_list = np.array([res['metadata']['id'] for res in res_list])
self.embed_list = np.stack([res['embed'] for res in res_list])
index.add_with_ids(self.embed_list, id_list)
return index
## 使用knn进行搜索
def knn_search(self, ice_num):
res_list = self.forward(self.dataloader,
process_bar=True,
information='Embedding test set...')
rtr_idx_list = [[] for _ in range(len(res_list))]
logger.info('Retrieving data for test set...')
for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
idx = entry['metadata']['id']
embed = np.expand_dims(entry['embed'], axis=0)
near_ids = self.index.search(embed, ice_num)[1][0].tolist()
rtr_idx_list[idx] = near_ids
return rtr_idx_list
## 利用sentence transformer进行embedding
def forward(self, dataloader, process_bar=False, information=''):
res_list = []
_dataloader = copy.deepcopy(dataloader)
if process_bar:
logger.info(information)
_dataloader = tqdm.tqdm(_dataloader,
disable=not self.is_main_process)
for _, entry in enumerate(_dataloader):
with torch.no_grad():
metadata = entry.pop('metadata')
raw_text = self.tokenizer.batch_decode(
entry['input_ids'],
skip_special_tokens=True,
verbose=False)
res = self.model.encode(raw_text, show_progress_bar=False)
res_list.extend([{
'embed': r,
'metadata': m
} for r, m in zip(res, metadata)])
return res_list
def retrieve(self):
"""Retrieve the in-context example index for each test example."""
return self.knn_search(self.ice_num)
总结
本文主要介绍了OpenICL上下文示例学习框架的几种选择策略以及opencompass的实现,这对我们加强对prompt的认知有很大的帮助。
实际上,这些prompt的挑选,与我们现在进行大模型文档中相关文档的召回有很大的相似性,而如何快速的比对不同召回方式,快速地进行评估,进行模块化也是我们需要关注的问题。两种存在着本质上的共性。
参考文献
1、https://github.com/open-compass/opencompass
2、https://github.com/Shark-NLP/OpenICL
3、https://arxiv.org/abs/2303.02913
关于我们
老刘,刘焕勇,NLP开源爱好者与践行者,主页:https://liuhuanyong.github.io。
老刘说NLP,将定期发布语言资源、工程实践、技术总结等内容,欢迎关注。
对于想加入更优质的知识图谱、事件图谱、大模型AIGC实践、相关分享的,可关注公众号,在后台菜单栏中点击会员社区->会员入群加入。
更多AI工具,参考Github-AiBard123,国内AiBard123