This commit is contained in:
zwy 2020-06-13 22:00:05 +08:00
commit bed94b4b07
4 changed files with 114 additions and 23 deletions

View File

@ -3,6 +3,98 @@ jittor.mpi
这里是Jittor的MPI模块的API文档您可以通过`from jittor import mpi`来获取该模块。
## 如何从单卡代码适配多卡代码
使用`mpirun`时以下几种模块会自动检测mpi环境并且自动切换成多卡版本
* jittor.optimizer: 自动同步梯度
* jittor.nn.BatchNorm* 同步batch norm
* jittor.dataset 自动数据并行
大部分情况下,单卡训练的代码可以直接使用`mpirun`实现分布式多卡运行。 但仍然如下几种情况下,需要对代码进行调整:
1. 对硬盘进行写操作(保存模型,保存曲线)
2. 需要统计全局信息validation 上的全局准确率)
### 对硬盘进行写操作
对于第一点,假设原来您的代码如下:
```python
for i, (images, labels) in enumerate(dataset):
output = model(images)
loss = nn.cross_entropy_loss(output, labels)
acc1 = accuracy(output, labels)
SGD.step(loss)
loss_data = loss.data
writer.add_scalar("Train/loss")
```
更改后的代码如下:
```python
for i, (images, labels) in enumerate(dataset):
output = model(images)
loss = nn.cross_entropy_loss(output, labels)
acc1 = accuracy(output, labels)
SGD.step(loss)
loss_data = loss.data
if jt.rank == 0:
writer.add_scalar("Train/loss")
```
这里我们使用了 jt.rank 来限制,只允许第一个进程可以写 loss这个代码在单卡下也是有效的因为单卡的 jt.rank 值为 0 需要注意的是,在 `if jt.rank == 0` 代码块里面的代码不允许调用任何jittor的api因为这很有可能导致多卡之间的api调用不一致而产生**死锁**!
### 需要统计全局信息
统计全局信息有两种方法,第一种是使用提供的 mpi op 来实现全局信息统计, 如下所示, 是一个validation的代码
```python
def val(epoch):
global min_error
model.eval()
correct_nums = 0
for i, (images, labels) in enumerate(valdataset):
output = model(images)
correct_nums += top1error(output, labels)
correct_nums.sync()
top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len
if top1_error < min_error:
print("[*] Best model is updated ...")
model.save('model_best.pkl')
```
更改方案如下:
```python
def val(epoch):
global min_error
model.eval()
correct_nums = 0
for i, (images, labels) in enumerate(valdataset):
output = model(images)
correct_nums += top1error(output, labels)
correct_nums.sync()
if jt.in_mpi:
correct_nums = correct_nums.mpi_all_reduce()
top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len
if jt.rank == 0 and top1_error < min_error:
print("[*] Best model is updated ...")
model.save('model_best.pkl')
```
可以留意到我们首先使用了 `mpi_all_reduce` 来统计多卡的正确数量(mpi_all_reduce会将多个mpi进程的结果累加起来) 然后在 `jt.rank == 0` 的情况下才更新模型。
第二种方法是使用`@jt.single_process_scope()`,被装饰的代码会直接以单进程的方式执行,无需处理多卡。
```python
@jt.single_process_scope()
def val(epoch):
......
```
下面是 jittor 的 mpi api reference.
```eval_rst
.. automodule:: jittor_mpi_core
:members:

View File

@ -16,7 +16,7 @@ with lock.lock_scope():
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
if has_cuda:
@ -125,24 +125,7 @@ class profile_scope(_call_no_record_scope):
profiler.stop()
self.report.extend(profiler.report())
class single_process_scope:
""" single_process_scope
Code in this scope will only be executed by single process.
All the mpi code inside this scope will have not affect.
mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1,
example::
with jt.single_process_scope(rank=0) as flag:
if flag:
......
@jt.single_process_scope(rank=0)
def xxx():
...
"""
class __single_process_scope:
def __init__(self, rank=0):
self.rank = rank
@ -165,15 +148,30 @@ class single_process_scope:
if mpi:
mpi.set_state(self.bk_mpi_state)
def __call__(self, func):
global mpi
def single_process_scope(rank=0):
""" single_process_scope
Code in this scope will only be executed by single process.
All the mpi code inside this scope will have not affect.
mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1,
example::
@jt.single_process_scope(rank=0)
def xxx():
...
"""
def outer(func):
def inner(*args, **kw):
ret = None
with self as flag:
sync_all()
with __single_process_scope(rank) as flag:
if flag:
ret = func(*args, **kw)
return ret
return inner
return outer
def clean():
import gc

View File

@ -396,6 +396,7 @@ def setup_mpi():
setup_mpi()
in_mpi = inside_mpi()
rank = mpi.world_rank() if in_mpi else 0
setup_nccl()
setup_cutt()

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.4.4',
version='1.1.4.5',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",