AI 文摘

Mixtral8x7B(MistralMoE)模型解析





作者: AINLP 来源: AINLP

本文特别鸣谢字节跳动 Crane佬解答了我对SWA的疑惑

0 前言

1 Mistral 7B 模型

1.1 SWA(Sliding Window Attention)

2 Mixtral 8x7B(MoE)模型

3 Llama2 70B vs Mixtral 8x7B

0 前言

从前段时间Mistral AI 公司发布全球首款MoE(Mixture-of-Experts)大模型——Mixtral-8x7B 以来,就在AI界引起了不小的轰动,从一众科技自媒体的报道中我注意到了一个关键信息点:比Llama-2 70B具有更少的参数 ,却有更高的精度 。这一点燃起了我的兴趣,故特来学习一下Mixtral 8x7B 相对于Llama 2 70B有何不同。还是老样子

  • paper :https://arxiv.org/pdf/2401.04088.pdf

  • code :https://github.com/mistralai/mistral-src

首先,通过Mistral AI 公司的主页我发现他一共发布了两个模型:Mistral 7B 和 Mixtral-8x7B ,后者为基于前者的MoE模型 。从其公布的测试结果可以发现Mistral 7B 以7B的参数量在所有benchmarks超越了Llama-2 13B 并且与Llama-2 34B性能相当

而使用MoE策略的 Mixtral-8x7B 模型则以46.7B参数量,在多数benchmarks上超越Llama 2 70B模型。

如此优异的表现,本文就来看看这两个模型相对于Llama 2做了哪些改变,以及相对于Llama 2 这两个模型的参数量和FLOPs

这里再多说一句,因为Mistral 模型是基于Llama 2模型的,所以对Llama 2模型结构不了解的可以先去看看我之前写的Llama 2详解

1 Mistral 7B模型

llama

Mistral 7B模型与Llama 2 7B模型结构整体上是相似的,其结构参数如下所示

model-arch

具体而言,就是存在以下几点差异:

  • 对于Attention部分使用GQA (Group Query Attention)来计算注意力机制,其中Q的头数为32,而KV 的头数为8,换句话说就是每4组Q共享一组KV。这一点与Llama 2 不同,Llama 2 是在34B和70B中才使用了GQA,在7B中依然使用的是MHA(Multi-Head-Attention)

  • 使用SWA(Sliding Window Attention) 。GQA和SWA叠加来降低显存占用提高推理速度。

  • 增大FeedForward HiddenDim的值,由Llama-2 7B的11008 ,改为14336

GQA和更改FFN HiddenDim的值 这两个改动都比较容易理解,那么接下来就主要来看看SWA(Sliding Window Attention)的原理和实现细节

1.1 SWA(Sliding Window Attention)

Mistral 使用了GQA和SWA两种方法来加速计算Attention,GQA在Llama 2详解的文章中说明过,这里主要讲解一下SWA。我们知道在Attention的计算一般是Q 与shape为[bst, multi-head,seq_len, head_dim] 的KV 进行注意力计算,其中seq_len 为已处理所有tokens总数,GQA在多头上做文章使得多组Q共享一组KV;而SWA则是在seq_len 这个维度做文章,不在将Q与所有seq-len的KV直接 “计算注意力,而是只与Sliding Window SizeKV直接 “计算注意力,如下示意图,为Sliding Window Size为3的情况

注意:这里用的是直接 计算注意力,下文会说明直接的含义

swa

举个例子,在on 单词所对应的token计算Attention时,对于普通Attention 可以与前面所有单词对应的 计算Attention,而对于SWA, 只能直接与、、计算。

我们知道在LLM推理时,一般分为prompting 和 generation两个阶段,为了满足SWA,prompting阶段可以通过一个mask的掩码操作实现,如下

if input_ids.shape[1] > 1:  
    # seqlen推理时在prompt阶段为n,在generation阶段为1  
    seqlen = input_ids.shape[1]  
    # mask在推理时也只在prompt阶段有,  
    #定义一个全1方阵  
    tensor = torch.full((seqlen, seqlen),fill_value=1)  
    # 上三角部分全为0  
    mask = torch.tril(tensor, diagonal=0).to(h.dtype)  
    # make the mask banded to account for sliding window  
    # 这里代码diagonal应该等于(-self.args.sliding_window+1)才能满足window size为    
    # self.args.sliding_window,这应该是官方代码的一个小bug?  
    mask = torch.triu(mask, diagonal=-self.args.sliding_window)  
    mask = torch.log(mask)  
