GrokAdamW
GrokAdamW 是一种新的优化器,旨在通过结合Grokfast技术与AdamW优化器的优势来加速AI模型的训练,尤其适用于那些在训练数据上过拟合后,验证数据表现突然改善的延迟泛化模型。
使用场景
GrokAdamW特别适用于处理表现出延迟泛化现象的深度学习模型。在这些模型中,在一段时间的过拟合之后,验证数据上的性能会大幅提升。这个优化器通过以下方式加速这种“grokking”现象:
- 自适应动量:Grokfast组件的动量会根据“Grokking Signal”动态调整,反映模型的泛化进展。
- 逐层动量衰减:不同层的学习速率不同,GrokAdamW在前层引入动量衰减以促进泛化,而在后层防止过拟合。
- 多种Grokking Signals:支持多种信号函数的组合,以捕获泛化性能的不同方面。
- 可选梯度裁剪:提供裁剪梯度的选项,以增强训练稳定性,防止梯度爆炸。
理论基础
Grokking现象表示在长时间的过拟合后,深度学习模型能突然实现泛化。Grokfast技术通过放大训练过程中的慢变化梯度来加速这种现象。
GrokAdamW通过以下数学机制集成Grokfast与AdamW优化算法:
- AdamW更新: 部分采用了动量和平方动量的估计,进行偏差修正后,利用这些估计更新参数。
- Grokfast整合: 通过计算Grokking Signal(反映泛化进展的信号),调整EMA滤波器的动量,进而放大慢变化梯度。
安装与使用
通过pip简单安装
pip install grokadamw
并在训练模型时使用
import torch
import torch.nn as nn
from grokadamw import GrokAdamW
# 建立模型
model = nn.Linear(10, 1)
# 定义Grokking信号函数
def grokking_signal_fn(training_loss: float, validation_loss: float) -> float:
if training_loss == 0:
return 0.0 # 避免除零错误
return (validation_loss - training_loss) / training_loss
# 初始化GrokAdamW优化器
optimizer = GrokAdamW(model.parameters(), lr=1e-3, grokking_signal_fn=grokking_signal_fn)
# 训练循环
for epoch in range(num_epochs):
# ... 训练代码 ...
loss = optimizer.step(closure=lambda: your_loss_function(model, data))
配置参数
GrokAdamW支持标准的AdamW参数(如lr
, betas
, eps
, weight_decay
)和Grokfast的特定参数(如alpha_init
, lamb
, gamma
, grokking_signal_fns
, grokking_signal_decay_rate
, gradient_clipping
)。
调试建议
解决常见问题,如:
- 确保Grokking信号函数返回有意义的值。
- 如果梯度频繁裁剪,可能需要调整学习率或其他超参数。
- 监控不同层的学习动态,调整
gamma
或各层的超参数。
贡献
欢迎提交问题、功能请求或拉取请求。 详细信息请见CONTRIBUTING.md文件。
许可
GrokAdamW遵循MIT许可证。详细信息请参见许可证文件。
广告:私人定制视频文本提取,字幕翻译制作等,欢迎联系QQ:1752338621