最近正跟随老师做有关NAS方面的工作,想将NAS与自己的想法结合起来验证一些有趣的结论,这其中会涉及到生成不同的网络结构提交到GPU做验证的过程。
那么问题来了:对即将要提交到GPU的网络结构大小,程序是无法提前预知的,训练batch_size设置大了很可能下一个网络结构就会爆显存;而保守地设置地很小的话,就毫无训练效率可言了;在网络上搜集了一些捕获PyTorch OOM错误并尝试从其中恢复的例子,供记录与参考。
retry_times = 0
try:
while retry_times < 3:
try:
train = TrainModel(gpu_id=str(gpu_id), batch_size=batch_size, num_workers=num_workers)
best_metric = train.start(max_epoch)
break # 训练成功完成则退出while
except RuntimeError as e: # PyTorch的OOM属于RuntimeError
if "out of memory" in str(e).lower():
retry_times += 1
batch_size = batch_size // 2 # batch_size调整策略,可自定义修改
Log.warn("Out of memory detected! Trying to lower batch_size to %d and restart..." % batch_size)
else:
raise e # 其他错误原地抛出,不受影响
finally:
for p in train.model.parameters():
if p.grad is not None:
del p.grad
del train # 手动删除所有梯度与整个模型,避免训练中断的模型继续占用GPU
gc.collect()
torch.cuda.empty_cache()
time.sleep(5)