AI 文摘

如何快速地设计并评估fewshot示例的效果:OpenICL上下文示例学习框架推荐及实现源码


  • By AiBard123
  • November 20, 2023 - 2 min read



作者: 老刘说NLP 来源: 老刘说NLP

今天是2023年11月18日,星期六,北京,天气晴。

在昨天的文章中,我们谈了谈**《IN-CONTEXT LEARNING WITH ITERATIVE DEMON- STRATION SELECTION》** 这一工作,该工作提出了迭代示例选择(IDS)的方案,IDS利用Zero-shot-CoT,迭代选择多样化但仍与测试样本密切相关的示例作为ICL示例,并对比了不同构造方式的效果。

常用的ICL示例构造方式如下:

而最近在做评估方面的工作,在看Opencompass源码的时候,看到了其中对于prompt的一些生成策略,主要来自于OpenICL框架,OpenICL是一个用于上下文学习(In-context learning)工具包,同时也是一个LLLM评估的开源工具包,其中具体实现方式以及源码读了一通,很有趣。‍‍‍‍‍‍‍‍‍‍‍

这对我们‍‍ 如何快速地设计并评估few shot示例的效果很有帮助。推荐大家看看,其中还牵引出了好几个构造方案,本文对这些工作进行介绍,供大家一起参考。

一、OpenICL上下文示例学习框架

传统的微调方法不同,ICL采用预先训练的模式,适应了无形的任务,而没有任何参数更新。然而,由于所涉及的检索和推断方法多种多样,不同模型、数据集和任务都有不同的预处理要求,ICL的实施十分复杂。迫切需要为ICL建立一个统一灵活的框架,以方便上述组成部分的执行。

《OpenICL: An Open-Source Framework for In-context Learning》这一工作,地址https://github.com/Shark-NLP/OpenICL项目。论文地址:https://arxiv.org/abs/2303.02913

基本的ICL流程如上图所示,OpenICL首先通过用户指定的检索方法(例如TopK或VoteK)从索引集中获取适当的上下文示例,针对每个测试输入或整个测试集。然后,根据提供的提示模板,将上下文示例和测试输入连接成单个序列。最后,所有提示都被馈送到语言模型中,通过定义的推理策略(例如Chain-of-thought)推断输出。

1、Retriever模块

Retriever负责从预先存在的训练数据中检索上下文示例。该模块支持语料库级别的(即仅为整个测试集检索一组示例)和实例级别的(即为每个测试输入分别检索示例)检索方法。包括以下方法:

随机方法(Random): 使用随机方法来选择例子构建上下文。虽然随机方法的性能变化范围较大,但在只有少量演示可用时仍然是常用的选择。

启发式方法(Heuristic method) :为了克服随机方法的缺点,包括基于语义相似性的检索方法,例如 BM25、TopK和 VoteK,已经显示出很大的潜力。

基于模型的方法(Model-based method) :使用模型对输出的置信度来选择和排序示例,例如熵(Entropy)和 MDL(https://aclanthology.org/2023.acl-long.79.pdf)。

2、Inferencer模块

Inferencer会调用预训练语言模型,基于上下文示例和测试输入的串联生成预测。

主要包括两种,

一种是直接方法,使用词汇表中的标记来表示候选答案,并选择概率最高的作为最终预测结果。

另一种是困惑度方法,Brown等人计算输入和候选答案的序列连接的句子困惑度,并选择困惑度最低的作为最终预测结果。

例如,对于ppl方式,ppl方式旨在解决base模型指令不遵循以及生成方式评估太慢的问题,可以应用于选择题以及生成任务上。

假设由ABCD四个选项。对于一个样本,构造四个prompt,分别由问题+”答案:“+[ABCD]中的一个组成。然后将这四个prompt分别输入到大模型中,调用大模型生成logits。

huggingface中关于ppl的实现:

import torch  
from tqdm import tqdm  
  
max_length = model.config.n_positions  
stride = 512  
seq_len = encodings.input_ids.size(1)  
  
nlls = []  
prev_end_loc = 0  
for begin_loc in tqdm(range(0, seq_len, stride)):  
    end_loc = min(begin_loc + max_length, seq_len)  
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop  
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)  
    target_ids = input_ids.clone()  
    target_ids[:, :-trg_len] = -100  
  
    with torch.no_grad():  
        outputs = model(input_ids, labels=target_ids)  
  
        # loss is calculated using CrossEntropyLoss which averages over valid labels  
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels  
        # to the left by 1.  
        neg_log_likelihood = outputs.loss  
  
    nlls.append(neg_log_likelihood)  
  
    prev_end_loc = end_loc  
    if end_loc == seq_len:  
        break  
  
