HOOOS

单卡跑通万级突变:本地轻量化 ESMFold 部署与高通量筛选实战

0 10 BioInformaticsPro ESMFold突变体筛选蛋白质结构预测
Apple

在蛋白质工程和定向进化中,对成百上千个突变体进行结构预测是一项常见的任务。传统的 AlphaFold2 尽管精度极高,但由于需要进行耗时的 MSA(多序列比对)检索,在面对高通量突变体筛选时,算力成本和时间周期往往难以接受。

Meta 开发的 ESMFold 彻底改变了这一现状。它直接利用预训练的蛋白质语言模型(ESM-2)的表征来预测三维结构,免去了 MSA 步骤,使预测速度提升了近 60 倍。

然而,在本地配置 ESMFold 时,许多研究者会遇到原版 Facebook 仓库依赖冲突严重、显存溢出(OOM)、长序列运行缓慢等问题。本文将分享一套基于 Hugging Face transformers 构建的本地轻量化 ESMFold 部署方案,并提供一套可直接用于高通量突变筛选的管道脚本。


一、 环境配置:避开原生 Repo 的“依赖地狱”

ESM 官方早期的 GitHub 仓库(facebookresearch/esm)由于依赖的 PyTorch 和 Fairseq 版本较旧,在如今的 CUDA 12.x 环境下极难配置。

最轻量、最稳定的方式是使用 Hugging Face 维护的 transformers 库。它已经原生集成了 ESMFold,且与现代 PyTorch 2.x 完美兼容。

1. 创建 Conda 虚拟环境

建议使用 Python 3.10 兼顾兼容性与性能:

conda create -n esmfold python=3.10 -y
conda activate esmfold

2. 安装 PyTorch (根据你的 CUDA 版本选择)

如果使用 RTX 3090/4090 或 A100/H800,强烈建议安装支持 CUDA 11.8 或 12.1 的 PyTorch 2.0+,以便启用 PyTorch 2.x 的高效算子和 BF16 精度:

# 以 CUDA 12.1 为例
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

3. 安装依赖库

安装 Hugging Face 核心库、加速库以及生物信息学处理常用库:

pip install transformers accelerate biopython pandas tqdm

二、 核心机制:如何在消费级显卡上降低显存?

ESMFold 默认的参数量巨大(约 30 亿参数),预测 500aa 以上的蛋白时极易撑爆显存。要在本地(如 24G 显存的 RTX 3090/4090)顺畅运行,必须开启以下三个优化:

1. 采用 BFloat16 / Float16 混合精度

ESMFold 的语言模型部分对精度降低不敏感,但 Folding Trunk(结构预测头)在纯 FP16 下可能会出现 NaN(数值溢出)。因此,强烈推荐在支持的 GPU 上使用 bfloat16

2. 启用 Chunking(分块计算)

这是 ESMFold 降低显存的最关键武器。通过限制注意力机制中一次性计算的残基数量,将空间复杂度从 $O(N^2)$ 降低。

model.esmfold.set_chunk_size(64) # 默认是 None,设置为 64 或 128 可以大幅削减显存需求,而精度几乎无损

3. 定期释放显存

在高通量预测中,PyTorch 的显存垃圾回收不够及时,会导致显存积压最终 OOM。在处理完每个突变体后,必须显式调用:

import torch
import gc

gc.collect()
torch.cuda.empty_cache()

三、 高通量突变筛选脚本实现

以下是一个完整的突变体结构预测与筛选脚本。它读取包含突变体序列的 CSV 文件,自动利用本地 ESMFold 进行推理,提取每个突变体的平均 pLDDTpTM 分数(用于评估结构稳定性和置信度),并将高置信度的结构保存为 PDB 文件。

import os
import gc
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, EsmForProteinFolding

