在生产环境中部署 BERT、GPT 等 NLP 模型时,我们几乎都会开启 Triton Inference Server 的 Dynamic Batching(动态批处理)。这个功能很香,能把多个客户端的单条请求攒成一个 Batch 发给 GPU,榨干算力。
但很快你就会遇到一个堪称“显存刺客”的噩梦:过度 Padding。
假设你的模型最大支持 512 长度。某一个瞬间,Triton 攒了一个 Batch,里面有 15 条长度只有 10 的短文本,但偏偏混进来 1 条长度为 512 的长文本。
在默认的动态批处理机制下,Triton 为了拼出规整的 2D Tensor [BatchSize, MaxSeqLen],会强制把那 15 条短序列全部 Padding(补零)到 512 长度。
这时候,你的 Batch 矩阵从原本只需计算的少量 Token,瞬间膨胀成了 $16 \times 512$ 的庞然大物。不仅计算耗时(Latency)因为无意义的补零暴涨,显存也会因为 Attention 矩阵中 $O(N^2)$ 的空间复杂度瞬间拉爆,直接触发 OOM。
怎么优雅地解决这个问题?我们有三个递进的解法,从 Triton 原生配置到架构设计。
终极解法:启用 Ragged Batching(非对称批处理)
这是 Triton 官方为了彻底解决动态长度输入而设计的硬核特性。
核心思想:既然把变长输入拼成 2D 矩形会引入大量 Padding,那我们干脆不拼 2D。我们把所有请求的 Token 展平(Flatten),拼成一个 1D 的大向量,再配合一个记录每条请求实际长度的辅助 Tensor(batch_input),一起塞给后端。
这就叫 Ragged Batching。
1. 怎么配置 config.pbtxt?
假设你的模型输入是 INPUT_IDS,普通的配置是 dims: [ -1 ] 且开启 max_batch_size。要启用 Ragged Batching,你需要两步:
- 在输入张量中声明
allow_ragged_batch: true。 - 引入
batch_input,让 Triton 自动为你生成每条数据的边界信息。
name: "nlp_ragged_model"
platform: "onnxruntime_onnx" # 或者 tensorrt_llm / pytorch
max_batch_size: 32
input [
{
name: "INPUT_IDS"
data_type: TYPE_INT32
dims: [ -1 ]
# 核心:允许不规则 Batch
allow_ragged_batch: true
}
]
# Triton 会自动拼出这个 Tensor,告诉你这个 Batch 里每个样本的真实长度
batch_input [
{
kind: BATCH_ELEMENT_COUNT
target_name: "SEQ_LEN"
data_type: TYPE_INT32
}
]
2. 后端(Backend)如何对接?
启用 Ragged Batching 后,如果 Batch 内有 3 条请求,实际长度分别是 3, 5, 2:
- 原来你收到的张量形状是
[3, 5](全部 Pad 到 5),内容包含大量的补零。 - 现在你的模型收到的是一个展平的 1D Tensor,形状为
[10](即 3+5+2),没有任何 Padding。 - 同时,你会收到一个名为
SEQ_LEN的辅助 Tensor,内容为[3, 5, 2]。
注意:这需要你的模型本身支持这种“展平式”输入。
如果你使用的是 TensorRT-LLM 或者配合了 FlashAttention 的自定义 PyTorch/C++ Backend,它们原生就支持这种输入(通常称为 VarLen 或 Unpadded 输入)。它们在算子内部通过 SEQ_LEN 来计算每个序列的 Attention,彻底规避了对 Padding 区域的无效计算和显存占用。
降维打击:网关层/客户端分桶(Bucketing)
如果你的模型后端比较老(比如普通的 ONNX Runtime),不支持 Ragged 这种 1D 展平输入,必须吃 [B, S] 这种 2D Padded 结构,该怎么办?
这时候,最好的解法是分桶(Bucketing)。
不要让长度差异极大的请求混在同一个 Triton 队列里。我们可以在 Triton 前面挂一个轻量级的代理(比如基于 Python/Go 的 API 网关,或者使用 Triton 的 BLS 编排),根据输入文本的 Token 长度,将请求分流到不同的 Triton 模型实例(或者不同的模型队列)中。
典型分桶策略:
- 桶 A (Short):长度 $0 \sim 64$
- 桶 B (Medium):长度 $65 \sim 256$
- 桶 C (Long):长度 $257 \sim 512$
在 Triton 中,你可以部署三个一模一样的模型实例(或者在 config.pbtxt 中通过配置不同的模型版本、不同名字的模型),分别对应这三个桶。
┌─── [Len 0-64] ───> Triton Model_Short (Max_Len: 64)
│
Client ──> Router ──┼─── [Len 65-256] ──> Triton Model_Medium (Max_Len: 256)
│
└─── [Len 257-512] ──> Triton Model_Long (Max_Len: 512)
为什么这样能省显存?
即使发生了动态攒 Batch,由于进到“桶 A”的请求长度都在 64 以内,Triton 最多也只会把它们 Pad 到 64,绝对不会出现因为一条长文本把整个 Batch 撑到 512 的情况。
这种方案虽然牺牲了一点点部署复杂度,但对不支持 Ragged Batching 的老模型来说,是性价比最高的保命手段。
细节微调:榨干 Dynamic Batcher 的参数
如果你不想改代码,也不想改架构,只想调参,那么在 config.pbtxt 中,有几个参数必须精细化控制:
dynamic_batching {
# 1. 限制首选 Batch 大小
preferred_batch_size: [ 8, 16, 32 ]
# 2. 控制攒批的最大等待时间
max_queue_delay_microseconds: 5000
}
这里有一个权衡(Trade-off):
max_queue_delay_microseconds不能设得太大。
如果你的 QPS 不够高,而你把这个值设得太大(比如 50ms),Triton 会为了凑齐preferred_batch_size强行等待。等待的时间越长,攒出来的请求长度跨度可能就越大,越容易发生“短文本被极少数长文本强行拉大”的惨剧。- 合理降低
max_batch_size。
在 NLP 变长场景下,一味追求大 BatchSize 是危险的。建议针对长文本模型,将max_batch_size限制在一个保守的值(如 8 或 16),优先保证 Latency 的稳定,防止单次大 Batch 运算超时或直接 OOM。
总结:你的技术栈该怎么选?
- 新项目/自研 LLM/高并发:毫不犹豫选择 Ragged Batching,配合 TensorRT-LLM 或支持变长 Attention 的后端。这是目前业界性能天花板的解法。
- 传统老项目/ONNX/PyTorch 原生导出:采用 Gateway 分桶路由。虽然架构变复杂了,但对模型代码零侵入,防 OOM 效果立竿见影。
- 紧急救火:立刻调小
max_queue_delay_microseconds和max_batch_size,牺牲一点点吞吐量,换取系统不崩溃。