ppl = torch.exp(torch.stack(nlls).mean())  

其思想在于:遵循了滑窗的策略,即设置一个窗口,按照某个步长向右滑动,每次在窗口内计算ppl,直到窗口滑到结尾,然后将所有的ppl进行求平均得到最终的ppl。

3、几个例子

通过串联rtriver,inference、lm、metric等,可以快速地进行模型评估验证。

1)采用PPL推理策略的 GPT2-XL (1.5B)在SST-2数据集上的ICL性能。

2) 基于直接生成方式-bleu推理策略评估XGLM (7.5B)在WMT16 (de-en) 数据集上的ICL性能

3)评估GPT-3 (175B) text-davinci-003 版本在GSM8K数据集上使用思维链推理策略的ICL性能

二、Opencompass中几个内置的fewshot源码实现

Opencompass中几个内置的fewshot主要依靠retriever模块,主要用于大模型从数据集中选择N个固定数量的样本加入到prompt中。

其预置了多种不同的样本选择方式。这个和基于知识库的问答应用差不多,主要是通过一些方法选择最适合当前数据样本的上下文召回,帮助大模型更好得理解当前数据样本。

典型的方法有:基于BM25的检索方式、基于KNN的检索方式、随机采样方式等,当然也有zeroshot方式,即不带任何先验样本。

地址https://github.com/open-compass/opencompass/blob/main/opencompass/openicl/icl_retriever中提供了bm25、random、topk、votek以及dpp等多种构造方法的实现。

1、重要参数

数据集(BaseDataset ): 任何 BaseDataset 实例。将使用reader 、train 和test 的属性。

ice_separator (可选[str] ): 每个上下文示例模板之间的分隔符。默认值为’\n’。

ice_eos_token (Optional[str] ):上下文示例模板的句末标记。默认为"\n"。

ice_num (Optional[int] ):上下文示例模板的数量。默认为 1。

index_split(Optional[str] ): 用于检索内示例索引的数据集分割。默认为 “train”。

test_split(Optional[str] ): 数据集的分割,用于检索示例样本。默认为 “test”。

2、zeroRetriver方案

zeroshot检索,所有示例均置为空。

class ZeroRetriever(BaseRetriever):  
    """Zeroshot Retriever. The retriever returns empty list for all queries.  
    """  
    def __init__(self, dataset, ice_eos_token: Optional[str] = '') -> None:  
        super().__init__(dataset, '', ice_eos_token, 0)  
  
    def retrieve(self, id_list: List[int] = None) -> List[List]:  
        if id_list is not None:  
            get_logger().warning('id_list is not empty, but will be ignored.')  
        rtr_idx_list = [[] for _ in range(len(self.test_ds))]  
        return rtr_idx_list  

3、FixKRetriever方案

测试提示的每个上下文示例都是从索引集中检索出相同的K个示例。

class FixKRetriever(BaseRetriever):  
"""  
    Fix-K Retriever. Each in-context example of the test prompts is  
    retrieved as the same K examples from the index set.  
"""  
    def __init__(self,  
                 dataset,  
                 fix_id_list: List[int],  
                 ice_separator: Optional[str] = '\n',  
                 ice_eos_token: Optional[str] = '\n',  
                 ice_num: Optional[int] = 1) -> None:  
        super().__init__(dataset, ice_separator, ice_eos_token, ice_num)  
        self.fix_id_list = fix_id_list  
  
    def retrieve(self):  
        """Retrieve the in-context example index for each test example."""  
        num_idx = len(self.index_ds)  
        for idx in self.fix_id_list:  
            assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'  
        rtr_idx_list = []  
        for _ in trange(len(self.test_ds), disable=not self.is_main_process):  
            rtr_idx_list.append(self.fix_id_list)  
        return rtr_idx_list  

