CANN-昇腾NPU-量化训练-QAT和PTQ怎么选
模型量化有两种时机训练时做QATQuantization-Aware Training和训练后做PTQPost-Training Quantization。在昇腾NPU上QAT 用 torch_npu 的量化感知训练PTQ 用 CANN 的 AMCT 工具。这篇讲清楚两者的适用场景和操作步骤。PTQ训练后量化PTQ 不需要重新训练直接把 fp16 模型量化成 int8/w8a8。适合快速上线、不想重新训练的场景。fromamct_npuimportcreate_quant_config,quantize_model# 1. 准备校准数据集100-1000 条代表性数据calib_dataloaderget_calib_dataloader(num_samples500)# 2. 创建量化配置configcreate_quant_config(model_filemodel.onnx,config_file./quant_config.json,dst_json_path./quant_ready.json,)# 3. 校准跑一遍校准数据统计激活分布quant_modelquantize_model(model_filemodel.onnx,quant_config_file./quant_config.json,calib_dataloadercalib_dataloader,)# 4. 导出量化模型quant_model.export_quant_onnx(model_quant.onnx)PTQ 的关键校准数据集要跟真实推理数据分布一致。用训练集做校准推理时分布不同精度损失会放大。QAT量化感知训练QAT 在训练时模拟量化误差让模型适应量化。精度损失比 PTQ 小 30-50%但需要重新训练。importtorchfromtorch_npu.contribimportQATWrapper modelAutoModelForCausalLM.from_pretrained(meta-llama/Llama-2-7b-hf,torch_dtypetorch.bfloat16,device_mapnpu:0,)# 包装成 QAT 模型qat_modelQATWrapper(model,qconfig{weight:int8,activation:int8,quantize_per_tensor:True,})# 正常训练QAT 在 forward 时插入伪量化节点optimizertorch.optim.AdamW(qat_model.parameters(),lr1e-5)fordataindataloader:lossqat_model(data)loss.backward()optimizer.step()# 训练完成后转成真正量化模型quant_modeltorch.ao.quantization.convert(qat_model)torch.save(quant_model.state_dict(),model_qat.pt)精度损失对比Llama2-7BCANN 8.5Atlas 800I A2量化方案WNLI (准确率)GSM8K (准确率)推理速度fp16 (基准)78.5%56.2%1.0×PTQ int876.1% (-2.4%)53.8% (-2.4%)1.8×QAT int877.9% (-0.6%)55.6% (-0.6%)1.8×PTQ int468.2% (-10.3%)44.1% (-12.1%)2.5×QAT int474.8% (-3.7%)51.3% (-4.9%)2.5×QAT 的精度损失只有 PTQ 的 1/4。如果精度敏感评测集、生产环境优先 QAT。选择建议场景推荐方案理由快速原型验证PTQ不需要训练10 分钟完成生产环境精度敏感QAT精度损失小训练成本可接受显存严重不足PTQ int4权重 4bit显存减半已有训练流水线QAT插入 QAT wrapper 即可改动小跟 ATB 的配合ATB 的 LLM 接口直接支持量化模型fromatbimportLLM# PTQ 量化模型model_ptqLLM(model_quant.onnx,devicenpu:0,quantizew8a8,# 对应 PTQ 的配置)# QAT 量化模型model_qatLLM(model_qat.pt,devicenpu:0,quantizew8a8_qat,)ATB 内部会自动调用对应的量化 GEMM kernel。w8a8 的 GEMM 吞吐是 fp16 的 1.8-2.0×。PTQ 快但精度损失大QAT 慢但精度高。如果你的模型要上生产多花 1-2 天做 QAT 是值得的。PTQ 适合快速验证和显存极度受限的场景。仓库在这里https://atomgit.com/cann/AMCT