【Anchor DETR论文阅读】:基于锚点查询设计的Transformer检测器,50epoch收敛且速度精度双升
论文信息标题Anchor DETR: Query Design for Transformer-Based Object Detection会议AAAI 2022单位MEGVII Technology旷视科技代码github.com/megvii-research/AnchorDETR论文https://arxiv.org/pdf/2109.07107.pdf一、引言DETR的查询黑盒之谜DETR开创了Transformer端到端检测的极简范式无Anchor、无NMS美得不像话。但它有两个老大难问题查询不可解释object query 是纯可学习向量不知道它“负责看哪里”。收敛极慢要训 500 epoch工业落地根本等不起。根本原因每个查询没有明确的空间责任范围全局乱看优化难度极大。于是旷视这篇 Anchor DETR 直接拍板让每个查询绑定一个锚点anchor point有明确的“地盘”学得又快又好效果直接拉满训练轮数从 500 →50 epoch提速 10 倍R50-DC5 单尺度特征AP 44.2%速度19 FPS超过 DETR、Deformable DETR、Conditional DETR还提出RCDA 行列解耦注意力省显存、无随机内存访问、硬件友好拉满二、核心动机让查询“有址可寻”2.1 DETR 查询的乱象图 1预测区域的可视化展示。请注意子图a源自 DETR 图Carion 等人2020 年。每个预测区域包含了查询值集上的所有框预测。每个有颜色的点代表一个预测的标准化中心位置。这些点通过颜色编码来区分绿色代表小框红色代表大水平框蓝色代表大垂直框。子图b最后一行中的黑色点表示锚点。我们所提出的预测区域与特定位置的关系比 DETR 更紧密。(a) DETR每个查询的预测框散布全图没有明确聚焦区域。(b) Anchor DETR每个查询的预测都紧紧围绕自己的锚点职责清晰。这就是位置模糊性查询不知道自己该管哪儿自然难训练。2.2 解决方案Anchor Point Multi-Pattern图 2锚点分布的可视化展示。每个点都代表了一个锚点的标准化位置。Anchor DETR 做了两件最关键的事查询 锚点编码每个查询明确“我守这个点”。一点多检测一个锚点配多个 pattern解决“同位置多物体”问题。三、方法详解全文精读无省略3.1 总体架构图 3所提出探测器的流程图。请注意编码器层和解码器层的结构与 DETR1 相同只是我们在编码器层中替换掉了自注意力机制在解码器层中替换掉了交叉注意力机制改用了我们提出的“行列解耦注意力”机制。流程Backbone → Encoder带RCDA→ Decoder带Anchor查询RCDA→ FFN分类回归创新点只有两个但贯穿全文基于锚点模式的查询设计行列解耦注意力 RCDA3.2 锚点与查询编码1锚点定义锚点就是图像上的坐标点Posq∈RNA×2Pos_q \in \mathbb{R}^{N_A \times 2}Posq∈RNA×2NAN_ANA锚点数量每个点存储(x,y)(x,y)(x,y)归一化 0~1支持两种锚点网格锚点均匀铺在图上可学习锚点随机初始化跟着训练一起学2锚点 → 物体查询把锚点坐标编码成查询位置嵌入QpEncode(Posq)Q_p Encode(Pos_q)QpEncode(Posq)最直接的方式就是和key共用编码函数Qpg(Posq),Kpg(Posk)Q_p g(Pos_q),\quad K_p g(Pos_k)Qpg(Posq),Kpg(Posk)文章直接用两层MLP做编码适配性更强。3.3 Multi-Pattern一點多檢測一个位置可能叠多个物体比如人抱小孩。于是给每个锚点配Np 个模式向量让一个点能出多个框。最终查询数量NqNp×NAN_q N_p \times N_ANqNp×NANpN_pNp模式数默认3NAN_ANA锚点数默认300所有锚点共享同一组模式向量保证平移不变性。图片3来自原文 Figure 4三个模式分别负责不同宽高比的物体分工明确。3.4 注意力公式回顾DETR 标准注意力Attention(Q,K,V)softmax(QKTdk)VAttention(Q,K,V) softmax(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)softmax(dkQKT)VQQfQp,KKfKpQQ_fQ_p,\quad KK_fK_pQQfQp,KKfKpQfQ_fQf查询内容特征QpQ_pQp查询位置本文锚点编码KfK_fKf键内容KpK_pKp键位置3.5 RCDA行列解耦注意力超级硬核这是本文第二个大创新把2D注意力拆成行注意力 列注意力。为什么要拆标准注意力的内存消耗Nq×H×W×MN_q \times H \times W \times MNq×H×W×M太高RCDA 先对 K 做全局池化拆成行特征Kf,x∈RW×CK_{f,x} \in \mathbb{R}^{W \times C}Kf,x∈RW×C列特征Kf,y∈RH×CK_{f,y} \in \mathbb{R}^{H \times C}Kf,y∈RH×C然后先做行注意力再做列注意力。最终内存消耗从O(HW)O(HW)O(HW)降到O(HW)O(HW)O(HW)省显存、速度快、硬件友好且不会引入随机内存访问。四、核心公式与符号解释4.1 锚点查询QpEncode(Posq)Q_p Encode(Pos_q)QpEncode(Posq)QpQ_pQp物体查询位置部分PosqPos_qPosq锚点坐标(x,y)(x,y)(x,y)EncodeEncodeEncode正弦编码 MLP4.2 多模式查询QfinitQ_f^{init}Qfinit由模式向量广播得到NqNp×NAN_q N_p × N_ANqNp×NANpN_pNp每个锚点的检测头数模式数NAN_ANA锚点数量4.3 RCDA 内存节省比StandardRCDAW×MC\frac{Standard}{RCDA} \frac{W×M}{C}RCDAStandardCW×MWWW特征图宽MMM头数CCC通道数默认256默认设置下可省2~4倍显存。五、核心代码PyTorch 风格# # 1. 锚点编码 → 查询# defencode_anchor_points(anchor_points,out_dim256):# anchor_points: [B, N, 2] (x,y)x,yanchor_points.unbind(-1)pos_xpositional_encoding(x,out_dim//2)pos_ypositional_encoding(y,out_dim//2)postorch.cat([pos_x,pos_y],dim-1)# MLP 编码posmlp(pos)returnpos# # 2. 多模式 pattern 扩展# defmultiply_pattern(anchor_embedding,pattern_embedding):# anchor_embedding: [B, NA, C]# pattern_embedding: [NP, C]NAanchor_embedding.shape[1]NPpattern_embedding.shape[0]# 每个锚点复制 NP 个模式q_posanchor_embedding.unsqueeze(2).repeat(1,1,NP,1)q_featpattern_embedding.unsqueeze(0).repeat(1,NA,1,1)# 合并成 [B, NA*NP, C]q_posq_pos.flatten(1,2)q_featq_feat.flatten(1,2)returnq_feat,q_pos# # 3. 行列解耦注意力 RCDA# classRowColumnDecoupledAttention(nn.Module):defforward(self,q,k,v,q_posNone,k_posNone):# 解耦行、列特征k_rowk.mean(1)# [B, W, C]k_colk.mean(2)# [B, H, C]# 行注意力attn_rowtorch.matmul(q,k_row.transpose(-2,-1))out_rowtorch.matmul(attn_row.softmax(-1),v)# 列注意力attn_coltorch.matmul(q,k_col.transpose(-2,-1))out_coltorch.matmul(attn_col.softmax(-1),v)# 融合returnout_rowout_col六、实验结果与深度分析6.1 与Transformer检测模型对比表格1来自原文 Table 1模型特征APFPSDETRDC543.312SMCAmulti43.710Deformable DETRmulti43.815Conditional DETRDC543.810Anchor DETRDC544.219结论单尺度特征吊打多尺度速度第一50epoch 超过 DETR 500epoch6.2 与主流检测器对比表格2来自原文 Table 2模型EpochAPDETR-DC550043.3Anchor DETR-DC55044.2真正意义上速度、精度、收敛、成本全维度超越。6.3 消融实验表格3来自原文 Table 3RCDAanchorpatternAP39.3✅42.6✅40.3✅40.3✅✅✅44.2三个组件缺一不可共同提升 4.9 AP七、全文总结最精髓5句话DETR慢的根源查询无明确空间责任注意力散乱难优化。Anchor DETR 解法查询锚点编码责任明确收敛狂快。Multi-Pattern一个锚点多模式解决同位置多物体。RCDA 行列解耦注意力省显存、速度快、硬件友好。最终效果50epoch、单尺度、无NMS、无Anchor、AP 44.2%、FPS 19。这篇是工业落地极其友好的一篇 DETR 改进结构干净、速度快、收敛快、可解释强。