# ================= 配置参数 =================
MODEL_DIR = "facebook/esmfold-v1"  # 首次运行会自动从 HF 下载,亦可指向本地路径
INPUT_CSV = "mutants.csv"          # 包含 'mutant_name' 和 'sequence' 两列
OUTPUT_DIR = "./output_pdbs"
CHUNK_SIZE = 64                    # 显存紧张时设为 64,显存充足(>40G)可设为 None 或 128
device = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ================= 1. 加载模型 =================
print("正在加载 ESMFold 模型...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# 推荐在 Ampere 及以上架构 GPU 上使用 bfloat16
model = EsmForProteinFolding.from_pretrained(
    MODEL_DIR, 
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
)

model = model.to(device)
model.eval()

# 开启 Chunking 优化显存
if CHUNK_SIZE is not None:
    model.esmfold.set_chunk_size(CHUNK_SIZE)

# ================= 2. 批量推理函数 =================
def predict_structure(sequence, name):
    inputs = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        
    # 提取评估指标
    # pLDDT:残基级别的置信度 (0-1)
    # pTM:整链拓扑置信度 (0-1)
    token_plddt = outputs.plddt.cpu().numpy()[0]
    # 排除开头和结尾的特殊 Token 对应的值
    mean_plddt = token_plddt.mean() * 100 # 转为 0-100 标准尺度
    ptm = outputs.ptm.cpu().item()
    
    # 转换为 PDB 文本并保存
    pdb_string = model.output_to_pdb(outputs)[0]
    pdb_path = os.path.join(OUTPUT_DIR, f"{name}.pdb")
    with open(pdb_path, "w") as f:
        f.write(pdb_string)
        
    return mean_plddt, ptm

# ================= 3. 执行高通量筛选 =================
df = pd.read_csv(INPUT_CSV)
results = []

print(f"开始筛选,共 {len(df)} 个突变体...")
for idx, row in tqdm(df.iterrows(), total=len(df)):
    name = row['mutant_name']
    seq = row['sequence']
    
    # 过滤掉过长序列,避免突发 OOM
    if len(seq) > 700:
        print(f"\n跳过 {name}: 序列长度 {len(seq)} 超过单卡阈值 700aa")
        results.append({"mutant_name": name, "pLDDT": None, "pTM": None, "status": "Too Long"})
        continue
        
    try:
        plddt, ptm = predict_structure(seq, name)
        results.append({
            "mutant_name": name, 
            "pLDDT": round(plddt, 2), 
            "pTM": round(ptm, 3), 
            "status": "Success"
        })
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"\n{name} 预测时显存溢出,尝试清理显存...")
            results.append({"mutant_name": name, "pLDDT": None, "pTM": None, "status": "OOM"})
        else:
            print(f"\n{name} 预测失败: {str(e)}")
            results.append({"mutant_name": name, "pLDDT": None, "pTM": None, "status": "Error"})
    finally:
        # 强制清理垃圾,稳定显存占用
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()

# ================= 4. 保存筛选报告 =================
df_res = pd.DataFrame(results)
df_res.to_csv("screening_results.csv", index=False)
print("\n筛选完成!结果已保存至 screening_results.csv")

四、 进阶技巧:如何定量评估突变带来的“结构破坏”?

高通量突变体筛选的核心在于对比:突变体与野生型(Wild Type, WT)相比,结构发生了什么变化?

仅仅看 pLDDT 和 pTM 的数值变化是不够的,我们通常需要计算突变体 PDB 与野生型 PDB 之间的 RMSD(均方根偏差)

可以使用 BiopythonSuperimposer 模块实现自动比对和 RMSD 计算,快速筛选出结构发生剧烈变化(可能导致失活)或结构保持高刚性(可能有利于表达)的突变:

from bio.PDB import PDBParser, Superimposer

def calculate_rmsd(wt_pdb_path, mutant_pdb_path):
    parser = PDBParser(QUIET=True)
    ref_struct = parser.get_structure("WT", wt_pdb_path)
    sample_struct = parser.get_structure("Mutant", mutant_pdb_path)
    
    # 提取 C-alpha 原子进行对齐
    ref_atoms = [atom for atom in ref_struct.get_atoms() if atom.get_name() == "CA"]
    sample_atoms = [atom for atom in sample_struct.get_atoms() if atom.get_name() == "CA"]
    
    # 确保原子数量一致(适用于单点或多点置换突变,不适用于插入/缺失突变)
    if len(ref_atoms) == len(sample_atoms):
        super_imposer = Superimposer()
        super_imposer.set_atoms(ref_atoms, sample_atoms)
        super_imposer.apply(sample_struct.get_atoms())
        return super_imposer.rms
    else:
        return None  # 长度不一致时需使用局部比对算法

五、 避坑指南与硬件性能参考

  1. 多长的序列可以跑?
    在 24G 显存(RTX 3090/4090)下,开启 chunk_size=64 且使用 bfloat16 后,极限可以跑 700-800aa 左右的单链蛋白。再长的序列建议切片或者租用 80G 的 A100。
  2. 多进程能加速吗?
    对于 ESMFold 这类重型模型,单张显卡不建议开多进程推理。因为显存分配是动态的,多进程极易导致瞬间的 OOM。建议采用串行 Batch 预测,或者多卡分布式推理(每个卡独立分配一个子序列集)。
  3. 本地权重缓存:
    如果你在内网集群等无网环境下运行,可以先在外网将 facebook/esmfold-v1 模型文件夹打包下载下来,然后将脚本中的 MODEL_DIR 修改为本地的绝对路径。

点评评价

captcha
健康