从零开始,用PyTorch复现一个迷你版Llama(附GQA和RoPE代码实现)
从零构建迷你LlamaPyTorch实战中的GQA与RoPE核心实现当Meta在2023年初开源LLaMA模型时很少有人预料到这个羊驼会引发大模型开源生态的连锁反应。作为首个能在消费级显卡上运行的高性能开源大模型LLaMA不仅证明了小规模模型通过高质量数据训练可以达到惊人效果更重要的是其精心设计的架构为后续模型提供了可复用的模块化组件。本文将带您从零开始用PyTorch实现一个包含Group Query Attention和RoPE位置编码的简化版LLaMA通过代码解剖那些让LLaMA高效运转的核心机制。1. 环境准备与基础架构在开始构建模型前我们需要配置合适的开发环境。推荐使用Python 3.9和PyTorch 2.0环境这些版本对Transformer架构有更好的原生支持conda create -n minillama python3.9 conda activate minillama pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install transformers matplotlib numpyLLaMA的基础架构仍然是Transformer的堆叠但有几个关键改进点前置RMSNorm替代传统LayerNorm提升训练稳定性RoPE位置编码为Q/K矩阵注入相对位置信息Group Query Attention平衡MHA和MQA的优势SwiGLU MLP更高效的激活函数设计我们先定义模型的基础配置类这将作为后续所有组件的参数来源class LlamaConfig: def __init__( self, vocab_size32000, hidden_size4096, intermediate_size11008, num_hidden_layers32, num_attention_heads32, num_key_value_heads8, # GQA参数 max_position_embeddings2048, rms_norm_eps1e-6, rope_theta10000.0, ): self.vocab_size vocab_size self.hidden_size hidden_size self.intermediate_size intermediate_size self.num_hidden_layers num_hidden_layers self.num_attention_heads num_attention_heads self.num_key_value_heads num_key_value_heads self.max_position_embeddings max_position_embeddings self.rms_norm_eps rms_norm_eps self.rope_theta rope_theta2. 实现RoPE旋转位置编码RoPE(Rotary Position Embedding)是LLaMA位置编码的核心创新它通过旋转矩阵的方式将位置信息注入到Q/K向量中。与传统的绝对位置编码不同RoPE能在Attention计算中自然形成相对位置关系。2.1 RoPE数学原理RoPE的核心思想是对查询向量q和键向量k进行旋转变换。对于位置m的q向量和位置n的k向量旋转操作可以表示为q̃ₘ qₘ * cos(mθ) rotate(qₘ) * sin(mθ) k̃ₙ kₙ * cos(nθ) rotate(kₙ) * sin(nθ)其中rotate()表示将向量后半部分取负后与前半部分交换位置的操作。这种设计的精妙之处在于当计算q̃ₘ和k̃ₙ的点积时结果会自动包含相对位置(m-n)的信息。2.2 PyTorch实现我们首先实现旋转操作的辅助函数def rotate_half(x): 将输入张量的后一半维度旋转 x1 x[..., : x.shape[-1] // 2] x2 x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim-1)接下来是完整的RoPE模块实现class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings2048, base10000): super().__init__() self.dim dim self.max_position_embeddings max_position_embeddings self.base base # 计算频率倒数 inv_freq 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer(inv_freq, inv_freq, persistentFalse) self._set_cos_sin_cache() def _set_cos_sin_cache(self, seq_lenNone): seq_len self.max_position_embeddings if seq_len is None else seq_len t torch.arange(seq_len, deviceself.inv_freq.device) # 计算外积得到位置*频率 freqs torch.einsum(i,j-ij, t, self.inv_freq) emb torch.cat((freqs, freqs), dim-1) self.register_buffer(cos_cache, emb.cos(), persistentFalse) self.register_buffer(sin_cache, emb.sin(), persistentFalse) def forward(self, x, position_ids): # x: [bs, num_heads, seq_len, head_dim] # position_ids: [bs, seq_len] seq_len x.shape[-2] if seq_len self.max_position_embeddings: self._set_cos_sin_cache(seq_len) cos self.cos_cache[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin self.sin_cache[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] x_rot (x * cos) (rotate_half(x) * sin) return x_rot在实际应用中RoPE有以下几个关键优势长序列友好旋转操作不会引入额外的参数计算复杂度稳定相对位置编码点积结果自然包含相对位置信息可扩展性可通过调整base参数控制位置编码的衰减速度3. Group Query Attention实现Group Query Attention(GQA)是LLaMA V2引入的重要改进它在多头注意力(MHA)和多查询注意力(MQA)之间找到了平衡点。3.1 GQA核心思想GQA将查询头分成若干组每组共享相同的键/值头。这种设计带来了两个好处相比MHA显著减少了KV缓存的内存占用相比MQA保留了更多的模型容量和表达能力注意力类型查询头数键/值头数KV缓存大小表达能力MHAHH最大最强MQAH1最小最弱GQAHG (1GH)中等平衡3.2 PyTorch实现首先实现KV投影的共享逻辑def repeat_kv(hidden_states: torch.Tensor, n_rep: int) - torch.Tensor: 将键/值头重复n_rep次以匹配查询头数 batch, num_kv_heads, seq_len, head_dim hidden_states.shape if n_rep 1: return hidden_states hidden_states hidden_states[:, :, None, :, :].expand( batch, num_kv_heads, n_rep, seq_len, head_dim ) return hidden_states.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)完整的GQA实现如下class LlamaAttention(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config config self.hidden_size config.hidden_size self.num_heads config.num_attention_heads self.head_dim self.hidden_size // self.num_heads self.num_key_value_heads config.num_key_value_heads self.num_key_value_groups self.num_heads // self.num_key_value_heads # 投影层 self.q_proj nn.Linear(self.hidden_size, self.num_heads * self.head_dim) self.k_proj nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim) self.v_proj nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim) self.o_proj nn.Linear(self.hidden_size, self.hidden_size) # RoPE初始化 self.rotary_emb LlamaRotaryEmbedding( self.head_dim, max_position_embeddingsconfig.max_position_embeddings, baseconfig.rope_theta, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.Tensor] None, past_key_value: Optional[Tuple[torch.Tensor]] None, ): batch_size, seq_len, _ hidden_states.shape # 投影查询、键、值 query_states self.q_proj(hidden_states) key_states self.k_proj(hidden_states) value_states self.v_proj(hidden_states) # 重塑为多头形式 query_states query_states.view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states key_states.view( batch_size, seq_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states value_states.view( batch_size, seq_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) # 应用RoPE位置编码 query_states self.rotary_emb(query_states, position_ids) key_states self.rotary_emb(key_states, position_ids) # KV缓存处理 if past_key_value is not None: key_states torch.cat([past_key_value[0], key_states], dim2) value_states torch.cat([past_key_value[1], value_states], dim2) # 重复KV头以匹配查询头数 key_states repeat_kv(key_states, self.num_key_value_groups) value_states repeat_kv(value_states, self.num_key_value_groups) # 计算注意力分数 attn_weights torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights attn_weights attention_mask # 注意力概率和上下文计算 attn_weights torch.softmax(attn_weights, dim-1) attn_output torch.matmul(attn_weights, value_states) # 合并头并输出 attn_output attn_output.transpose(1, 2).contiguous() attn_output attn_output.reshape(batch_size, seq_len, self.hidden_size) return self.o_proj(attn_output)在实际应用中GQA需要注意以下几点头数比例通常设置num_key_value_heads为num_attention_heads的1/4到1/8KV缓存推理时需正确管理past_key_value以支持自回归生成计算效率相比MHA可节省30-40%的显存占用同时保持相近的模型质量4. 完整模型集成与验证现在我们将所有组件集成为一个完整的迷你LLaMA模型并验证其正确性。4.1 模型架构实现class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() self.config config self.hidden_size config.hidden_size self.intermediate_size config.intermediate_size self.gate_proj nn.Linear(self.hidden_size, self.intermediate_size) self.up_proj nn.Linear(self.hidden_size, self.intermediate_size) self.down_proj nn.Linear(self.intermediate_size, self.hidden_size) self.act_fn nn.SiLU() # SwiGLU激活 def forward(self, x): gate self.act_fn(self.gate_proj(x)) up self.up_proj(x) return self.down_proj(gate * up) class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps1e-6): super().__init__() self.weight nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon eps def forward(self, hidden_states): variance hidden_states.pow(2).mean(-1, keepdimTrue) hidden_states hidden_states * torch.rsqrt(variance self.variance_epsilon) return self.weight * hidden_states class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size config.hidden_size self.self_attn LlamaAttention(config) self.mlp LlamaMLP(config) self.input_layernorm LlamaRMSNorm(config.hidden_size, epsconfig.rms_norm_eps) self.post_attention_layernorm LlamaRMSNorm(config.hidden_size, epsconfig.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.Tensor] None, past_key_value: Optional[Tuple[torch.Tensor]] None, ): # 自注意力 residual hidden_states hidden_states self.input_layernorm(hidden_states) hidden_states self.self_attn( hidden_stateshidden_states, attention_maskattention_mask, position_idsposition_ids, past_key_valuepast_key_value, ) hidden_states residual hidden_states # MLP residual hidden_states hidden_states self.post_attention_layernorm(hidden_states) hidden_states self.mlp(hidden_states) hidden_states residual hidden_states return hidden_states class MiniLlama(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config config self.embed_tokens nn.Embedding(config.vocab_size, config.hidden_size) self.layers nn.ModuleList( [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm LlamaRMSNorm(config.hidden_size, epsconfig.rms_norm_eps) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.Tensor] None, past_key_values: Optional[List[torch.Tensor]] None, ): # 嵌入层 hidden_states self.embed_tokens(input_ids) # 解码层 for idx, decoder_layer in enumerate(self.layers): past_key_value past_key_values[idx] if past_key_values is not None else None hidden_states decoder_layer( hidden_states, attention_maskattention_mask, position_idsposition_ids, past_key_valuepast_key_value, ) # 最终归一化 hidden_states self.norm(hidden_states) return hidden_states4.2 模型验证与测试为了验证我们的实现是否正确我们可以进行以下几个测试形状检查确保各层输入输出维度匹配RoPE测试验证位置编码是否正确影响注意力模式GQA测试检查KV头的共享逻辑是否正确def test_model(): config LlamaConfig( hidden_size256, num_attention_heads8, num_key_value_heads2, intermediate_size512, num_hidden_layers4, ) model MiniLlama(config) # 测试输入 input_ids torch.randint(0, config.vocab_size, (2, 16)) position_ids torch.arange(16).unsqueeze(0).expand(2, -1) attention_mask torch.ones(2, 16) # 前向传播 outputs model(input_ids, attention_mask, position_ids) print(fOutput shape: {outputs.shape}) # 应为[2, 16, 256] # 测试RoPE rotary_emb LlamaRotaryEmbedding(dim32) q torch.randn(1, 8, 16, 32) # [bs, heads, seq_len, dim] q_rot rotary_emb(q, position_ids[:, :16]) print(fRotated Q shape: {q_rot.shape}) # 应与输入相同 # 测试GQA attn LlamaAttention(config) kv torch.randn(2, 2, 16, 32) # [bs, kv_heads, seq_len, dim] repeated_kv repeat_kv(kv, n_rep4) print(fRepeated KV shape: {repeated_kv.shape}) # 应为[2, 8, 16, 32] test_model()4.3 性能优化技巧在实际部署中我们可以用以下优化手段Flash Attention使用优化的注意力实现提升计算效率KV缓存量化对past_key_value进行低精度存储算子融合将RMSNorm与后续线性层融合为单个核函数# Flash Attention示例需安装flash-attn包 try: from flash_attn import flash_attn_func def flash_attention_forward( q, k, v, attention_maskNone, dropout_p0.0, softmax_scaleNone ): return flash_attn_func( q, k, v, dropout_pdropout_p, softmax_scalesoftmax_scale, causalTrue ) except ImportError: print(Flash Attention not available, using default implementation)通过本实现我们不仅理解了LLaMA的核心机制更重要的是掌握了如何将这些创新设计应用到自己的模型中。RoPE和GQA已经成为现代大语言模型的标准配置它们的实现思路值得每一位NLPer深入理解。