4、RandomRetriever方案

RandomRetriever以随机方式检索测试提示的每个上下文示例。

class RandomRetriever(BaseRetriever):  
  """Random Retriever. Each in-context example of the test prompts is  
    retrieved in a random way.  
  """  
  
    def __init__(self,  
                 dataset,  
                 ice_separator: Optional[str] = '\n',  
                 ice_eos_token: Optional[str] = '\n',  
                 ice_num: Optional[int] = 1,  
                 seed: Optional[int] = 43) -> None:  
        super().__init__(dataset, ice_separator, ice_eos_token, ice_num)  
        self.seed = seed  
  
    def retrieve(self):  
        np.random.seed(self.seed)  
        num_idx = len(self.index_ds)  
        rtr_idx_list = []  
        logger.info('Retrieving data for test set...')  
        for _ in trange(len(self.test_ds), disable=not self.is_main_process):  
            idx_list = np.random.choice(num_idx, self.ice_num,  
                                        replace=False).tolist()  
            rtr_idx_list.append(idx_list)  
        return rtr_idx_list  

5、BM25Retriever方案

BM25Retriever中测试提示的每个测试提示的每个上下文示例都是通过BM25算法检索的。

class BM25Retriever(BaseRetriever):  
  """  
  Random Retriever. Each in-context example of the test prompts is  
    retrieved in a random way.  
  """  
    bm25 = None  
    index_corpus = None  
    test_corpus = None  
  
    def __init__(self,  
                 dataset,  
                 ice_separator: Optional[str] = '\n',  
                 ice_eos_token: Optional[str] = '\n',  
                 ice_num: Optional[int] = 1) -> None:  
        super().__init__(dataset, ice_separator, ice_eos_token, ice_num)  
        self.index_corpus = [  
            word_tokenize(data) for data in  
            self.dataset_reader.generate_input_field_corpus(self.index_ds)  
        ]  
        self.bm25 = BM25Okapi(self.index_corpus)  
        self.test_corpus = [  
            word_tokenize(data) for data in  
            self.dataset_reader.generate_input_field_corpus(self.test_ds)  
        ]  
  
    def retrieve(self) -> List[List]:  
        """Retrieve the in-context example index for each test example."""  
        rtr_idx_list = []  
        logger.info('Retrieving data for test set...')  
        for idx in trange(len(self.test_corpus),  
                          disable=not self.is_main_process):  
            query = self.test_corpus[idx]  
            scores = self.bm25.get_scores(query)  
            near_ids = list(np.argsort(scores)[::-1][:self.ice_num])  
            near_ids = [int(a) for a in near_ids]  
            rtr_idx_list.append(near_ids)  
        return rtr_idx_list  

6、TopkRetriever方案

TopkRetriever,使用 基本knn实现。SentenceTransformer用于计算嵌入。Faiss 用于进行近邻搜索

