最近正跟随老师做有关NAS方面的工作,想将NAS与自己的想法结合起来验证一些有趣的结论,这其中会涉及到生成不同的网络结构提交到GPU做验证的过程。

那么问题来了:对即将要提交到GPU的网络结构大小,程序是无法提前预知的,训练batch_size设置大了很可能下一个网络结构就会爆显存;而保守地设置地很小的话,就毫无训练效率可言了;在网络上搜集了一些捕获PyTorch OOM错误并尝试从其中恢复的例子,供记录与参考。

retry_times = 0
try:
    while retry_times < 3:
        try:
            train = TrainModel(gpu_id=str(gpu_id), batch_size=batch_size, num_workers=num_workers)
            best_metric = train.start(max_epoch)
            break  # 训练成功完成则退出while
        except RuntimeError as e:  # PyTorch的OOM属于RuntimeError
            if "out of memory" in str(e).lower():
                retry_times += 1
                batch_size = batch_size // 2  # batch_size调整策略,可自定义修改
                Log.warn("Out of memory detected! Trying to lower batch_size to %d and restart..." % batch_size)
            else:
                raise e  # 其他错误原地抛出,不受影响
        finally:
            for p in train.model.parameters():
                if p.grad is not None:
                    del p.grad
            del train  # 手动删除所有梯度与整个模型,避免训练中断的模型继续占用GPU
            gc.collect()
            torch.cuda.empty_cache()
            time.sleep(5)
Last modification:January 19th, 2021 at 09:03 pm
If you think my article is useful to you, please feel free to appreciate