AI新工具
banner

GrokAdamW


介绍:

GrokAdamW 是一种结合 Grokfast 和 AdamW 优化器优势加速深度学习模型训练的新型优化器。









GrokAdamW

GrokAdamW 是一种新的优化器,旨在通过结合Grokfast技术与AdamW优化器的优势来加速AI模型的训练,尤其适用于那些在训练数据上过拟合后,验证数据表现突然改善的延迟泛化模型。

使用场景

GrokAdamW特别适用于处理表现出延迟泛化现象的深度学习模型。在这些模型中,在一段时间的过拟合之后,验证数据上的性能会大幅提升。这个优化器通过以下方式加速这种“grokking”现象:

  1. 自适应动量:Grokfast组件的动量会根据“Grokking Signal”动态调整,反映模型的泛化进展。
  2. 逐层动量衰减:不同层的学习速率不同,GrokAdamW在前层引入动量衰减以促进泛化,而在后层防止过拟合。
  3. 多种Grokking Signals:支持多种信号函数的组合,以捕获泛化性能的不同方面。
  4. 可选梯度裁剪:提供裁剪梯度的选项,以增强训练稳定性,防止梯度爆炸。
理论基础

Grokking现象表示在长时间的过拟合后,深度学习模型能突然实现泛化。Grokfast技术通过放大训练过程中的慢变化梯度来加速这种现象。

GrokAdamW通过以下数学机制集成Grokfast与AdamW优化算法:

  • AdamW更新: 部分采用了动量和平方动量的估计,进行偏差修正后,利用这些估计更新参数。
  • Grokfast整合: 通过计算Grokking Signal(反映泛化进展的信号),调整EMA滤波器的动量,进而放大慢变化梯度。
安装与使用

通过pip简单安装

pip install grokadamw

并在训练模型时使用

import torch
import torch.nn as nn
from grokadamw import GrokAdamW

# 建立模型
model = nn.Linear(10, 1)

# 定义Grokking信号函数
def grokking_signal_fn(training_loss: float, validation_loss: float) -> float:
    if training_loss == 0:
        return 0.0  # 避免除零错误
    return (validation_loss - training_loss) / training_loss

# 初始化GrokAdamW优化器
optimizer = GrokAdamW(model.parameters(), lr=1e-3, grokking_signal_fn=grokking_signal_fn)

# 训练循环
for epoch in range(num_epochs):
    # ... 训练代码 ...
    loss = optimizer.step(closure=lambda: your_loss_function(model, data)) 
配置参数

GrokAdamW支持标准的AdamW参数(如lr, betas, eps, weight_decay)和Grokfast的特定参数(如alpha_init, lamb, gamma, grokking_signal_fns, grokking_signal_decay_rate, gradient_clipping)。

调试建议

解决常见问题,如:

  • 确保Grokking信号函数返回有意义的值。
  • 如果梯度频繁裁剪,可能需要调整学习率或其他超参数。
  • 监控不同层的学习动态,调整gamma或各层的超参数。
贡献

欢迎提交问题、功能请求或拉取请求。 详细信息请见CONTRIBUTING.md文件。

许可

GrokAdamW遵循MIT许可证。详细信息请参见许可证文件。

可关注我们的公众号:每天AI新工具

广告:私人定制视频文本提取,字幕翻译制作等,欢迎联系QQ:1752338621