别再被数学劝退!用PyTorch从零实现DDPM扩散模型(附完整代码)
用PyTorch实战DDPM无需深究数学也能玩转扩散模型当你在社交媒体上看到AI生成的艺术作品时是否好奇过它们背后的技术原理扩散模型Diffusion Models作为当前最热门的生成式AI技术之一正以惊人的速度改变着内容创作的格局。本文将带你绕过复杂的数学推导直接进入代码实践环节用PyTorch从零构建一个完整的DDPMDenoising Diffusion Probabilistic Models模型。1. 环境准备与数据加载在开始之前我们需要配置好开发环境。推荐使用Python 3.8和PyTorch 1.12版本pip install torch torchvision matplotlib tqdm对于数据集我们将使用经典的CIFAR-10它包含60,000张32x32的彩色图像import torch from torchvision import datasets, transforms transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset datasets.CIFAR10( root./data, trainTrue, downloadTrue, transformtransform ) dataloader torch.utils.data.DataLoader( dataset, batch_size128, shuffleTrue )提示如果你的GPU显存较小可以将batch_size调整为64或322. DDPM核心组件实现2.1 噪声调度器扩散模型的核心在于如何合理地添加和去除噪声。我们需要定义一个噪声调度器来控制不同时间步的噪声强度import math def linear_beta_schedule(timesteps): beta_start 0.0001 beta_end 0.02 return torch.linspace(beta_start, beta_end, timesteps) def cosine_beta_schedule(timesteps, s0.008): steps timesteps 1 x torch.linspace(0, timesteps, steps) alphas_cumprod torch.cos(((x / timesteps) s) / (1 s) * math.pi * 0.5) ** 2 alphas_cumprod alphas_cumprod / alphas_cumprod[0] betas 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) timesteps 1000 betas cosine_beta_schedule(timesteps) # 使用余弦调度器效果更好 # 预计算有用的值 alphas 1. - betas alphas_cumprod torch.cumprod(alphas, axis0) sqrt_alphas_cumprod torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod torch.sqrt(1. - alphas_cumprod)2.2 前向加噪过程前向过程逐步将数据转换为高斯噪声这个过程是固定的不需要训练def q_sample(x_start, t, noiseNone): if noise is None: noise torch.randn_like(x_start) sqrt_alphas_cumprod_t extract(sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) return sqrt_alphas_cumprod_t * x_start sqrt_one_minus_alphas_cumprod_t * noise def extract(a, t, x_shape): batch_size t.shape[0] out a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)3. 构建U-Net模型U-Net是DDPM中用于预测噪声的核心网络结构。下面我们实现一个简化版的U-Netimport torch.nn as nn import torch.nn.functional as F class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim): super().__init__() self.time_mlp nn.Linear(time_emb_dim, out_ch) self.conv1 nn.Conv2d(in_ch, out_ch, 3, padding1) self.conv2 nn.Conv2d(out_ch, out_ch, 3, padding1) self.norm nn.GroupNorm(8, out_ch) def forward(self, x, t): h self.conv1(x) h self.norm(h) h F.silu(h) time_emb F.silu(self.time_mlp(t)) h h time_emb[:, :, None, None] h self.conv2(h) h self.norm(h) h F.silu(h) return h class UNet(nn.Module): def __init__(self): super().__init__() self.time_mlp nn.Sequential( SinusoidalPositionEmbeddings(100), nn.Linear(100, 256), nn.SiLU(), nn.Linear(256, 256) ) self.down1 Block(3, 64, 256) self.down2 Block(64, 128, 256) self.down3 Block(128, 256, 256) self.mid Block(256, 256, 256) self.up1 Block(512, 128, 256) self.up2 Block(256, 64, 256) self.up3 Block(128, 64, 256) self.out nn.Conv2d(64, 3, 1) def forward(self, x, t): t self.time_mlp(t) # 下采样 h1 self.down1(x, t) h2 self.down2(F.max_pool2d(h1, 2), t) h3 self.down3(F.max_pool2d(h2, 2), t) # 中间层 h self.mid(F.max_pool2d(h3, 2), t) # 上采样 h F.interpolate(h, scale_factor2, modenearest) h self.up1(torch.cat([h, h3], dim1), t) h F.interpolate(h, scale_factor2, modenearest) h self.up2(torch.cat([h, h2], dim1), t) h F.interpolate(h, scale_factor2, modenearest) h self.up3(torch.cat([h, h1], dim1), t) return self.out(h) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim dim def forward(self, time): device time.device half_dim self.dim // 2 embeddings math.log(10000) / (half_dim - 1) embeddings torch.exp(torch.arange(half_dim, devicedevice) * -embeddings) embeddings time[:, None] * embeddings[None, :] embeddings torch.cat((embeddings.sin(), embeddings.cos()), dim-1) return embeddings4. 训练与采样4.1 训练循环DDPM的训练目标是让网络学会预测添加到图像中的噪声def train(model, dataloader, optimizer, epochs, device): model.train() for epoch in range(epochs): for step, (images, _) in enumerate(dataloader): images images.to(device) # 随机采样时间步 t torch.randint(0, timesteps, (images.shape[0],), devicedevice).long() # 生成随机噪声 noise torch.randn_like(images) # 前向加噪过程 noisy_images q_sample(images, t, noise) # 预测噪声 predicted_noise model(noisy_images, t) # 计算损失 loss F.mse_loss(noise, predicted_noise) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 0: print(fEpoch {epoch} | Step {step} | Loss: {loss.item():.4f})4.2 采样生成训练完成后我们可以使用模型从纯噪声开始逐步去噪生成新图像torch.no_grad() def p_sample(model, x, t, t_index): betas_t extract(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t extract(sqrt_one_minus_alphas_cumprod, t, x.shape) sqrt_recip_alphas_t extract(sqrt_recip_alphas, t, x.shape) # 计算模型均值 model_mean sqrt_recip_alphas_t * ( x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t ) if t_index 0: return model_mean else: posterior_variance_t extract(posterior_variance, t, x.shape) noise torch.randn_like(x) return model_mean torch.sqrt(posterior_variance_t) * noise torch.no_grad() def sample(model, image_size, batch_size16, channels3): device next(model.parameters()).device # 从随机噪声开始 img torch.randn((batch_size, channels, image_size, image_size), devicedevice) for i in reversed(range(0, timesteps)): t torch.full((batch_size,), i, devicedevice, dtypetorch.long) img p_sample(model, img, t, i) # 将图像从[-1,1]转换到[0,1] img (img 1) * 0.5 return img5. 模型优化与技巧在实际应用中我们可以采用以下几种策略来提升DDPM的性能学习率调度使用余弦退火学习率可以显著提升模型收敛速度混合精度训练通过FP16训练可以节省显存并加快训练速度EMA模型使用指数移动平均的模型参数可以提高生成质量渐进式训练从低分辨率开始训练逐步提高分辨率# 示例EMA模型实现 class EMA: def __init__(self, beta): super().__init__() self.beta beta self.step 0 def update_model_average(self, ema_model, current_model): for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): old_weight, new_weight ema_params.data, current_params.data ema_params.data self.update_average(old_weight, new_weight) def update_average(self, old, new): if old is None: return new return old * self.beta (1 - self.beta) * new在CIFAR-10数据集上训练约50个epoch后你应该能够看到模型开始生成可识别的物体图像。虽然32x32的分辨率不高但这个完整的实现已经包含了DDPM的所有关键组件。