"""  
举个例子,tensor.shape : [10,10]  
self.args.sliding_window = 5,则mask为  
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],  
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],  
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],  
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],  
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],  
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],  
        [0, 1, 1, 1, 1, 1, 1, 0, 0, 0],  
        [0, 0, 1, 1, 1, 1, 1, 1, 0, 0],  
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 0],  
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]])  
"""  

而在generation阶段,因为是自回归生成所以mask起不到作用,那此时mistral则使用了RotatingBufferCache来实现此操作,具体而言,就是采用一种循环右移的存储方式,剔除离得远的KV,保存靠近的KV 。

rotationcache

如上图展示了一个Window Size为4的Cache,循环右移的写Cache的示意图。

RotatingBufferCache代码实现如下

# The cache is a rotating buffer  
# positions[-self.sliding_window:] 取最后w个位置的索引,取余  
# [None, :, None, None]操作用于扩维度[1,w,1,1]  
scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]  
# repeat操作repeat维度 [bsz, w, kv_head, head_dim]  
scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)  
# src取[:,-w,:,:] 所以src.shape=[bsz,w,kv_head,head_dim]  
# 根据scatter_pos作为index 将src写入cache  
self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])  
self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])  

我相信多数读者读到这里会跟我有一样的疑问,只让Q与前面Window Size的KV计算Attention,不会影响最终的预测精度吗?因为我们知道当前生成的token是由前面所有token共同决定的。而且论文中并没有特别详细说明,且给出的示意图(下图) 也有些让人费解。

这里结合Crane佬的解答和mistral官方repo的 issuse (https://github.com/mistralai/mistral-src/issues/40),我大概弄明白了:

SWA确实限制了每个token的Q只能关注固定大小(Window Size)内的其他token,然而,信息通过网络的传播并不仅仅局限于Window Size的大小,它还设计多层Transformer之间的信息传递。

举个例子,假设我有一组tokens ,并且我假设此时Sliding Window Size为3,当前处理的token为

那么对于而言,此时KV cache中存的分别是和 ,所以此时的能直接 获得最远的token信息是 ,而又是由前层输出计算而来,而中又汇总了tokens-的信息,同理类推又是由前前层的输出计算而来,所以他们又带了tokens-的信息

综上所述,对于而言虽然只能直接与tokens - 直接进行注意力机制计算,但是却可以间接与更早 的tokens - 参与注意力机制运算,以此类推,只要层数足够大,配合这种传递方式就可以覆盖整个序列。论文中还举例说明,对于一个序列长度是16k,Window Size为4K的SWA,只需要四层,最后一个token就能看到之前的全部token信息

2 Mixtral 8x7B (MoE)模型

前文说过 Mixtral-8x7B就是Mistral 7B的MoE模型,除了上述Mistral 7B中的特性以外,Mixtral-8x7B还引入了MoE结构。MoE(Mixture-of-Experts) 其实也不是一个新技术,早在1991年就已经被Michael Jordan 和 Geoffrey Hinton所提出 Adaptive mixtures of local experts , 而且关于MoE的发展在深度学习界也从未停止过 (所谓经典永不过时说的便是如此),相关的papers综述这里提供一个写的不错的Blog供大家参考一下:Mixture-of-Experts (MoE) 经典论文一览

这里简单的解释一下什么是MoE,简单点说就是我让一个网络模型结构有多条分支,每条分支代表一个Expert(专家),每个Expert都有其擅长的领域,当具体任务来临时,可以通过一个门空位Gate来具体选择采用哪一个或者哪几个Experts进行计算,这样的好处就是让每个Expert更专注特定领域,降低了不同领域数据对权重学习的干扰。当然在训练MoE模型时也要注意各个Experts负载均衡,防止赢者通吃,达不到想要的目的。

具体到Mixtral 8x7B模型中,其MoE的结构示意图如下所示

MoE 图源自@OpenCompass

可以发现,相对于Llama ,Mixtral 8x7B模型将FFN替换为MoE FFN,还是直接看代码

class MoeLayer(nn.Module):  
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):  
        super().__init__()  
        assert len(experts) > 0  
        # 定义experts,就是一组(8个)Llama FFN,  
        # Llama FFN就是两个Linear + Silu + Linear  
        self.experts = nn.ModuleList(experts)  
        # gate也是一个Linear,这个Linear weight的维度是[hidden_dim , num_experts]  
        self.gate = gate    
        self.args = moe_args  
    def forward(self, inputs: torch.Tensor):  
        # 更改input shape [bst,seq_len,hidden-dim] -> [bst*seq_len,hidden-dim]  
        inputs_squashed = inputs.view(-1, inputs.shape[-1])  
        # Gate Linear 将输入线性映射到num_experts  
        # 即[bst*seq_len,hidden-dim] -> [bst*seq_len,num_experts]  
        gate_logits = self.gate(inputs_squashed)  
        # topk排序  
        # weights返回topk的值  
        # selected_experts 返回topk的index  
        weights, selected_experts = torch.topk(  
            gate_logits, self.args.num_experts_per_tok  
        )  
        # 对每个weight做softmax,归一化  
        weights = nn.functional.softmax(  
            weights,  
            dim=1,  
            dtype=torch.float,  
        ).type_as(inputs)  
        results = torch.zeros_like(inputs_squashed)  
        for i, expert in enumerate(self.experts):  
            # 根据selected_experts确定weight的行id和列id  
            batch_idx, nth_expert = torch.where(selected_experts == i)  
            # 通过上述id选择对应的加权数据 以及执行对应的expert,并将结果加权求和  
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(  
                inputs_squashed[batch_idx]  
            )  
        return results.view_as(inputs)  