class TopkRetriever(BaseRetriever):  
    model = None  
     """Base class for Topk In-context Learning Retriever, implemented with  
    basic knn. SentenceTransformer is used to calculate embeddings. Faiss is  
    used to do the nearest neighbor search.  
    """  
      
    def __init__(self,  
                 dataset,  
                 ice_separator: Optional[str] = '\n',  
                 ice_eos_token: Optional[str] = '\n',  
                 ice_num: Optional[int] = 1,  
                 sentence_transformers_model_name: Optional[  
                     str] = 'all-mpnet-base-v2',  
                 tokenizer_name: Optional[str] = 'gpt2-xl',  
                 batch_size: Optional[int] = 1) -> None:  
        super().__init__(dataset, ice_separator, ice_eos_token, ice_num)  
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'  
        self.batch_size = batch_size  
        self.tokenizer_name = tokenizer_name  
        gen_datalist = self.dataset_reader.generate_input_field_corpus(  
            self.test_ds)  
  
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)  
        self.tokenizer.pad_token = self.tokenizer.eos_token  
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id  
        self.tokenizer.padding_side = 'right'  
  
        self.encode_dataset = DatasetEncoder(gen_datalist,  
                                             tokenizer=self.tokenizer)  
        co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,  
                                            device=self.device)  
        self.dataloader = DataLoader(self.encode_dataset,  
                                     batch_size=self.batch_size,  
                                     collate_fn=co)  
  
        self.model = SentenceTransformer(sentence_transformers_model_name)  
  
        self.model = self.model.to(self.device)  
        self.model.eval()  
  
        self.index = self.create_index()  
    ## 使用knn进行索引  
    def create_index(self):  
        import faiss  
  
        self.select_datalist = self.dataset_reader.generate_input_field_corpus(  
            self.index_ds)  
        encode_datalist = DatasetEncoder(self.select_datalist,  
                                         tokenizer=self.tokenizer)  
        co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,  
                                            device=self.device)  
        dataloader = DataLoader(encode_datalist,  
                                batch_size=self.batch_size,  
                                collate_fn=co)  
        index = faiss.IndexIDMap(  
            faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension()))  
        res_list = self.forward(dataloader,  
                                process_bar=True,  
                                information='Creating index for index set...')  
        id_list = np.array([res['metadata']['id'] for res in res_list])  
        self.embed_list = np.stack([res['embed'] for res in res_list])  
        index.add_with_ids(self.embed_list, id_list)  
        return index  
      
    ## 使用knn进行搜索  
    def knn_search(self, ice_num):  
        res_list = self.forward(self.dataloader,  
                                process_bar=True,  
                                information='Embedding test set...')  
        rtr_idx_list = [[] for _ in range(len(res_list))]  
        logger.info('Retrieving data for test set...')  
        for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):  
            idx = entry['metadata']['id']  
            embed = np.expand_dims(entry['embed'], axis=0)  
            near_ids = self.index.search(embed, ice_num)[1][0].tolist()  
            rtr_idx_list[idx] = near_ids  
        return rtr_idx_list  
    ## 利用sentence transformer进行embedding  
    def forward(self, dataloader, process_bar=False, information=''):  
        res_list = []  
        _dataloader = copy.deepcopy(dataloader)  
        if process_bar:  
            logger.info(information)  
            _dataloader = tqdm.tqdm(_dataloader,  
                                    disable=not self.is_main_process)  
        for _, entry in enumerate(_dataloader):  
            with torch.no_grad():  
                metadata = entry.pop('metadata')  
                raw_text = self.tokenizer.batch_decode(  
                    entry['input_ids'],  
                    skip_special_tokens=True,  
                    verbose=False)  
                res = self.model.encode(raw_text, show_progress_bar=False)  
            res_list.extend([{  
                'embed': r,  
                'metadata': m  
            } for r, m in zip(res, metadata)])  
        return res_list  
  
    def retrieve(self):  
        """Retrieve the in-context example index for each test example."""  
        return self.knn_search(self.ice_num)  

总结

本文主要介绍了OpenICL上下文示例学习框架的几种选择策略以及opencompass的实现,这对我们加强对prompt的认知有很大的帮助。

实际上,这些prompt的挑选,与我们现在进行大模型文档中相关文档的召回有很大的相似性,而如何快速的比对不同召回方式,快速地进行评估,进行模块化也是我们需要关注的问题。两种存在着本质上的共性。

参考文献

1、https://github.com/open-compass/opencompass

2、https://github.com/Shark-NLP/OpenICL

3、https://arxiv.org/abs/2303.02913

关于我们

老刘,刘焕勇,NLP开源爱好者与践行者,主页:https://liuhuanyong.github.io。

老刘说NLP,将定期发布语言资源、工程实践、技术总结等内容,欢迎关注。

对于想加入更优质的知识图谱、事件图谱、大模型AIGC实践、相关分享的,可关注公众号,在后台菜单栏中点击会员社区->会员入群加入。

​​​

更多AI工具,参考Github-AiBard123国内AiBard123