Jetson Orin AGX INT4 推理优化实践:super 分支从 9 tok/s 到 24 tok/s
Jetson Orin AGX INT4 推理优化实践super 分支从 9 tok/s 到 24 tok/s项目地址https://github.com/luogantt/LLM-inference-engine本文总结jetson-orin-agx-super分支上的一次端侧大模型推理优化实践。目标设备是 Jetson Orin AGX目标模型是 DeepSeek-R1-Distill-Qwen-7B目标场景是单 batch、单 token decode。这次优化的核心结论很明确INT4 不是只把权重压成 4 bit 就会自动变快。在 Jetson Orin AGX 上INT4 要真正加速必须配合 INT8 activation 和 DP4A 整数点积不能走 float 解包。最终保留的最快版本是lib-int4-o4-all。在实际 decode 日志里速度推进到约24 tokens/s单 token forward 延迟约43 ms。测试环境Device: Jetson Orin AGX CUDA arch: sm_87 Model: DeepSeek-R1-Distill-Qwen-7B Branch: jetson-orin-agx-super batch: 1 max_seq: 800 max_new_tokens: 512运行命令CUDA_VISIBLE_DEVICES0python python_infer.py\--model/data/project/deepseek-r1-7b\--lib./build/libllm_cuda.so\--prompt你好 deepseek 介绍一下黑格尔的思想\--max-new-tokens512\--max-seq800当前推荐编译命令make-fMakefile.cuda_lib clean-libmake-fMakefile.cuda_lib lib-int4-o4-allAsm_87模型尺寸和 decode 的真实瓶颈当前代码里的关键模型尺寸为N_LAYERS 28 HIDDEN 3584 KV_DIM 512 INTERMEDIATE 18944 VOCAB_SIZE 152064在单 token decode 阶段每一步只处理一个新的 token。此时最主要的计算不是大 batch GEMM而是大量 GEMVmatrix weight x vector activation也就是y W x b对第j个输出通道y_j \sum_{i0}^{H-1} W_{j,i} x_i b_j这里H 3584。每个输出行都要和长度为 3584 的 hidden vector 做一次点积。在每一层里主要 linear 包括Q projection: HIDDEN x HIDDEN K projection: KV_DIM x HIDDEN V projection: KV_DIM x HIDDEN O projection: HIDDEN x HIDDEN Gate projection: INTERMEDIATE x HIDDEN Up projection: INTERMEDIATE x HIDDEN Down projection: HIDDEN x INTERMEDIATE其中 MLP 的 gate/up/down 计算量很大QKV projection 和普通 linear 也会在每个 decode step 反复出现。只优化某一个 linear整体速度提升有限。super分支里真正有效的版本是把普通 linear、QKV、gate/up 这些主路径都切到 INT4 INT8 activation DP4A。从 FP 线性层到 INT4 DP4A 的数学推导原始 float 或 half 线性层为y_j \sum_i W_{j,i} x_i b_j如果做 weight-only INT4通常对每个输出行保存一个 scaleW_{j,i} \approx s^W_j q^W_{j,i}其中q^W_{j,i} \in [-8, 7]量化过程可以写成q^W_{j,i} \operatorname{clip} \left( \operatorname{round}\left(\frac{W_{j,i}}{s^W_j}\right), -8, 7 \right)如果 activation 仍然保持 float那么计算会变成y_j \approx \sum_i s^W_j q^W_{j,i} x_i b_j这条路径看似使用了 INT4 权重但每个权重在计算时仍然要load packed int4 unpack nibble sign extend convert to float float multiply-add所以它只是减少了权重带宽没有把计算本身切到整数点积。早期 INT4 版本速度不理想根本原因就在这里。要让 INT4 真正加速需要把 activation 也量化成 INT8x_i \approx s^x q^x_i其中q^x_i \in [-127, 127]单 token 动态 activation 量化为s^x \frac{\max_i |x_i|}{127}q^x_i \operatorname{clip} \left( \operatorname{round}\left(\frac{x_i}{s^x}\right), -127, 127 \right)代回线性层y_j \approx \sum_i \left(s^W_j q^W_{j,i}\right) \left(s^x q^x_i\right) b_j把 scale 提出来y_j \approx s^W_j s^x \sum_i q^W_{j,i} q^x_i b_j中间累加项是一个整数点积acc_j \sum_i q^W_{j,i} q^x_i最终反量化y_j \approx s^W_j s^x acc_j b_j这就是super分支 INT4 DP4A 路径的数学本质。INT32 accumulator 是否安全对当前 hidden sizeH 3584最坏情况下|q^W_{j,i}| \le 8|q^x_i| \le 127单项乘积最大约为8 \times 127 1016一个输出行的最坏累加绝对值上界为3584 \times 1016 3,641,344这个值远小于 int32 的范围2^{31} - 1 2,147,483,647所以在当前模型尺寸下用 int32 accumulator 保存 INT4 x INT8 点积是安全的。DP4A 做了什么NVIDIA GPU 的 DP4A 指令可以在一条指令中完成 4 组 int8 乘加acc \leftarrow acc a_0 b_0 a_1 b_1 a_2 b_2 a_3 b_3其中a_k和b_k都是 int8。对于 INT4 权重存储时一个 byte 可以放 2 个权重一个uint32_t可以放 8 个 INT4 权重uint32 packed [w7 w6 w5 w4 w3 w2 w1 w0]计算时可以把 8 个 INT4 权重拆成两组 int8x4[w0, w1, w2, w3] - int8x4 [w4, w5, w6, w7] - int8x4activation 已经是 INT8连续 8 个 activation 可以看作两组 int8x4[x0, x1, x2, x3] - int8x4 [x4, x5, x6, x7] - int8x4于是 8 个权重和 8 个 activation 的点积可以用两次 DP4A 完成acc \leftarrow acc \operatorname{DP4A}(w_{0:3}, x_{0:3})acc \leftarrow acc \operatorname{DP4A}(w_{4:7}, x_{4:7})这条路径避免了逐元素 float 解包和 float FMA把核心计算变成整数指令。为什么 INT4 float 解包不快INT4 的理论带宽优势很明显。以一个输出行为例H 3584权重格式每个权重字节数单输出行权重读取FP162 bytes7168 bytesINT81 byte3584 bytesINT40.5 byte1792 bytesINT4 相比 FP16权重读取量变成 1/4。相比 INT8权重读取量变成 1/2。但如果 INT4 每个元素都走unpack - sign extend - convert float - fmaf那么额外指令会吃掉带宽收益。实际日志也验证了这一点版本计算路径实测速度Weight-only INT8INT8 weight float/普通路径约 14 tok/s早期 INT4INT4 weight float 解包约 9 tok/sINT4 DP4AINT4 weight INT8 activation DP4A约 20 tok/s 以上所以 INT4 的关键不是“存得小”而是“算得对”。在 Orin AGX 上必须让 INT4 权重进入整数点积路径。super 分支的优化路线这次jetson-orin-agx-super分支主要经历了几轮版本核心思路实测表现初始 INT4INT4 存储但计算路径不够整数化约 9 tok/sINT4 DP4Aactivation INT8权重 INT4整数点积约 20 tok/slib-int4-o2-all一个 block 同时算 2 个输出行覆盖普通 linear、QKV、gate/up约 22.5 tok/slib-int4-o4-all一个 block 同时算 4 个输出行继续提高 activation 复用约 24 tok/slib-int4-o8-all一个 block 同时算 8 个输出行掉到约 18 tok/s已回滚最终保留的是lib-int4-o4-all。o2-all 和 o4-all 为什么能加速原始做法可以理解为一个 block 只算一个输出行。block 0 - y0 block 1 - y1 block 2 - y2 ...每个 block 都要读取同一份 activation vectorx只是读取的权重 row 不同。对 GEMV 来说activation 是所有输出行共享的y_j \sum_i W_{j,i} x_i这里的x_i对所有j都相同。于是可以让一个 block 同时算多个输出行block 0 - y0, y1, y2, y3 block 1 - y4, y5, y6, y7 ...对 4-output 版本一个 block 内维护 4 个 accumulatoracc_0 \sum_i q^W_{0,i} q^x_iacc_1 \sum_i q^W_{1,i} q^x_iacc_2 \sum_i q^W_{2,i} q^x_iacc_3 \sum_i q^W_{3,i} q^x_i每次读取一组 activation 后可以同时喂给 4 个权重 rowload x int8x4 load row0 int4x4 - dp4a - acc0 load row1 int4x4 - dp4a - acc1 load row2 int4x4 - dp4a - acc2 load row3 int4x4 - dp4a - acc3这样做有几个好处block 数量减少调度开销下降。activation 读取被多个输出行复用。每个 block 做的工作更饱满。仍然只维护 4 个主要 accumulator寄存器压力可控。可以用一个简化成本模型理解T(r) \approx T_{launch/block}(r) T_{weight} T_{activation/reuse}(r) T_{reduction}(r) T_{register/occupancy}(r)其中r表示一个 block 同时计算的输出行数。当r从 1 增加到 2、4block 数量下降 activation 复用提高 整体吞吐提高但当r继续增加到 8每个线程 accumulator 变多 row pointer 和 scale pointer 变多 寄存器使用变多 shared memory reduction 变重 occupancy 下降所以r不是越大越好。对 Jetson Orin AGX 和当前 hidden size 来说r 4是这次实测中最平衡的点。为什么 o8-all 失败并回滚lib-int4-o8-all的想法很自然既然 4-output 更快那 8-output 会不会更快实测结果是否定的。o8-all的 decode 速度掉到了约18 tok/sforward_ms ≈ 57.6 ms decode_tokens_per_s ≈ 18.0 tok/s这说明瓶颈已经从 block 调度和 activation 复用转移到了寄存器压力、occupancy 和 reduction 成本。8-output kernel 里每个线程需要同时维护8 个 accumulator 8 个 row pointer 更多 scale/local/output 指针 更多写回分支 更多 shared memory reduction 数据这些都会降低 SM 上可同时驻留的 block 数量。对 Orin AGX 这种端侧 GPU 来说occupancy 一旦下降整数 DP4A 指令也喂不满最后性能反而下降。所以o8-all被回滚当前super分支保留o4-all作为推荐路径。实测结果lib-int4-o2-all的一次记录forward_ms ≈ 46.4326 decode_tokens 474 decode_tokens_per_s ≈ 22.5382lib-int4-o4-all的一次记录forward_ms ≈ 43.7584 decode_tokens 474 decode_tokens_per_s ≈ 23.9898lib-int4-o8-all的一次记录forward_ms ≈ 57.5994 decode_tokens 474 decode_tokens_per_s ≈ 18.0096对比可以看到\frac{23.99}{22.54} \approx 1.064o4-all相比o2-all继续提升约 6.4%。而o8-all相比o4-all\frac{18.01}{23.99} \approx 0.751也就是掉了约 25%。这说明o4-all已经接近当前 kernel 结构下的甜点区间。与主流端侧推理引擎的关系MLC、llama.cpp、TensorRT-LLM 等主流推理引擎都有更完整的工程体系例如模型转换、图优化、跨平台 runtime、更多量化格式和更成熟的算子调度。这个项目的目标不是替代它们而是做一条更透明、更直接的 CUDA decode 优化路线不依赖 PyTorch 推理 不依赖大型 runtime 直接手写 C / CUDA decode 路径 针对 Jetson Orin AGX 的单 batch 场景优化这次super分支的意义在于它证明了一个小型手写 CUDA 推理引擎只要抓住端侧 decode 的真实瓶颈也可以把 7B 模型推到和主流端侧引擎同量级的速度区间。更重要的是这个过程把 INT4 加速的关键讲清楚了INT4 weight-only compression 只解决存储和带宽问题 INT8 activation quantization 让计算进入整数域 DP4A 让整数点积真正被硬件高效执行 4-output GEMV layout 在复用和 occupancy 之间取得平衡当前推荐使用方式切换到jetson-orin-agx-super分支后gitpull origin jetson-orin-agx-super编译make-fMakefile.cuda_lib clean-libmake-fMakefile.cuda_lib lib-int4-o4-allAsm_87运行CUDA_VISIBLE_DEVICES0python python_infer.py\--model/data/project/deepseek-r1-7b\--lib./build/libllm_cuda.so\--prompt你好 deepseek 介绍一下黑格尔的思想\--max-new-tokens512\--max-seq800后续还能继续优化什么当前o4-all已经是这轮实验里最好的版本但后面仍然有一些方向可以继续尝试。1. 更精细的 kernel fusion现在已经优化了多个 linear 的 INT4 DP4A 路径但 RMSNorm、量化、linear、SwiGLU、residual 之间仍然存在 kernel 边界。后续可以继续研究是否能减少中间写回。2. activation quantization 优化当前 activation 每步动态量化s^x \frac{\max_i |x_i|}{127}这一步需要先求 max再写出 int8 activation。后续可以研究更快的归约、近似 scale、或者和前一个算子融合。3. KV cache 访存优化decode 越往后attention 对 KV cache 的读取越重。当前max_seq800下后段 token 的 forward_ms 会逐步上升说明 KV cache 和 attention 访存仍然值得优化。4. 针对固定尺寸生成专用 kernel当前模型尺寸固定HIDDEN 3584 INTERMEDIATE 18944 KV_DIM 512可以为这些尺寸生成更激进的专用 kernel减少通用分支和边界判断。5. 更严格的同条件 benchmark后续如果要和 MLC、llama.cpp 等引擎对标需要统一同一模型 同一量化方式 同一 prompt 同一 max_seq 同一 max_new_tokens 同一 Jetson 电源模式和频率设置 同一 prefill/decode 统计口径只有这样速度对比才足够严谨。总结jetson-orin-agx-super分支这次实践说明INT4 不等于自动加速。INT4 如果走 float 解包会浪费掉 4 bit 权重的优势。真正有效的路径是 INT4 权重、INT8 activation、DP4A 整数点积。单 token decode 的核心瓶颈是 GEMV不是大 batch GEMM。一个 block 同时算 4 个输出行是当前 Jetson Orin AGX 上更合理的平衡点。更激进的 8-output kernel 会因为寄存器压力和 occupancy 下降而变慢。最终lib-int4-o4-all把 DeepSeek-R1-Distill-Qwen-7B 在 Jetson Orin AGX 上的 decode 速度推进到约24 tokens/s。这不是靠框架黑盒得到的结果而是从线性层数学公式、量化公式、DP4A 指令到 GEMV kernel layout 一步步压出来的结果。这也是这个项目最有价值的地方它把端侧 LLM 推理的性能问题拆开让每一次优化都能被解释、被验证、被继续推进。