HOOOS

用好 Ragged Batching,解决 Triton 动态批处理中的 NLP 显存爆炸

0 24 极客推理 TritonNLP推理显存优化
Apple

在生产环境中部署 BERT、GPT 等 NLP 模型时,我们几乎都会开启 Triton Inference Server 的 Dynamic Batching(动态批处理)。这个功能很香,能把多个客户端的单条请求攒成一个 Batch 发给 GPU,榨干算力。

但很快你就会遇到一个堪称“显存刺客”的噩梦:过度 Padding

假设你的模型最大支持 512 长度。某一个瞬间,Triton 攒了一个 Batch,里面有 15 条长度只有 10 的短文本,但偏偏混进来 1 条长度为 512 的长文本。
在默认的动态批处理机制下,Triton 为了拼出规整的 2D Tensor [BatchSize, MaxSeqLen],会强制把那 15 条短序列全部 Padding(补零)到 512 长度。

这时候,你的 Batch 矩阵从原本只需计算的少量 Token,瞬间膨胀成了 $16 \times 512$ 的庞然大物。不仅计算耗时(Latency)因为无意义的补零暴涨,显存也会因为 Attention 矩阵中 $O(N^2)$ 的空间复杂度瞬间拉爆,直接触发 OOM。

怎么优雅地解决这个问题?我们有三个递进的解法,从 Triton 原生配置到架构设计。


终极解法:启用 Ragged Batching(非对称批处理)

这是 Triton 官方为了彻底解决动态长度输入而设计的硬核特性。

核心思想:既然把变长输入拼成 2D 矩形会引入大量 Padding,那我们干脆不拼 2D。我们把所有请求的 Token 展平(Flatten),拼成一个 1D 的大向量,再配合一个记录每条请求实际长度的辅助 Tensor(batch_input),一起塞给后端。

这就叫 Ragged Batching

1. 怎么配置 config.pbtxt?

假设你的模型输入是 INPUT_IDS,普通的配置是 dims: [ -1 ] 且开启 max_batch_size。要启用 Ragged Batching,你需要两步:

  • 在输入张量中声明 allow_ragged_batch: true
  • 引入 batch_input,让 Triton 自动为你生成每条数据的边界信息。
name: "nlp_ragged_model"
platform: "onnxruntime_onnx" # 或者 tensorrt_llm / pytorch
max_batch_size: 32

input [
  {
    name: "INPUT_IDS"
    data_type: TYPE_INT32
    dims: [ -1 ]
    # 核心:允许不规则 Batch
    allow_ragged_batch: true
  }
]

# Triton 会自动拼出这个 Tensor,告诉你这个 Batch 里每个样本的真实长度
batch_input [
  {
    kind: BATCH_ELEMENT_COUNT
    target_name: "SEQ_LEN"
    data_type: TYPE_INT32
  }
]

2. 后端(Backend)如何对接?

启用 Ragged Batching 后,如果 Batch 内有 3 条请求,实际长度分别是 3, 5, 2:

  • 原来你收到的张量形状是 [3, 5](全部 Pad 到 5),内容包含大量的补零。
  • 现在你的模型收到的是一个展平的 1D Tensor,形状为 [10](即 3+5+2),没有任何 Padding
  • 同时,你会收到一个名为 SEQ_LEN 的辅助 Tensor,内容为 [3, 5, 2]

注意:这需要你的模型本身支持这种“展平式”输入。
如果你使用的是 TensorRT-LLM 或者配合了 FlashAttention 的自定义 PyTorch/C++ Backend,它们原生就支持这种输入(通常称为 VarLen 或 Unpadded 输入)。它们在算子内部通过 SEQ_LEN 来计算每个序列的 Attention,彻底规避了对 Padding 区域的无效计算和显存占用。


降维打击:网关层/客户端分桶(Bucketing)

如果你的模型后端比较老(比如普通的 ONNX Runtime),不支持 Ragged 这种 1D 展平输入,必须吃 [B, S] 这种 2D Padded 结构,该怎么办?

这时候,最好的解法是分桶(Bucketing)

不要让长度差异极大的请求混在同一个 Triton 队列里。我们可以在 Triton 前面挂一个轻量级的代理(比如基于 Python/Go 的 API 网关,或者使用 Triton 的 BLS 编排),根据输入文本的 Token 长度,将请求分流到不同的 Triton 模型实例(或者不同的模型队列)中。

典型分桶策略:

  • 桶 A (Short):长度 $0 \sim 64$
  • 桶 B (Medium):长度 $65 \sim 256$
  • 桶 C (Long):长度 $257 \sim 512$

在 Triton 中,你可以部署三个一模一样的模型实例(或者在 config.pbtxt 中通过配置不同的模型版本、不同名字的模型),分别对应这三个桶。

                    ┌─── [Len 0-64] ───> Triton Model_Short (Max_Len: 64)
                    │
Client ──> Router ──┼─── [Len 65-256] ──> Triton Model_Medium (Max_Len: 256)
                    │
                    └─── [Len 257-512] ──> Triton Model_Long (Max_Len: 512)

为什么这样能省显存?
即使发生了动态攒 Batch,由于进到“桶 A”的请求长度都在 64 以内,Triton 最多也只会把它们 Pad 到 64,绝对不会出现因为一条长文本把整个 Batch 撑到 512 的情况。
这种方案虽然牺牲了一点点部署复杂度,但对不支持 Ragged Batching 的老模型来说,是性价比最高的保命手段。


细节微调:榨干 Dynamic Batcher 的参数

如果你不想改代码,也不想改架构,只想调参,那么在 config.pbtxt 中,有几个参数必须精细化控制:

dynamic_batching {
  # 1. 限制首选 Batch 大小
  preferred_batch_size: [ 8, 16, 32 ]
  
  # 2. 控制攒批的最大等待时间
  max_queue_delay_microseconds: 5000
}

这里有一个权衡(Trade-off):

  • max_queue_delay_microseconds 不能设得太大。
    如果你的 QPS 不够高,而你把这个值设得太大(比如 50ms),Triton 会为了凑齐 preferred_batch_size 强行等待。等待的时间越长,攒出来的请求长度跨度可能就越大,越容易发生“短文本被极少数长文本强行拉大”的惨剧。
  • 合理降低 max_batch_size
    在 NLP 变长场景下,一味追求大 BatchSize 是危险的。建议针对长文本模型,将 max_batch_size 限制在一个保守的值(如 8 或 16),优先保证 Latency 的稳定,防止单次大 Batch 运算超时或直接 OOM。

总结:你的技术栈该怎么选?

  1. 新项目/自研 LLM/高并发:毫不犹豫选择 Ragged Batching,配合 TensorRT-LLM 或支持变长 Attention 的后端。这是目前业界性能天花板的解法。
  2. 传统老项目/ONNX/PyTorch 原生导出:采用 Gateway 分桶路由。虽然架构变复杂了,但对模型代码零侵入,防 OOM 效果立竿见影。
  3. 紧急救火:立刻调小 max_queue_delay_microsecondsmax_batch_size,牺牲一点点吞吐量,换取系统不崩溃。

点评评价

captcha
健康