AI语音降噪 - DeepFilterNet 训练设计技巧总结


1. 随机种子固定(可复现性)

seed = config("SEED", 42, int, section="train")
check_manual_seed(seed)

设计作用:
check_manual_seed 同时固定 Python randomnumpytorch(CPU 和 CUDA)的随机种子,确保每次训练的数据增强、参数初始化、dropout mask 等随机行为完全一致,使实验结果可复现。

启发:


2. 主机特定批大小配置(多机适配)

key = get_host() + "_" + config.get("model") + "_" + config.get("fft_size")
set_batch_size(config_file, args.host_batchsize_config, host_key=key)

设计作用:
不同服务器 GPU 显存不同,对同一模型的最优 batch size 也不同。该机制通过 hostname_modelname_fftsize 作为 key,在外部 JSON 配置中查找当前机器的最优 batch size,动态覆盖默认配置,无需修改代码即可适配多台机器。

启发:


3. Batch Size 动态调度(渐进式训练)

# 格式: '0/4,5/8,10/16'  -> epoch0: bs=4, epoch5: bs=8, epoch10: bs=16
batch_size_scheduling: List[str] = config("BATCH_SIZE_SCHEDULING", [], Csv(str))
for e, b in batch_size_scheduling:
    if e <= epoch:
        scheduling_bs = min(b, bs)  # 不超过配置上限
if prev_scheduling_bs != scheduling_bs:
    dataloader.set_batch_size(scheduling_bs, "train")
    lrs = setup_lrs(len(dataloader))  # 重新计算 lr schedule!
    wds = setup_wds(len(dataloader))

设计作用:
训练初期用小 batch size,模型更新更频繁,有助于跳出局部最优、加快收敛;训练后期增大 batch size,梯度估计更稳定,有助于精细收敛。同时,batch size 变化会改变 epoch 内的 step 数,因此必须重新计算 lr/wd schedule 数组,保证学习率曲线在时间轴上的形状不变。

启发:


4. SIGUSR1 信号处理(优雅超时退出)

signal.signal(signal.SIGUSR1, get_sigusr1_handler(args.base_dir))

def get_sigusr1_handler(base_dir):
    def h(*args):
        global should_stop
        should_stop = True                          # 标记停止
        open(os.path.join(base_dir, "continue"), "w").close()  # 写 continue 文件
    return h

在每个 epoch 末检查:

if should_stop:
    logger.info("Stopping training due to timeout")
    exit(0)

设计作用:
在 HPC 集群(如 SLURM)中,作业有时间配额限制,到时会被强制 kill。SIGUSR1 信号通常作为"即将超时"的预警信号提前发出。捕获后设置 should_stop=True,等当前 epoch 训练完毕再安全退出——此时 checkpoint 已写入磁盘,不会丢失进度。同时写入 continue 文件,供 job scheduler 判断是否需要重新提交任务继续训练。

启发:


5. 分模块训练(mask_only / df_only)

mask_only: bool = config("MASK_ONLY", False, bool, section="train")
train_df_only: bool = config("DF_ONLY", False, bool, section="train")

if mask_only:
    params = [p for n, p in model.named_parameters()
              if "dfrnn" not in n and "df_dec" not in n]
elif df_only:
    params = (p for n, p in model.named_parameters() if "df" in n.lower())
else:
    params = model.parameters()

设计作用:
DeepFilterNet 由两个主要分支组成:ERB mask 分支(粗粒度噪声抑制)和 DF deep filter 分支(细粒度频谱整形)。分模块训练允许:

这是一种**课程学习(Curriculum Learning)**思想的工程实现。

启发:


6. TorchScript JIT 编译

jit = config("JIT", False, cast=bool, section="train")
# 注意:必须在 log_model_summary 之后再 jit
if jit:
    model = torch.jit.script(model)

设计作用:
torch.jit.script 将 Python 代码编译为静态计算图,消除 Python 解释器开销、使能算子融合,可加速训练和推理。同时 JIT 模型可序列化为 .pt 文件,在不依赖 Python 环境的情况下(如 C++ 推理框架)运行。代码中刻意在 log_model_summary 之后才执行 JIT,因为 JIT 模型的接口与普通 Module 不同,无法正常打印结构。

启发:


7. 余弦退火 LR/WD 逐步调度

