TensorRT MHA:输入尽量静态 + sequence length 对齐到 8¶
工程笔记 / 部署踩坑记录 — 2026-06-05
一句话结论 (TL;DR)
在 TensorRT 里跑 multi-head attention (MHA),输入 shape 最好是静态的,且 sequence length 对齐到 8 的倍数。否则可能同时踩到精度和性能两个坑——TRT 的 fused MHA kernel 对静态 shape 和对齐的 seq_len 有偏好,不满足时要么选不到 fused kernel(perf 掉),要么走到一条数值上和 PyTorch 不一致的路径(精度 diff)。
1. 背景:这条信息从哪来¶
今天得到一条经验:TensorRT 中 MHA 的输入最好是静态的,而且 sequence length 要对齐到 8,否则会有精度和 perf 的双重问题。
这不是凭空的规则——它和我们最近在排查的一个真实模型问题对上了号(见下)。先把规则记下来,再用 case 佐证。
2. 实际案例:DriveOS 703 vs 705 精度不一致¶
排查的现象:
- 同一个模型,用 DriveOS 703 版本的 TensorRT 转换之后,推理精度与 PyTorch 一致。
- 换到 DriveOS 705 版本的 TensorRT 转换之后,精度与 PyTorch 不一致了。
- 图、权重、输入都没变,唯一变量是 TensorRT 版本。
根因:模型里有一个 attention 的 seq_len = 900,900 不是 8 的倍数(900 / 8 = 112.5)。
修复:把这个 seq_len 做一下 padding 对齐到 8 的倍数(即 900 → 904,配合 attention mask 屏蔽掉 padding 位),705 上的精度就恢复了,和 PyTorch / 703 重新对齐。
3. 为什么会这样(分析 / 推测)¶
以下是对现象的工程解释,属于分析推测,不是已确证的结论;真正的 kernel 选择细节需要看 builder verbose log / engine inspector 才能坐实。
- TensorRT 对 attention 有专门的 fused MHA (fMHA) kernel(flash-attention 风格,把 QK^T、softmax、PV 融成一个 kernel),吞吐远高于拆开的 matmul + softmax。这些 fused kernel 通常对 seq_len 的对齐有要求(8 的倍数较常见,部分 kernel 要 16 / 64),并且偏好静态 shape——builder 要在构建期就能确定 shape,才能选到并 autotune 到最优的 fused tactic。
- 当 seq_len 不对齐或 shape 动态时,TRT 可能:
- 选不到 fused kernel,回退到更通用的实现 → perf 掉;
- 选到的 kernel 在边界 / padding 处理上和 PyTorch 语义有偏差,或走了不同的累加 / 精度路径 → 数值 diff(精度问题)。
- 为什么"换个版本就不一致":703 → 705 之间 TRT 版本不同,fMHA kernel 库和 tactic 选择启发式都变了。同一张图、同一个
seq_len = 900,703 恰好选到一个数值 OK 的 kernel,705 选到了另一个在 900(非 8 倍数)这个 case 上有问题的 kernel。把 seq_len pad 到 904 之后,命中了对齐 / 正确的 kernel,精度恢复。换句话说——根因是一个潜在的对齐假设,版本升级只是把它暴露出来,900 这个值一直是颗定时炸弹。
4. 实操建议 (checklist)¶
- 静态化 attention 输入:导出 ONNX / 构建 engine 时,把 attention 的
seq_len固定成静态值,避免 attention 子图走 dynamic shape。 - seq_len 对齐到 8 的倍数:必要时更保守地对齐到 16 / 64(取决于目标 kernel),用 attention mask 屏蔽 padding 位保证数值等价。
900 → 904就是这个案例的做法。 - 把 attention 子图当高风险点做版本回归:跨 DriveOS / TensorRT 版本升级时,重点对 attention 子图做逐层数值比对,别只看最终 loss / 端到端指标。
- 用 polygraphy 定位 diff:
polygraphy run model.onnx --trt --onnxrt(必要时加逐层比对 /--validate),找到第一处 diff 落在哪个 attention 节点,避免被下游放大误导。 - 看 builder log 里 fMHA / flash kernel 有没有被选中:没被选中往往是 perf 和精度问题的同源信号——通常意味着 shape 不静态或 seq_len 没对齐。
5. 待确认 / 开放问题¶
- 705 上具体选了哪个 fMHA kernel、和 703 差在哪?可以对比
--verbosebuilder log 或用 engine inspector /trtexec --dumpLayerInfo看 tactic。 - 对齐到 8 是否总是足够,还是这个模型 / 这个 kernel 实际需要 16 / 64?值得在本模型上验证一下对齐粒度的下界。
- padding 之外有没有副作用?尤其确认 softmax 是否对 padding 位做了正确的 mask(mask 不干净会让 padding 通过 softmax 污染有效位)。
- 这条规则的适用边界:是所有 attention 都需要,还是只有命中 fused kernel 的那部分?非 fused 路径(拆开的 matmul + softmax)大概率不受 seq_len 对齐影响,但 perf 也差。
关于这条笔记¶
工程踩坑记录,非论文总结。来源是 2026-06-05 排查 DriveOS TensorRT 版本间精度不一致问题时沉淀的经验 + 对现象的分析。