最近很荣幸加入了旷视成都研究院参与实习工作,在开始正式工作前的首要任务自然是将用好、用活旷视自建的Brain++计算平台。Brain++的设计很巧妙,也让我这个没出本科校门的fw体会到了多卡训练的爽点。之前在本校接触过3卡2080ti的多卡训练,当时用的开源代码,没有意识到自己动手写torch单机多卡训练代码时有这么多坑,下面一一列举一哈:

报错:Default process group has not been initialized, please make sure to call init_process_group.

调用torch.distributed下任何函数前,必须运行torch.distributed.init_process_group(backend='nccl')初始化。

DistributedSampler的shuffle

torch.utils.data.distributed.DistributedSampler有一个很坑的点,尽管提供了shuffle选项,但此shuffle非彼shuffle,如果不在每个epoch前手动执行下面这两行,在每张卡上每个epoch返回的index permutation都会是一样的...

if global_rank != -1:
     dataloader.sampler.set_epoch(epoch)

理由可见下段注释:

source: torch/utils/data/distributed.py#L126

def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch

torch.load

多卡的checkpoint不要直接load,指定load到cpu会缓解GPU0的显存压力。

checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])

# 使用下面这种load方式会导致每个进程在GPU0多占用一部分显存,原因是默认load的位置是GPU0
# checkpoint = torch.load("checkpoint.pth")
# model.load_state_dict(checkpoint["state_dict"])

SyncBN

多卡BN在默认情况下计算单张卡下的mean和var,并在单张卡下维护自己的BN参数。解决办法有SyncBN与Gruop Normalization。

SyncBN用同步的方法完成BN以尽可能模拟单卡场景,尽管会降低GPU利用率,但可以提高模型在多卡场景下的表现。可以用一行代码替换模型中所有的BN层为SyncBN层:

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

注意,这只会替换掉所有直接或间接继承自torch.nn.modules.batchnorm._BatchNorm的module(isinstance判断),意味着模型中自定义的Normalization Layer依然保持原有行为。

获取分布式参数(local_rank, global_rank, world_size)的几个方式

rank分为local_rank和global_rank,分别为本机的第多少个计算设备以及全局第多少个计算设备。单机多卡场景下两者保持一致。
global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
global_rank = torch.distributed.get_rank()

  • 如本文第一条总结所说,这个函数需要初始化torch.distributed.init_process_group(backend='nccl')后才能成功调用。
import argparse
  parser = argparse.ArgumentParser()
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  opt = parser.parse_args()
  print(opt.local_rank)

world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1

torch.distributed.barrier()

多卡数据分配的处理由torch.utils.data.distributed.DistributedSampler完成,但有一些其他任务并没有这么方便的接口来处理多进程的同步问题。一个经典的例子是:很多操作只需要在worker 0上执行一遍,其他worker来取worker 0的结果就行,比如说预训练模型下载、数据集检查。torch.distributed.barrier()提供了很方便的、类似于p.join()的同步手段:每当进程执行到torch.distributed.barrier()时,该进程会保持阻塞,直至所有进程都执行到了这句话。

torch.distributed.barrier()的一个经典使用方法是用它构建一个上下文管理器:

from contextlib import contextmanager
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()

contextmanager装饰器配合yield使得我们不用编写__enter__ __exit__方法也能定义一个上下文管理器,十分Pythonic!

正如其函数名torch_distributed_zero_first所示,它能够让指定代码段先被worker 0执行一遍,随后才允许其余worker进入,使用方法如下:

with torch_distributed_zero_first(local_rank):
    check_dataset(data_dict)  # check dataset integrity
Last modification:April 4th, 2021 at 04:13 pm
If you think my article is useful to you, please feel free to appreciate