# 预先计算所有 step 的 lr 值(numpy 数组,按全局 step 索引)
lrs = setup_lrs(steps_per_epoch)   # shape: [total_steps]
wds = setup_wds(steps_per_epoch)   # shape: [total_steps] or None

# 每个 step 内更新
it = start_steps + i  # 全局 step 编号
param_group["lr"] = lr_scheduler_values[it] * param_group.get("lr_scale", 1)
param_group["weight_decay"] = wd_scheduler_values[it]

setup_lrs 使用带 warmup 的余弦退火,支持多周期:

lr
▲
|  warmup  |      cosine decay       | cycle 2 |
|          \                        /\         /
|           \                      /  \       /
|            \____________________/    \_____/
└────────────────────────────────────────────→ step

setup_wds 同样用余弦曲线从 weight_decay 衰减到 weight_decay_end(若不配置则不启用)。

启发:


8. 早停机制(Early Stopping)

val_criteria_type = config("VALIDATION_CRITERIA", "loss", section="train")
val_criteria_rule = config("VALIDATION_CRITERIA_RULE", "min", section="train")
patience = config("EARLY_STOPPING_PATIENCE", 5, int, section="train")

val_criteria = metrics[val_criteria_type]
write_cp(model, "model", checkpoint_dir, epoch+1,
         metric=val_criteria, cmp=val_criteria_rule)  # 只保存最优

if not check_patience(checkpoint_dir, max_patience=patience,
                      new_metric=val_criteria, cmp=val_criteria_rule,
                      raise_=False):
    break

设计作用:
若验证集指标连续 patience 个 epoch 没有改善(支持 min/max 两种方向),则停止训练,防止过拟合和浪费计算资源。验证指标支持任意 metric(loss、PESQ、DNSMOS 等),由配置决定,不耦合于 loss。write_cp 中带 metric 参数时只在指标更优时覆盖"best"检查点,训练结束后加载的就是历史最优模型。

启发:


9. 训练前先跑验证集(START_EVAL 基准线)

if config("START_EVAL", False, cast=bool, section="train"):
    val_loss = run_epoch(model, epoch=epoch-1, split="valid", ...)
    log_metrics(f"[{epoch-1}] [valid]", metrics)

设计作用:
在正式训练第 0 epoch 之前,先跑一遍验证集,得到初始模型(随机初始化或预加载权重)的基准性能。这样:

  1. 训练曲线从 epoch -1 开始,可以直观看出训练带来的提升量
  2. 若基准性能已经很好(如加载了预训练权重),可以判断是否需要训练
  3. 提前暴露数据 pipeline、loss 计算的 bug,而不是等到第一轮训练完再发现问题

启发:


10. 断点续训设计

检查点保存(每轮 + 最优):

write_cp(model, "model", checkpoint_dir, epoch + 1)      # 每轮保存模型
write_cp(opt,   "opt",   checkpoint_dir, epoch + 1)      # 每轮保存优化器状态
write_cp(model, "model", checkpoint_dir, epoch + 1,
         metric=val_criteria, cmp=val_criteria_rule)     # 条件保存最优模型

续点加载:

# 加载模型(含 epoch 号)
model, epoch = load_model(checkpoint_dir if args.resume else None, ...)
# 加载优化器(含 Adam 矩估计 m1/m2)
read_cp(opt, "opt", cp_dir)
# 从上次停止的 epoch 继续
for epoch in range(epoch, max_epochs):
    ...

设计亮点:

启发:


11. 核心训练循环(主 epoch 循环)

flowchart TD
    A[开始 epoch 循环] --> B{有 batch_size 调度?}
    B -- 是 --> C[更新 batch_size\n联动重算 lrs/wds]
    B -- 否 --> D
    C --> D[run_epoch: split=train\n传入 lr/wd 调度数组]
    D --> E[记录 train 指标\nlog_metrics]
    E --> F[write_cp: 保存模型+优化器\n每轮均保存]
    F --> G[run_epoch: split=valid\n不传 lr/wd 调度]
    G --> H[write_cp: 若指标更优则保存 best]
    H --> I[log_metrics 验证指标]
    I --> J{check_patience\n连续N轮无提升?}
    J -- 继续训练 --> K{should_stop?\nSIGUSR1触发?}
    J -- 触发早停 --> L[break 退出循环]
    K -- 是 --> M[exit 优雅退出]
    K -- 否 --> N[losses.reset_summaries]
    N --> A
    L --> O[load_model best checkpoint]
    O --> P[run_epoch: split=test]
    P --> Q[记录测试指标 结束]

