BN / LN / RMSNorm 归一化方法总结一、背景与动机深度网络训练中常见问题梯度消失 / 梯度爆炸不同层输入分布变化Internal Covariate Shift收敛慢、训练不稳定 归一化Normalization的核心目标将特征标准化到稳定分布加速训练并提升模型稳定性。BN、LN 和 RMSNorm 的本质区别在于归一化维度与是否中心化其中 BN 依赖 batchLN 按特征归一化而 RMSNorm 仅做尺度归一化以提升效率与稳定性。统一形式x ^ x − μ σ 2 ϵ \hat{x} \frac{x - \mu}{\sqrt{\sigma^2 \epsilon}}x^σ2ϵ​x−μ​再进行仿射变换y γ x ^ β y \gamma \hat{x} \betayγx^β二、Batch NormalizationBN1. 核心思想对batch 维度 空间维度做归一化常用于 CNNμ c E B , H , W [ x ] \mu_c \mathbb{E}_{B,H,W}[x]μc​EB,H,W​[x]σ c 2 Var B , H , W [ x ] \sigma_c^2 \text{Var}_{B,H,W}[x]σc2​VarB,H,W​[x]x ^ x − μ c σ c 2 ϵ \hat{x} \frac{x - \mu_c}{\sqrt{\sigma_c^2 \epsilon}}x^σc2​ϵ​x−μc​​2. 推理阶段重要使用滑动平均μ r u n n i n g ( 1 − m ) μ m μ b a t c h \mu_{running} (1-m)\mu m\mu_{batch}μrunning​(1−m)μmμbatch​σ r u n n i n g 2 ( 1 − m ) σ 2 m σ b a t c h 2 \sigma^2_{running} (1-m)\sigma^2 m\sigma^2_{batch}σrunning2​(1−m)σ2mσbatch2​3. 特点依赖 batch size训练 / 推理行为不同适合 CNN三、Layer NormalizationLN1. 核心思想对单个样本的特征维度做归一化Transformer 常用μ E C [ x ] \mu \mathbb{E}_{C}[x]μEC​[x]σ 2 Var C [ x ] \sigma^2 \text{Var}_{C}[x]σ2VarC​[x]x ^ x − μ σ 2 ϵ \hat{x} \frac{x - \mu}{\sqrt{\sigma^2 \epsilon}}x^σ2ϵ​x−μ​2. 特点不依赖 batch size训练 / 推理一致适合 NLP / Transformer四、RMSNormRoot Mean Square Norm1. 核心思想只做缩放不减均值R M S ( x ) E [ x 2 ] RMS(x) \sqrt{\mathbb{E}[x^2]}RMS(x)E[x2]​x ^ x E [ x 2 ] ϵ \hat{x} \frac{x}{\sqrt{\mathbb{E}[x^2] \epsilon}}x^E[x2]ϵ​x​y γ x ^ y \gamma \hat{x}yγx^2. 特点去掉均值中心化更简单计算更快在大模型中表现良好如 LLaMA五、三者对比面试重点方法归一化维度是否减均值是否依赖Batch训练/推理差异典型应用BNB, H, W✅✅✅CNNLNC✅❌❌TransformerRMSNormC❌❌❌大模型六、本质区别总结1. 归一化维度不同BN跨样本LN / RMSNorm单样本2. 是否中心化减均值BN / LN有RMSNorm无3. 数学表达差异BN / LNx − μ σ \frac{x - \mu}{\sigma}σx−μ​RMSNormx E [ x 2 ] \frac{x}{\sqrt{\mathbb{E}[x^2]}}E[x2]​x​七、代码实现# NOTE BN/LN/RMSNormimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFclassBatchNorm(nn.Module):# BN is usually used for CNN, and the input dimensions are B, C, H, W.def__init__(self,channels_dim,eps1e-5,momentum0.1):super().__init__()self.epseps self.momentummomentum# NOTE: momentum is the update speed of running_mean and running_varself.register_buffer(running_mean,torch.zeros(1,channels_dim,1,1))self.register_buffer(running_var,torch.ones(1,channels_dim,1,1))self.gammann.Parameter(torch.ones(1,channels_dim,1,1))self.betann.Parameter(torch.zeros(1,channels_dim,1,1))defforward(self,x):ifself.training:meanx.mean(dim[0,2,3],keepdimTrue)# B,C,H,W - 1,C,1,1varx.var(dim[0,2,3],keepdimTrue,unbiasedFalse)# B,C,H,W - 1,C,1,1# update running statsself.running_mean(1-self.momentum)*self.running_meanself.momentum*mean self.running_var(1-self.momentum)*self.running_varself.momentum*varelse:meanself.running_mean varself.running_var x_normed(x-mean)/torch.sqrt(varself.eps)outself.gamma*x_normedself.betareturnoutclassLayerNorm(nn.Module):# LN is usually used for RNN/Transformer, and the input dimensions are B, L, C.def__init__(self,channels_dim,eps1e-5):super().__init__()self.epseps self.gammann.Parameter(torch.ones(1,1,channels_dim))self.betann.Parameter(torch.zeros(1,1,channels_dim))defforward(self,x):meanx.mean(dim-1,keepdimTrue)# B,L,C - B,L,1varx.var(dim-1,keepdimTrue,unbiasedFalse)# B,L,C - B,L,1x_normed(x-mean)/torch.sqrt(varself.eps)outself.gamma*x_normedself.betareturnoutclassRMSNorm(nn.Module):# RMSNorm is a variant of LN, which only normalizes the variance and does not normalize the mean.# It is usually used for RNN/Transformer, and the input dimensions are B, L, C.def__init__(self,channels_dim,eps1e-5):super().__init__()self.epseps self.gammann.Parameter(torch.ones(1,1,channels_dim))defforward(self,x):rmstorch.mean(x**2,dim-1,keepdimTrue)# B,L,C - B,L,1x_normedx/torch.sqrt(rmsself.eps)outself.gamma*x_normedreturnoutif__name____main__:xtorch.rand(10,5,768)LNLayerNorm(768)x_LNLN(x)print(x_LN.shape)RMSNRMSNorm(768)x_RMSRMSN(x)print(x_RMS.shape)cnn_xtorch.rand(4,12,512,512)BNBatchNorm(12)x_BNBN(cnn_x)print(x_BN.shape)