From 666854eb65b20f35a2c5ec4dd2c211b06da945a9 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Sat, 13 Jun 2020 20:52:03 +0800 Subject: [PATCH] update mpi useage --- doc/source/jittor.mpi.md | 92 +++++++++++++++++++++++++++++++++ python/jittor/__init__.py | 42 +++++++-------- python/jittor/compile_extern.py | 1 + setup.py | 2 +- 4 files changed, 114 insertions(+), 23 deletions(-) diff --git a/doc/source/jittor.mpi.md b/doc/source/jittor.mpi.md index c6c00548..74d498bf 100644 --- a/doc/source/jittor.mpi.md +++ b/doc/source/jittor.mpi.md @@ -3,6 +3,98 @@ jittor.mpi 这里是Jittor的MPI模块的API文档,您可以通过`from jittor import mpi`来获取该模块。 +## 如何从单卡代码适配多卡代码 + +使用`mpirun`时,以下几种模块会自动检测mpi环境并且自动切换成多卡版本: + +* jittor.optimizer: 自动同步梯度 +* jittor.nn.BatchNorm*: 同步batch norm +* jittor.dataset: 自动数据并行 + +大部分情况下,单卡训练的代码可以直接使用`mpirun`实现分布式多卡运行。 但仍然如下几种情况下,需要对代码进行调整: + +1. 对硬盘进行写操作(保存模型,保存曲线) +2. 需要统计全局信息(validation 上的全局准确率) + +### 对硬盘进行写操作 + +对于第一点,假设原来您的代码如下: + +```python +for i, (images, labels) in enumerate(dataset): + output = model(images) + loss = nn.cross_entropy_loss(output, labels) + acc1 = accuracy(output, labels) + SGD.step(loss) + loss_data = loss.data + writer.add_scalar("Train/loss") +``` + +更改后的代码如下: + +```python +for i, (images, labels) in enumerate(dataset): + output = model(images) + loss = nn.cross_entropy_loss(output, labels) + acc1 = accuracy(output, labels) + SGD.step(loss) + loss_data = loss.data + if jt.rank == 0: + writer.add_scalar("Train/loss") +``` + +这里我们使用了 jt.rank 来限制,只允许第一个进程可以写 loss,这个代码在单卡下也是有效的,因为单卡的 jt.rank 值为 0, 需要注意的是,在 `if jt.rank == 0` 代码块里面的代码,不允许调用任何jittor的api,因为这很有可能导致多卡之间的api调用不一致而产生**死锁**! + +### 需要统计全局信息 + +统计全局信息有两种方法,第一种是使用提供的 mpi op 来实现全局信息统计, 如下所示, 是一个validation的代码: + +```python +def val(epoch): + global min_error + model.eval() + correct_nums = 0 + for i, (images, labels) in enumerate(valdataset): + output = model(images) + correct_nums += top1error(output, labels) + correct_nums.sync() + top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len + if top1_error < min_error: + print("[*] Best model is updated ...") + model.save('model_best.pkl') +``` + +更改方案如下: + +```python +def val(epoch): + global min_error + model.eval() + correct_nums = 0 + for i, (images, labels) in enumerate(valdataset): + output = model(images) + correct_nums += top1error(output, labels) + correct_nums.sync() + if jt.in_mpi: + correct_nums = correct_nums.mpi_all_reduce() + top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len + if jt.rank == 0 and top1_error < min_error: + print("[*] Best model is updated ...") + model.save('model_best.pkl') +``` + +可以留意到我们首先使用了 `mpi_all_reduce`, 来统计多卡的正确数量(mpi_all_reduce会将多个mpi进程的结果累加起来), 然后在 `jt.rank == 0` 的情况下才更新模型。 + +第二种方法是使用`@jt.single_process_scope()`,被装饰的代码会直接以单进程的方式执行,无需处理多卡。 + +```python +@jt.single_process_scope() +def val(epoch): + ...... +``` + +下面是 jittor 的 mpi api reference. + ```eval_rst .. automodule:: jittor_mpi_core :members: diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 81a3a75c..949b6956 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -16,7 +16,7 @@ with lock.lock_scope(): from jittor_core import * from jittor_core.ops import * from . import compile_extern - from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi + from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank if core.get_device_count() == 0: has_cuda = compile_extern.has_cuda = compiler.has_cuda = False if has_cuda: @@ -125,24 +125,7 @@ class profile_scope(_call_no_record_scope): profiler.stop() self.report.extend(profiler.report()) -class single_process_scope: - """ single_process_scope - - Code in this scope will only be executed by single process. - - All the mpi code inside this scope will have not affect. - mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1, - - example:: - - with jt.single_process_scope(rank=0) as flag: - if flag: - ...... - - @jt.single_process_scope(rank=0) - def xxx(): - ... - """ +class __single_process_scope: def __init__(self, rank=0): self.rank = rank @@ -165,15 +148,30 @@ class single_process_scope: if mpi: mpi.set_state(self.bk_mpi_state) - def __call__(self, func): - global mpi +def single_process_scope(rank=0): + """ single_process_scope + + Code in this scope will only be executed by single process. + + All the mpi code inside this scope will have not affect. + mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1, + + example:: + + @jt.single_process_scope(rank=0) + def xxx(): + ... + """ + def outer(func): def inner(*args, **kw): ret = None - with self as flag: + sync_all() + with __single_process_scope(rank) as flag: if flag: ret = func(*args, **kw) return ret return inner + return outer def clean(): import gc diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 1bdf7c84..55983bd3 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -396,6 +396,7 @@ def setup_mpi(): setup_mpi() in_mpi = inside_mpi() +rank = mpi.world_rank() if in_mpi else 0 setup_nccl() setup_cutt() diff --git a/setup.py b/setup.py index 425c267c..8651fdf8 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh: setuptools.setup( name='jittor', - version='1.1.4.4', + version='1.1.4.5', # scripts=[], author="Jittor Group", author_email="ran.donglang@gmail.com",