在蛋白质工程和定向进化中,对成百上千个突变体进行结构预测是一项常见的任务。传统的 AlphaFold2 尽管精度极高,但由于需要进行耗时的 MSA(多序列比对)检索,在面对高通量突变体筛选时,算力成本和时间周期往往难以接受。
Meta 开发的 ESMFold 彻底改变了这一现状。它直接利用预训练的蛋白质语言模型(ESM-2)的表征来预测三维结构,免去了 MSA 步骤,使预测速度提升了近 60 倍。
然而,在本地配置 ESMFold 时,许多研究者会遇到原版 Facebook 仓库依赖冲突严重、显存溢出(OOM)、长序列运行缓慢等问题。本文将分享一套基于 Hugging Face transformers 构建的本地轻量化 ESMFold 部署方案,并提供一套可直接用于高通量突变筛选的管道脚本。
一、 环境配置:避开原生 Repo 的“依赖地狱”
ESM 官方早期的 GitHub 仓库(facebookresearch/esm)由于依赖的 PyTorch 和 Fairseq 版本较旧,在如今的 CUDA 12.x 环境下极难配置。
最轻量、最稳定的方式是使用 Hugging Face 维护的 transformers 库。它已经原生集成了 ESMFold,且与现代 PyTorch 2.x 完美兼容。
1. 创建 Conda 虚拟环境
建议使用 Python 3.10 兼顾兼容性与性能:
conda create -n esmfold python=3.10 -y
conda activate esmfold
2. 安装 PyTorch (根据你的 CUDA 版本选择)
如果使用 RTX 3090/4090 或 A100/H800,强烈建议安装支持 CUDA 11.8 或 12.1 的 PyTorch 2.0+,以便启用 PyTorch 2.x 的高效算子和 BF16 精度:
# 以 CUDA 12.1 为例
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
3. 安装依赖库
安装 Hugging Face 核心库、加速库以及生物信息学处理常用库:
pip install transformers accelerate biopython pandas tqdm
二、 核心机制:如何在消费级显卡上降低显存?
ESMFold 默认的参数量巨大(约 30 亿参数),预测 500aa 以上的蛋白时极易撑爆显存。要在本地(如 24G 显存的 RTX 3090/4090)顺畅运行,必须开启以下三个优化:
1. 采用 BFloat16 / Float16 混合精度
ESMFold 的语言模型部分对精度降低不敏感,但 Folding Trunk(结构预测头)在纯 FP16 下可能会出现 NaN(数值溢出)。因此,强烈推荐在支持的 GPU 上使用 bfloat16。
2. 启用 Chunking(分块计算)
这是 ESMFold 降低显存的最关键武器。通过限制注意力机制中一次性计算的残基数量,将空间复杂度从 $O(N^2)$ 降低。
model.esmfold.set_chunk_size(64) # 默认是 None,设置为 64 或 128 可以大幅削减显存需求,而精度几乎无损
3. 定期释放显存
在高通量预测中,PyTorch 的显存垃圾回收不够及时,会导致显存积压最终 OOM。在处理完每个突变体后,必须显式调用:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()
三、 高通量突变筛选脚本实现
以下是一个完整的突变体结构预测与筛选脚本。它读取包含突变体序列的 CSV 文件,自动利用本地 ESMFold 进行推理,提取每个突变体的平均 pLDDT 和 pTM 分数(用于评估结构稳定性和置信度),并将高置信度的结构保存为 PDB 文件。
import os
import gc
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, EsmForProteinFolding
# ================= 配置参数 =================
MODEL_DIR = "facebook/esmfold-v1" # 首次运行会自动从 HF 下载,亦可指向本地路径
INPUT_CSV = "mutants.csv" # 包含 'mutant_name' 和 'sequence' 两列
OUTPUT_DIR = "./output_pdbs"
CHUNK_SIZE = 64 # 显存紧张时设为 64,显存充足(>40G)可设为 None 或 128
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ================= 1. 加载模型 =================
print("正在加载 ESMFold 模型...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
# 推荐在 Ampere 及以上架构 GPU 上使用 bfloat16
model = EsmForProteinFolding.from_pretrained(
MODEL_DIR,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
)
model = model.to(device)
model.eval()
# 开启 Chunking 优化显存
if CHUNK_SIZE is not None:
model.esmfold.set_chunk_size(CHUNK_SIZE)
# ================= 2. 批量推理函数 =================
def predict_structure(sequence, name):
inputs = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# 提取评估指标
# pLDDT:残基级别的置信度 (0-1)
# pTM:整链拓扑置信度 (0-1)
token_plddt = outputs.plddt.cpu().numpy()[0]
# 排除开头和结尾的特殊 Token 对应的值
mean_plddt = token_plddt.mean() * 100 # 转为 0-100 标准尺度
ptm = outputs.ptm.cpu().item()
# 转换为 PDB 文本并保存
pdb_string = model.output_to_pdb(outputs)[0]
pdb_path = os.path.join(OUTPUT_DIR, f"{name}.pdb")
with open(pdb_path, "w") as f:
f.write(pdb_string)
return mean_plddt, ptm
# ================= 3. 执行高通量筛选 =================
df = pd.read_csv(INPUT_CSV)
results = []
print(f"开始筛选,共 {len(df)} 个突变体...")
for idx, row in tqdm(df.iterrows(), total=len(df)):
name = row['mutant_name']
seq = row['sequence']
# 过滤掉过长序列,避免突发 OOM
if len(seq) > 700:
print(f"\n跳过 {name}: 序列长度 {len(seq)} 超过单卡阈值 700aa")
results.append({"mutant_name": name, "pLDDT": None, "pTM": None, "status": "Too Long"})
continue
try:
plddt, ptm = predict_structure(seq, name)
results.append({
"mutant_name": name,
"pLDDT": round(plddt, 2),
"pTM": round(ptm, 3),
"status": "Success"
})
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f"\n{name} 预测时显存溢出,尝试清理显存...")
results.append({"mutant_name": name, "pLDDT": None, "pTM": None, "status": "OOM"})
else:
print(f"\n{name} 预测失败: {str(e)}")
results.append({"mutant_name": name, "pLDDT": None, "pTM": None, "status": "Error"})
finally:
# 强制清理垃圾,稳定显存占用
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# ================= 4. 保存筛选报告 =================
df_res = pd.DataFrame(results)
df_res.to_csv("screening_results.csv", index=False)
print("\n筛选完成!结果已保存至 screening_results.csv")
四、 进阶技巧:如何定量评估突变带来的“结构破坏”?
高通量突变体筛选的核心在于对比:突变体与野生型(Wild Type, WT)相比,结构发生了什么变化?
仅仅看 pLDDT 和 pTM 的数值变化是不够的,我们通常需要计算突变体 PDB 与野生型 PDB 之间的 RMSD(均方根偏差)。
可以使用 Biopython 的 Superimposer 模块实现自动比对和 RMSD 计算,快速筛选出结构发生剧烈变化(可能导致失活)或结构保持高刚性(可能有利于表达)的突变:
from bio.PDB import PDBParser, Superimposer
def calculate_rmsd(wt_pdb_path, mutant_pdb_path):
parser = PDBParser(QUIET=True)
ref_struct = parser.get_structure("WT", wt_pdb_path)
sample_struct = parser.get_structure("Mutant", mutant_pdb_path)
# 提取 C-alpha 原子进行对齐
ref_atoms = [atom for atom in ref_struct.get_atoms() if atom.get_name() == "CA"]
sample_atoms = [atom for atom in sample_struct.get_atoms() if atom.get_name() == "CA"]
# 确保原子数量一致(适用于单点或多点置换突变,不适用于插入/缺失突变)
if len(ref_atoms) == len(sample_atoms):
super_imposer = Superimposer()
super_imposer.set_atoms(ref_atoms, sample_atoms)
super_imposer.apply(sample_struct.get_atoms())
return super_imposer.rms
else:
return None # 长度不一致时需使用局部比对算法
五、 避坑指南与硬件性能参考
- 多长的序列可以跑?
在 24G 显存(RTX 3090/4090)下,开启chunk_size=64且使用bfloat16后,极限可以跑 700-800aa 左右的单链蛋白。再长的序列建议切片或者租用 80G 的 A100。 - 多进程能加速吗?
对于 ESMFold 这类重型模型,单张显卡不建议开多进程推理。因为显存分配是动态的,多进程极易导致瞬间的 OOM。建议采用串行Batch预测,或者多卡分布式推理(每个卡独立分配一个子序列集)。 - 本地权重缓存:
如果你在内网集群等无网环境下运行,可以先在外网将facebook/esmfold-v1模型文件夹打包下载下来,然后将脚本中的MODEL_DIR修改为本地的绝对路径。