设计要点:


12. run_epoch 具体实现

flowchart TD
    A[run_epoch 开始] --> B[初始化\nis_train / model.train / seed]
    B --> C[for i, batch in dataloader]
    C --> D[opt.zero_grad]
    D --> E{有 lr/wd 调度?}
    E -- 是 --> F[按全局 step it\n更新 param_group lr/wd]
    E -- 否 --> G
    F --> G[数据移至 device]
    G --> H[set_grad_enabled is_train\nmodel.forward]
    H --> I[losses.forward 计算 loss]
    I --> J{loss 含 NaN?}
    J -- 是 --> K[跳过本 batch\nn_nans++]
    K --> C
    J -- 否 --> L{is_train?}
    L -- 是 --> M[err.backward\nclip_grad_norm_\nopt.step]
    M --> N{梯度 NaN?}
    N -- 是 --> O[保存 NaN 样本音频\n跳过本 batch]
    O --> C
    N -- 否 --> P
    L -- 否 --> P
    P[detach_hidden model] --> Q[l_mem.append err]
    Q --> R{i % log_freq?}
    R -- 是 --> S[log_metrics\nsummary_write 保存音频]
    R -- 否 --> C
    S --> C
    C -- epoch结束 --> T[cleanup 释放显存]
    T --> U[return mean loss]

设计要点:

机制 实现 目的
NaN 容忍 n_nans 计数 + continue,超过 MAX_NANS 才 raise 避免偶发 NaN 中断训练
梯度裁剪 clip_grad_norm_(..., 1.0) 防止梯度爆炸,稳定 RNN 训练
NaN 样本记录 保存 NaN 时的音频到 summary/nan/ 方便事后排查问题数据
推理时 clone input = as_real(noisy).clone() 防止 no_grad 模式下 in-place 操作破坏数据
周期性日志 i % log_freq 避免频繁 IO 拖慢训练

13. detach_hidden —— Truncated BPTT

核心原理:

RNN 的隐状态 h 在 batch 之间持续传递(模拟流式处理)。PyTorch 中每个 tensor 同时携带两样东西:

# 训练 batch_i 结束后:
detach_hidden(model)
# 效果:h.data 保留(数值继续传) + h.grad_fn = None(计算图截断)
场景 若不 detach detach 后
训练 梯度穿越多个 batch,显存无限增长,OOM 梯度只在当前 batch 内反传,显存稳定
验证/测试 隐状态持有旧 tensor 引用,内存缓慢泄漏 引用断开,GC 可正常回收

两个概念的区分:

启发:


14. 多损失函数组合设计(setup_losses)

def setup_losses() -> Loss:
    istft = Istft(fft_size, hop_size, fft_window).to(device)
    loss = Loss(state, istft).to(device)
    return loss

Loss 内部组合多个子损失,每个通过 factor 配置权重(0 表示禁用):

子损失 作用
MaskLoss ERB mask 与 ideal mask 的 MSE,监督粗粒度降噪
SpectralLoss 复数频谱域 L1/L2,监督增强频谱形状
MultiResSpecLoss 多分辨率 STFT loss,捕捉不同时频尺度的失真
SdrLoss / SegSdrLoss 时域 SDR,直接优化感知质量
LocalSnrLoss 局部 SNR 估计精度,辅助任务
ASRLoss 语音识别特征一致性,保护语音可懂度

启发:


补充:其他值得关注的设计点

过拟合模式(OVERFIT):
overfit = config("OVERFIT", False, bool) 传给 DataLoader 后,DataLoader 只重复加载少量样本。用于快速验证模型能否在小数据集上过拟合,是"模型是否有足够容量"的快速检测方法。

音频摘要保存(summary_write):
每隔 log_freq 步保存 clean/noisy/enhanced 的 wav 文件和 lsnr 曲线,训练过程可随时监听音质变化,而不必等到训练结束才评估。

@logger.catch 装饰器:
捕获 main() 中所有未处理异常并通过 loguru 格式化输出,确保异常信息写入训练日志文件而不是只打印到 stdout 后消失。