3 Llama-2 70B vs Mixtral 8x7B

文章的最后,我们再来对比一下Llama-2 70B 和 Mixtral 8x7B 的参数量以及浮点运算量(FLOPs)

  • Params

ModelAttentionFeedForwardLayersOthersTotal

Llama-2 70B 8192* 8192 * 2+ 8192 * 1024 * 2 + 8192 = 151003136 8192 * 28672 * 3 + 8192=704651264 80 8192 * 32000 * 2+ 8192 = 524296192 68976648192=68.98B

Mixtral 8x7B 4096 * 4096 * 2 + 4096 * 512 * 2 + 4096 = 37752832 4096 * 8 + 8 * (4096 * 14336 *3) + 4096 = 1409323008 32 4096 * 32000 * 2+ 4096 = 262148096 46568574976=46.57B

  • FLOPs

计算FLOPs,我们就都以输入为2048的单batch作为基准计算,并且我们只计算矩阵乘法相关的FLOPs作为整体网络FLOPs的估算,Norm层的计算先忽略

ModelAttentionFeedForwardLayersOthersTotal

Llama-2 70B 2 * 2048 * 8192 * 8192 * 2 + 2* 2048 * 8192 * 1024 * 2 + 64 * 2 *2048 * 128 * 2048 * 2 = 7.55914244 10^{11} 3 * 8192 * 28672 * 2048 2 =2.88621802 10^{12} 80 2 * 8192 * 32000 2048 * 2 = 2.1474836510^{12} 2.9351806510^{14}=293.5TFLOPs

Mixtral 8x7B 2 * 2048 * 4096 * 4096 2 + 2 * 2048 * 4096 * 512 * 2 + 32 * 2 2048 * 128 * 2048 * 2 = 2.23338299 * 10^{11} 2048 * 4096 * 8 * 2 +3 * 4096 * 14336 * 2048 * 2 * 2 = 1.44324323 * 10^{12} 32 2 2048 * 4096 * 32000 * 2= 1.07374182 10^{12} 5.44043508* 10^{13} = 54.4TFLOPs

好啦完结撒花~

进技术交流群请添加AINLP小助手微信(id: ainlp2)

请备注具体方向+所用到的相关技术点

![](https://api.allorigins.win/raw?url=https://mmbiz.qpic.cn/mmbiz_jpg/nW2ZPfuYqSJADkmZ2IX6Z23znAibuEevotDMq9iaMxiapK7jfMibiauGFkycicAJEs6x5U9SGyDJZ0S1tRed9TPNUUDQ/640?wx_fmt=other&wxfrom=5&wx_lazy=1&wx_co=1&tp=webp)

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括LLM、预训练模型、自动生成、文本摘要、智能问答、聊天机器人、机器翻译、知识图谱、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLP小助手微信(id:ainlp2),备注工作/研究方向+加群目的。

  


![](https://api.allorigins.win/raw?url=https://mmbiz.qpic.cn/mmbiz_jpg/nW2ZPfuYqSKABHCqVVQkVYPrM4XY1vsd0iaeuXzyJnoFc8cibd5mYb4wdA3WMQtiaPVmr0XLZHMuVibqWncibpnTSnQ/640?wx_fmt=other&wxfrom=5&wx_lazy=1&wx_co=1&tp=webp)

更多AI工具,参考Github-AiBard123国内AiBard123

可关注我们的公众号:每天AI新工具