用PyTorch手把手实现一个带注意力机制的UNet:从残差块到Skip Connection的完整代码拆解
用PyTorch手把手实现一个带注意力机制的UNet从残差块到Skip Connection的完整代码拆解在计算机视觉领域图像分割任务一直面临着如何有效捕捉局部细节与全局上下文信息的挑战。传统卷积神经网络CNN虽然擅长提取局部特征但在处理长距离依赖关系时表现有限。本文将带你从零开始构建一个融合注意力机制的UNet模型通过代码级别的拆解深入理解每个模块的设计原理与实现细节。1. 基础模块构建残差块与注意力机制1.1 残差块(ResidualBlock)实现残差连接是深度神经网络中的重要设计它通过跨层连接缓解了梯度消失问题。在UNet中我们使用改进版的残差块class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, time_channels, n_groups32): super().__init__() # 第一组归一化激活卷积 self.norm1 nn.GroupNorm(n_groups, in_channels) self.act1 nn.SiLU() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size3, padding1) # 第二组归一化激活卷积 self.norm2 nn.GroupNorm(n_groups, out_channels) self.act2 nn.SiLU() self.conv2 nn.Conv2d(out_channels, out_channels, kernel_size3, padding1) # 短路连接处理维度不匹配 self.shortcut ( nn.Conv2d(in_channels, out_channels, kernel_size1) if in_channels ! out_channels else nn.Identity() ) # 时间嵌入处理 self.time_emb nn.Linear(time_channels, out_channels) self.time_act nn.SiLU() def forward(self, x, t): h self.conv1(self.act1(self.norm1(x))) # 时间嵌入相加前需要调整维度 h self.time_emb(self.time_act(t))[:, :, None, None] h self.conv2(self.act2(self.norm2(h))) return h self.shortcut(x)关键设计要点组归一化(GroupNorm)相比批归一化对batch size不敏感时间嵌入将时间步信息融入网络适合时序相关任务短路连接当输入输出通道数不同时使用1x1卷积调整维度1.2 注意力机制(AttentionBlock)实现自注意力机制能够捕捉长距离依赖关系我们实现一个轻量级的空间注意力模块class AttentionBlock(nn.Module): def __init__(self, n_channels, n_heads1, d_kNone, n_groups32): super().__init__() self.norm nn.GroupNorm(n_groups, n_channels) self.projection nn.Linear(n_channels, n_heads * d_k * 3) self.output nn.Linear(n_heads * d_k, n_channels) self.scale (d_k or n_channels) ** -0.5 self.n_heads n_heads self.d_k d_k or n_channels def forward(self, x, tNone): batch, channels, height, width x.shape # 重塑为序列形式 (batch, seq_len, channels) x_flat x.view(batch, channels, -1).permute(0, 2, 1) # 计算QKV qkv self.projection(self.norm(x_flat)) qkv qkv.view(batch, -1, self.n_heads, 3 * self.d_k) q, k, v qkv.chunk(3, dim-1) # 注意力计算 attn torch.einsum(bihd,bjhd-bijh, q, k) * self.scale attn attn.softmax(dim2) out torch.einsum(bijh,bjhd-bihd, attn, v) # 输出投影 out out.reshape(batch, -1, self.n_heads * self.d_k) out self.output(out) x_flat # 恢复原始形状 return out.permute(0, 2, 1).view(batch, channels, height, width)维度变换过程解析步骤张量形状说明输入(B,C,H,W)原始特征图展平(B,H*W,C)准备计算注意力QKV投影(B,HW,3n_heads*d_k)线性变换分头处理3×(B,H*W,n_heads,d_k)拆分为Q,K,V注意力得分(B,HW,HW,n_heads)QK点积缩放输出(B,H*W,C)加权求和后还原2. UNet骨干网络设计2.1 下采样模块(DownBlock)下采样路径由多个DownBlock组成每个包含残差连接和可选注意力机制class DownBlock(nn.Module): def __init__(self, in_channels, out_channels, time_channels, has_attn): super().__init__() self.res ResidualBlock(in_channels, out_channels, time_channels) self.attn AttentionBlock(out_channels) if has_attn else nn.Identity() def forward(self, x, t): x self.res(x, t) return self.attn(x)2.2 上采样模块(UpBlock)上采样路径需要处理skip connection关键实现细节class UpBlock(nn.Module): def __init__(self, in_channels, out_channels, time_channels, has_attn): super().__init__() # 输入通道包含skip connection的拼接 self.res ResidualBlock(in_channels out_channels, out_channels, time_channels) self.attn AttentionBlock(out_channels) if has_attn else nn.Identity() def forward(self, x, t): x self.res(x, t) return self.attn(x)2.3 采样操作实现UNet需要精确控制特征图尺寸变化class Downsample(nn.Module): 使用步长2卷积实现下采样 def __init__(self, n_channels): super().__init__() self.conv nn.Conv2d(n_channels, n_channels, 3, stride2, padding1) def forward(self, x, tNone): return self.conv(x) class Upsample(nn.Module): 使用转置卷积实现上采样 def __init__(self, n_channels): super().__init__() self.conv nn.ConvTranspose2d(n_channels, n_channels, 4, stride2, padding1) def forward(self, x, tNone): return self.conv(x)采样策略对比类型实现方式优点缺点下采样步长2卷积参数少计算高效可能丢失高频信息上采样转置卷积可学习的上采样可能产生棋盘效应替代方案双线性插值卷积减少伪影计算量稍大3. 完整UNet集成3.1 网络架构配置通过配置字典灵活控制网络结构default_unet_config { image_channels: 3, # 输入图像通道数 n_channels: 64, # 初始通道数 channel_mults: [1, 2, 4], # 各层通道倍增系数 attn_levels: [False, True, True], # 哪些层使用注意力 n_blocks: 2, # 每层残差块数量 dropout: 0.1 # 丢弃率 }3.2 核心实现代码完整UNet类的初始化与前向传播class UNet(nn.Module): def __init__(self, configdefault_unet_config): super().__init__() # 初始化各组件 self.image_proj nn.Conv2d(config[image_channels], config[n_channels], 3, padding1) self.time_emb TimeEmbedding(config[n_channels] * 4) # 下采样路径 self.down_blocks nn.ModuleList() in_ch config[n_channels] for i, mult in enumerate(config[channel_mults]): out_ch config[n_channels] * mult for _ in range(config[n_blocks]): self.down_blocks.append( DownBlock(in_ch, out_ch, config[n_channels]*4, config[attn_levels][i])) in_ch out_ch if i ! len(config[channel_mults])-1: self.down_blocks.append(Downsample(in_ch)) # 中间层 self.middle MiddleBlock(in_ch, config[n_channels]*4) # 上采样路径 self.up_blocks nn.ModuleList() for i, mult in reversed(list(enumerate(config[channel_mults]))): out_ch config[n_channels] * mult for _ in range(config[n_blocks]): self.up_blocks.append( UpBlock(in_ch, out_ch, config[n_channels]*4, config[attn_levels][i])) # 通道数减半的过渡块 if i ! 0: self.up_blocks.append(Upsample(out_ch)) in_ch out_ch // 2 # 输出层 self.final nn.Sequential( nn.GroupNorm(8, in_ch), nn.SiLU(), nn.Conv2d(in_ch, config[image_channels], 3, padding1) ) def forward(self, x, t): # 时间嵌入 t self.time_emb(t) # 初始投影 x self.image_proj(x) # 存储跳连 skip_connections [x] # 下采样 for block in self.down_blocks: x block(x, t) if not isinstance(block, Downsample): skip_connections.append(x) # 中间层 x self.middle(x, t) # 上采样 for block in self.up_blocks: if isinstance(block, Upsample): x block(x, t) else: skip skip_connections.pop() x torch.cat([x, skip], dim1) x block(x, t) return self.final(x)4. 实战技巧与调试建议4.1 维度匹配检查表实现UNet时最常见的错误是维度不匹配以下检查点需特别注意下采样路径每次下采样后特征图尺寸应减半通道数变化需与配置一致跳连存储的特征图应在对应上采样时能正确拼接上采样路径转置卷积输出尺寸应与对应跳连特征图匹配拼接操作前应确保通道数正确最终输出尺寸应与输入图像一致注意力机制确保QKV计算后能正确还原维度多头注意力的头数应能整除通道数4.2 性能优化技巧# 启用混合精度训练 scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): output model(inputs, timesteps) loss criterion(output, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()其他优化建议使用梯度裁剪防止梯度爆炸采用学习率预热策略对注意力层使用Flash Attention实现合理设置组归一化的组数(通常32或16)4.3 常见问题排查训练不收敛检查残差连接是否正常工作验证时间嵌入是否正确融入网络确保注意力权重计算正确归一化显存不足降低批处理大小使用梯度累积在注意力层启用内存高效实现输出质量差检查跳连是否正确拼接验证上/下采样实现是否正确调整注意力层的放置位置