本文基于昇腾CANN和昇腾NPU围绕 ops-transformer 仓库的相关技术展开。LayerNorm 在大模型里被 RMSNorm 替换了。LayerNorm 做了减均值再除方差RMSNorm 只除均方根——去掉了减均值那一步。少一次 Reduce 操作在量产推理里省掉 15-20% 的归一化时间。LayerNorm 的计算流程# LayerNorm——先算均值、再算方差、再归一化deflayer_norm(x,gamma,beta,eps1e-6): x: [batch, seq_len, hidden_dim] gamma: [hidden_dim] —— 可学习缩放 beta: [hidden_dim] —— 可学习偏移 # Step 1: 算均值——一次 Reducemeanx.mean(dim-1,keepdimTrue)# [b, s, 1]# Step 2: 算方差——二次 Reducevar((x-mean)**2).mean(dim-1,keepdimTrue)# [b, s, 1]# Step 3: 归一化x_norm(x-mean)/torch.sqrt(vareps)# 减均值再除方差# Step 4: 缩放 偏移returnx_norm*gammabeta# 每次 LayerNorm 做 2 次全张量 Reduce 1 次逐元素 Scale# hidden_dim4096 时每次需要读 4096 个值 2 次 写 1 次LayerNorm 去均值那一步在 NLP 里不是必须的——Transformer 的残差连接已经做了中心化。RMSNorm 砍掉这步只做 Scale。RMSNorm 的数学差异# RMSNorm——只除 RMS过均值defrms_norm(x,gamma,eps1e-6): x: [batch, seq_len, hidden_dim] gamma: [hidden_dim] —— 可学习缩放无 beta RMSNorm(x) x / RMS(x) * gamma RMS(x) sqrt(mean(x^2) eps) # Step 1: 算均方——只有 1 次 Reducermstorch.sqrt((x**2).mean(dim-1,keepdimTrue)eps)# [b, s, 1]# Step 2: 归一化——不做减均值直接除x_normx/rms# Step 3: 缩放——有 gamma没有 betareturnx_norm*gamma# 跟 LayerNorm 的差异# 1. 没有 mean x.mean() → 省一次全张量 Reduce# 2. 没有 x - mean → 省一次逐元素减法# 3. 没有 beta → 省一次加法# 统计上RMSNorm 收敛到跟 LayerNorm 同等精度Llama 全系列用 RMSNorm——Llama-3.1-405B 也不例外。用 RMSNorm 替代 LayerNorm 后405B 模型单次 Forward 省掉 2 次大 Tensor 操作。CANN 上的 RMSNorm 融合实现// Ascend C 实现的 RMSNorm——融合了 Pow Reduce Sqrt DivclassRMSNormKernel:publicAscendC::Kernel{__aicore__inlinevoidProcess()override{// 一次性搞清 RMSNorm 的 Tile 策略constinttile_size1024;// 每次处理 1024 维constinttiles_per_blockhidden_dim/tile_size;AscendC::LocalTensorfloatx_local;AscendC::LocalAllocfloat(x_local,tile_size);AscendC::LocalTensorfloatsq_local;AscendC::LocalAllocfloat(sq_local,tile_size);// 分 Tile 计算 x^2 并在片上做部分累加// 这样不用把所有 x 搬完再算 RMS——减少 L1⇄DDR 往返floatpartial_sum0.0f;for(intt0;ttiles_per_block;t){// 搬一个 Tile 到 L1 Bufferinttile_offsett*tile_size;AscendC::DataCopy(x_local,gm_xtile_offset,tile_size);// x^2——用了 Vec 单元的通用计算指令AscendC::Mul(sq_local,x_local,x_local);// 片上的局部 ReduceSum——不走 DDRAscendC::ReduceAdd(partial_sum,sq_local,tile_size);}// RMS sqrt(partial_sum / hidden_dim eps)floatrmssqrtf(partial_sum/hidden_dim1e-6f);floatinv_rms1.0f/rms;// 用乘法代替除法// 第二遍 Tilex * inv_rms * gammafor(intt0;ttiles_per_block;t){inttile_offsett*tile_size;AscendC::DataCopy(x_local,gm_xtile_offset,tile_size);// 加载 gamma 参数AscendC::LocalTensorfloatgamma_local;AscendC::LocalAlloc(gamma_local,tile_size);AscendC::DataCopy(gamma_local,gm_gammatile_offset,tile_size);// x / rms * gamma——一次合并完成AscendC::Mul(x_local,x_local,inv_rms);AscendC::Mul(x_local,x_local,gamma_local);// 写回AscendC::DataCopy(gm_outtile_offset,x_local,tile_size);}}};比 LayerNorm 少了一个x - mean和一个 Reduce多出来的算力可以给 Batch 里的下一个请求。实测 Llama-7B 上把 Norm 替换为 RMSNorm 后Decode 速度从 28 tok/s 提到 32 tok/s。参考仓库RMSNorm 等 Transformer 算子神经网络基础算子库