AI语音降噪 - DeepFilterNet 训练设计技巧总结
1. 随机种子固定(可复现性)
seed = config("SEED", 42, int, section="train")
check_manual_seed(seed)
设计作用:check_manual_seed 同时固定 Python random、numpy、torch(CPU 和 CUDA)的随机种子,确保每次训练的数据增强、参数初始化、dropout mask 等随机行为完全一致,使实验结果可复现。
启发:
- 所有训练脚本应将 seed 作为配置项(而非硬编码),便于对比不同 seed 下模型的稳定性
- seed 应写入 checkpoint,续点训练时可恢复完全相同的随机状态
- 可复现性是科研和生产的基础保障,尤其在多次消融实验时,控制 seed 才能排除随机因素的干扰
2. 主机特定批大小配置(多机适配)
key = get_host() + "_" + config.get("model") + "_" + config.get("fft_size")
set_batch_size(config_file, args.host_batchsize_config, host_key=key)
设计作用:
不同服务器 GPU 显存不同,对同一模型的最优 batch size 也不同。该机制通过 hostname_modelname_fftsize 作为 key,在外部 JSON 配置中查找当前机器的最优 batch size,动态覆盖默认配置,无需修改代码即可适配多台机器。
启发:
- 训练配置应分为"模型配置"和"机器配置"两层,机器配置随环境变化,不应混入模型配置
- 在多人协作或多机环境中,这种机制避免了因忘记修改 batch size 导致的显存 OOM 或显存浪费
- 可扩展为更通用的"环境感知配置":自动检测 GPU 数量、显存大小,自动推荐超参
3. Batch Size 动态调度(渐进式训练)
# 格式: '0/4,5/8,10/16' -> epoch0: bs=4, epoch5: bs=8, epoch10: bs=16
batch_size_scheduling: List[str] = config("BATCH_SIZE_SCHEDULING", [], Csv(str))
for e, b in batch_size_scheduling:
if e <= epoch:
scheduling_bs = min(b, bs) # 不超过配置上限
if prev_scheduling_bs != scheduling_bs:
dataloader.set_batch_size(scheduling_bs, "train")
lrs = setup_lrs(len(dataloader)) # 重新计算 lr schedule!
wds = setup_wds(len(dataloader))
设计作用:
训练初期用小 batch size,模型更新更频繁,有助于跳出局部最优、加快收敛;训练后期增大 batch size,梯度估计更稳定,有助于精细收敛。同时,batch size 变化会改变 epoch 内的 step 数,因此必须重新计算 lr/wd schedule 数组,保证学习率曲线在时间轴上的形状不变。
启发:
- Batch size 和 learning rate 是强耦合的(线性缩放规则),调整任一个都需联动调整另一个
- 渐进式 batch size 在数据量较少时尤其有效,可以用小 batch 先探索再用大 batch 精化
- 动态调度参数应以人类可读的字符串格式配置(如
0/4,5/8),降低使用门槛
4. SIGUSR1 信号处理(优雅超时退出)
signal.signal(signal.SIGUSR1, get_sigusr1_handler(args.base_dir))
def get_sigusr1_handler(base_dir):
def h(*args):
global should_stop
should_stop = True # 标记停止
open(os.path.join(base_dir, "continue"), "w").close() # 写 continue 文件
return h
在每个 epoch 末检查:
if should_stop:
logger.info("Stopping training due to timeout")
exit(0)
设计作用:
在 HPC 集群(如 SLURM)中,作业有时间配额限制,到时会被强制 kill。SIGUSR1 信号通常作为"即将超时"的预警信号提前发出。捕获后设置 should_stop=True,等当前 epoch 训练完毕再安全退出——此时 checkpoint 已写入磁盘,不会丢失进度。同时写入 continue 文件,供 job scheduler 判断是否需要重新提交任务继续训练。
启发:
- 长时间训练必须处理被动中断,至少要能保证"当前 epoch 结束后安全保存"
should_stop标志比直接exit()更安全,让代码在自然检查点处停止- 对于云训练环境(抢占式实例),同样的机制可用来处理 SIGTERM
5. 分模块训练(mask_only / df_only)
mask_only: bool = config("MASK_ONLY", False, bool, section="train")
train_df_only: bool = config("DF_ONLY", False, bool, section="train")
if mask_only:
params = [p for n, p in model.named_parameters()
if "dfrnn" not in n and "df_dec" not in n]
elif df_only:
params = (p for n, p in model.named_parameters() if "df" in n.lower())
else:
params = model.parameters()
设计作用:
DeepFilterNet 由两个主要分支组成:ERB mask 分支(粗粒度噪声抑制)和 DF deep filter 分支(细粒度频谱整形)。分模块训练允许:
- 先单独训练某个分支,让它充分收敛,再联合训练
- 当某分支已有预训练权重时,仅微调另一个分支
- 节省计算量(只计算需要更新参数的梯度)
这是一种**课程学习(Curriculum Learning)**思想的工程实现。
启发:
- 复杂模型(多分支、多任务)应设计模块化训练接口,通过参数名过滤实现精细控制
- 训练策略(全部/部分参数)应作为配置项而非代码修改
- 分阶段训练时需注意学习率的匹配(已固定的分支有时需要极小 lr 防止被带偏)
6. TorchScript JIT 编译
jit = config("JIT", False, cast=bool, section="train")
# 注意:必须在 log_model_summary 之后再 jit
if jit:
model = torch.jit.script(model)
设计作用:torch.jit.script 将 Python 代码编译为静态计算图,消除 Python 解释器开销、使能算子融合,可加速训练和推理。同时 JIT 模型可序列化为 .pt 文件,在不依赖 Python 环境的情况下(如 C++ 推理框架)运行。代码中刻意在 log_model_summary 之后才执行 JIT,因为 JIT 模型的接口与普通 Module 不同,无法正常打印结构。
启发:
- JIT 应作为可选项,在调试阶段关闭(便于断点调试),在生产/测试阶段开启
- 对于计算密集型模型,JIT 可带来 10~30% 的推理加速,值得在性能敏感场景使用
- 须注意 JIT 对 Python 动态特性(如条件导入、类型动态推断)有严格限制,需提前验证模型兼容性
7. 余弦退火 LR/WD 逐步调度
# 预先计算所有 step 的 lr 值(numpy 数组,按全局 step 索引)
lrs = setup_lrs(steps_per_epoch) # shape: [total_steps]
wds = setup_wds(steps_per_epoch) # shape: [total_steps] or None
# 每个 step 内更新
it = start_steps + i # 全局 step 编号
param_group["lr"] = lr_scheduler_values[it] * param_group.get("lr_scale", 1)
param_group["weight_decay"] = wd_scheduler_values[it]
setup_lrs 使用带 warmup 的余弦退火,支持多周期:
lr
▲
| warmup | cosine decay | cycle 2 |
| \ /\ /
| \ / \ /
| \____________________/ \_____/
└────────────────────────────────────────────→ step
setup_wds 同样用余弦曲线从 weight_decay 衰减到 weight_decay_end(若不配置则不启用)。
启发:
- 逐 step 调度比逐 epoch 调度更平滑,对小数据集或 batch 数少的训练尤其重要
- 预计算成 numpy 数组再按 index 取值,比每步调用 scheduler.step() 更灵活,可在 batch size 变化时重新生成,且方便可视化和调试
lr_scale支持差分学习率(不同参数组使用不同倍率),是微调预训练模型的常用技巧- weight decay 的余弦衰减有助于训练后期减少正则化强度,让模型充分拟合
8. 早停机制(Early Stopping)
val_criteria_type = config("VALIDATION_CRITERIA", "loss", section="train")
val_criteria_rule = config("VALIDATION_CRITERIA_RULE", "min", section="train")
patience = config("EARLY_STOPPING_PATIENCE", 5, int, section="train")
val_criteria = metrics[val_criteria_type]
write_cp(model, "model", checkpoint_dir, epoch+1,
metric=val_criteria, cmp=val_criteria_rule) # 只保存最优
if not check_patience(checkpoint_dir, max_patience=patience,
new_metric=val_criteria, cmp=val_criteria_rule,
raise_=False):
break
设计作用:
若验证集指标连续 patience 个 epoch 没有改善(支持 min/max 两种方向),则停止训练,防止过拟合和浪费计算资源。验证指标支持任意 metric(loss、PESQ、DNSMOS 等),由配置决定,不耦合于 loss。write_cp 中带 metric 参数时只在指标更优时覆盖"best"检查点,训练结束后加载的就是历史最优模型。
启发:
- 早停的判断指标应与最终评估目标一致(如语音增强用 DNSMOS 而非 loss)
patience值需根据学习率曲线设置:使用余弦退火时 lr 会周期性回升,patience 需足够大- 最优检查点保存逻辑应与每轮保存逻辑分离,避免覆盖"最新"和"最优"混淆
9. 训练前先跑验证集(START_EVAL 基准线)
if config("START_EVAL", False, cast=bool, section="train"):
val_loss = run_epoch(model, epoch=epoch-1, split="valid", ...)
log_metrics(f"[{epoch-1}] [valid]", metrics)
设计作用:
在正式训练第 0 epoch 之前,先跑一遍验证集,得到初始模型(随机初始化或预加载权重)的基准性能。这样:
- 训练曲线从 epoch -1 开始,可以直观看出训练带来的提升量
- 若基准性能已经很好(如加载了预训练权重),可以判断是否需要训练
- 提前暴露数据 pipeline、loss 计算的 bug,而不是等到第一轮训练完再发现问题
启发:
- 任何训练都应建立基准线,尤其是 finetune 或迁移学习场景
- "先验证再训练"可以节省时间:如果初始模型已满足要求,无需训练
- 验证集 run 同时也是对整个 pipeline(数据加载、模型推理、指标计算)的端到端冒烟测试
10. 断点续训设计
检查点保存(每轮 + 最优):
write_cp(model, "model", checkpoint_dir, epoch + 1) # 每轮保存模型
write_cp(opt, "opt", checkpoint_dir, epoch + 1) # 每轮保存优化器状态
write_cp(model, "model", checkpoint_dir, epoch + 1,
metric=val_criteria, cmp=val_criteria_rule) # 条件保存最优模型
续点加载:
# 加载模型(含 epoch 号)
model, epoch = load_model(checkpoint_dir if args.resume else None, ...)
# 加载优化器(含 Adam 矩估计 m1/m2)
read_cp(opt, "opt", cp_dir)
# 从上次停止的 epoch 继续
for epoch in range(epoch, max_epochs):
...
设计亮点:
--no-resume标志可显式从头开始,不依赖是否有 checkpoint 文件- 优化器状态(Adam 的一阶/二阶矩)也被保存,续训时学习率动量状态延续,不会出现"重新热身"现象
for epoch in range(epoch, max_epochs)利用 epoch 变量天然实现从断点位置继续,无需额外逻辑
启发:
- 优化器状态和模型权重必须同时保存,否则续训效果不如从头训练
- 每轮都保存(而非只保存最优)可以防止最优 checkpoint 损坏时无法恢复
- 建议同时保存训练配置文件,确保续训时超参与原始训练完全一致
11. 核心训练循环(主 epoch 循环)
flowchart TD
A[开始 epoch 循环] --> B{有 batch_size 调度?}
B -- 是 --> C[更新 batch_size\n联动重算 lrs/wds]
B -- 否 --> D
C --> D[run_epoch: split=train\n传入 lr/wd 调度数组]
D --> E[记录 train 指标\nlog_metrics]
E --> F[write_cp: 保存模型+优化器\n每轮均保存]
F --> G[run_epoch: split=valid\n不传 lr/wd 调度]
G --> H[write_cp: 若指标更优则保存 best]
H --> I[log_metrics 验证指标]
I --> J{check_patience\n连续N轮无提升?}
J -- 继续训练 --> K{should_stop?\nSIGUSR1触发?}
J -- 触发早停 --> L[break 退出循环]
K -- 是 --> M[exit 优雅退出]
K -- 否 --> N[losses.reset_summaries]
N --> A
L --> O[load_model best checkpoint]
O --> P[run_epoch: split=test]
P --> Q[记录测试指标 结束]
设计要点:
- 训练和验证
run_epoch共用同一函数,通过split参数区分行为 losses.reset_summaries()在每个 train/valid 边界调用,确保日志不跨阶段积累- 训练结束后自动加载 best checkpoint 跑测试,保证测试用的是最优模型而非最后一轮模型
12. run_epoch 具体实现
flowchart TD
A[run_epoch 开始] --> B[初始化\nis_train / model.train / seed]
B --> C[for i, batch in dataloader]
C --> D[opt.zero_grad]
D --> E{有 lr/wd 调度?}
E -- 是 --> F[按全局 step it\n更新 param_group lr/wd]
E -- 否 --> G
F --> G[数据移至 device]
G --> H[set_grad_enabled is_train\nmodel.forward]
H --> I[losses.forward 计算 loss]
I --> J{loss 含 NaN?}
J -- 是 --> K[跳过本 batch\nn_nans++]
K --> C
J -- 否 --> L{is_train?}
L -- 是 --> M[err.backward\nclip_grad_norm_\nopt.step]
M --> N{梯度 NaN?}
N -- 是 --> O[保存 NaN 样本音频\n跳过本 batch]
O --> C
N -- 否 --> P
L -- 否 --> P
P[detach_hidden model] --> Q[l_mem.append err]
Q --> R{i % log_freq?}
R -- 是 --> S[log_metrics\nsummary_write 保存音频]
R -- 否 --> C
S --> C
C -- epoch结束 --> T[cleanup 释放显存]
T --> U[return mean loss]
设计要点:
| 机制 | 实现 | 目的 |
|---|---|---|
| NaN 容忍 | n_nans 计数 + continue,超过 MAX_NANS 才 raise |
避免偶发 NaN 中断训练 |
| 梯度裁剪 | clip_grad_norm_(..., 1.0) |
防止梯度爆炸,稳定 RNN 训练 |
| NaN 样本记录 | 保存 NaN 时的音频到 summary/nan/ |
方便事后排查问题数据 |
| 推理时 clone | input = as_real(noisy).clone() |
防止 no_grad 模式下 in-place 操作破坏数据 |
| 周期性日志 | i % log_freq |
避免频繁 IO 拖慢训练 |
13. detach_hidden —— Truncated BPTT
核心原理:
RNN 的隐状态 h 在 batch 之间持续传递(模拟流式处理)。PyTorch 中每个 tensor 同时携带两样东西:
- 数值:隐状态本身,需要跨 batch 传递(保留时序上下文)
- 计算图节点(grad_fn):backward 时用于传递梯度,不应跨 batch
# 训练 batch_i 结束后:
detach_hidden(model)
# 效果:h.data 保留(数值继续传) + h.grad_fn = None(计算图截断)
| 场景 | 若不 detach | detach 后 |
|---|---|---|
| 训练 | 梯度穿越多个 batch,显存无限增长,OOM | 梯度只在当前 batch 内反传,显存稳定 |
| 验证/测试 | 隐状态持有旧 tensor 引用,内存缓慢泄漏 | 引用断开,GC 可正常回收 |
两个概念的区分:
- 隐状态(数值):跨 batch 传递 ✓(前向上下文)
- 梯度(计算图):在 batch 边界截断 ✓(Truncated BPTT)
启发:
- 任何含状态的流式模型(RNN、状态空间模型)在批次训练时都需要 Truncated BPTT
detach应封装为工具函数(如detach_hidden),调用位置应在 backward 之后、下一次 forward 之前- 训练和推理路径应统一调用,防止推理时内存泄漏
14. 多损失函数组合设计(setup_losses)
def setup_losses() -> Loss:
istft = Istft(fft_size, hop_size, fft_window).to(device)
loss = Loss(state, istft).to(device)
return loss
Loss 内部组合多个子损失,每个通过 factor 配置权重(0 表示禁用):
| 子损失 | 作用 |
|---|---|
MaskLoss |
ERB mask 与 ideal mask 的 MSE,监督粗粒度降噪 |
SpectralLoss |
复数频谱域 L1/L2,监督增强频谱形状 |
MultiResSpecLoss |
多分辨率 STFT loss,捕捉不同时频尺度的失真 |
SdrLoss / SegSdrLoss |
时域 SDR,直接优化感知质量 |
LocalSnrLoss |
局部 SNR 估计精度,辅助任务 |
ASRLoss |
语音识别特征一致性,保护语音可懂度 |
启发:
- 多损失组合是音频增强任务的标配:单一损失(如 MSE)难以同时保证频谱精度和感知质量
- 每个子损失的
factor=0禁用设计,使得消融实验(ablation study)无需改代码,只改配置 - 辅助损失(如
LocalSnrLoss、ASRLoss)提供额外监督信号,有助于训练早期稳定收敛 losses.reset_summaries()/losses.get_summaries()机制使各子损失的值可独立监控,方便调试
补充:其他值得关注的设计点
过拟合模式(OVERFIT):overfit = config("OVERFIT", False, bool) 传给 DataLoader 后,DataLoader 只重复加载少量样本。用于快速验证模型能否在小数据集上过拟合,是"模型是否有足够容量"的快速检测方法。
音频摘要保存(summary_write):
每隔 log_freq 步保存 clean/noisy/enhanced 的 wav 文件和 lsnr 曲线,训练过程可随时监听音质变化,而不必等到训练结束才评估。
@logger.catch 装饰器:
捕获 main() 中所有未处理异常并通过 loguru 格式化输出,确保异常信息写入训练日志文件而不是只打印到 stdout 后消失。