forked from mindspore-Ecosystem/mindspore
fix_api
This commit is contained in:
parent
889f7cb030
commit
18640dd4ac
|
@ -1,14 +1,28 @@
|
|||
mindspore.communication
|
||||
========================
|
||||
集合通信接口的类。
|
||||
集合通信接口。
|
||||
|
||||
注意,集合通信接口需要预先设置环境变量。对于Ascend,用户需要配置rank_table,设置rank_id和device_id,相关教程可参考:
|
||||
<https://www.mindspore.cn/tutorials/zh-CN/master/intermediate/distributed_training/distributed_training_ascend.html>`_。
|
||||
对于GPU,用户需要预先配置host_file以及mpi,相关教程参考:
|
||||
<https://www.mindspore.cn/tutorials/zh-CN/master/intermediate/distributed_training/distributed_training_gpu.html>`_。
|
||||
|
||||
目前尚不支持CPU。
|
||||
|
||||
.. py:class:: mindspore.communication.GlobalComm
|
||||
|
||||
全局通信信息。GlobalComm 是一个全局类。 成员包含:BACKEND、WORLD_COMM_GROUP。
|
||||
GlobalComm 是一个储存通信信息的全局类。 成员包含:BACKEND、WORLD_COMM_GROUP。
|
||||
|
||||
- BACKEND:使用的通信库,HCCL或者NCCL。
|
||||
- WORLD_COMM_GROUP:全局通信域。
|
||||
|
||||
**样例:**
|
||||
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
>>> GlobalComm.BACKEND
|
||||
|
||||
.. py:method:: mindspore.communication.init(backend_name=None)
|
||||
|
||||
初始化通信服务需要的分布式后端,例如‘HCCL’或‘NCCL’服务。
|
||||
|
@ -17,13 +31,13 @@ mindspore.communication
|
|||
|
||||
**参数:**
|
||||
|
||||
- **backend_name** (str) – 后台服务的名称,可选HCCL或NCCL。如果未设置则根据硬件平台类型(device_target)进行推断,默认值为None。
|
||||
- **backend_name** (str) – 分布式后端的名称,可选HCCL或NCCL。如果未设置则根据硬件平台类型(device_target)进行推断,默认值为None。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `backend_name` 不是字符串时抛出。
|
||||
- **RuntimeError** – 在以下情况将抛出:1)硬件设备类型无效;2)后台服务无效;3)分布式计算初始化失败;4)未设置环境变量 `RANK_ID` 或 `MINDSPORE_HCCL_CONFIG_PATH` 的情况下初始化HCCL服务。
|
||||
- **ValueError** – 在环境变量 `RANK_ID` 设置成非数字时抛出。
|
||||
- **TypeError** – 参数 `backend_name` 不是字符串。
|
||||
- **RuntimeError** – 1)硬件设备类型无效;2)后台服务无效;3)分布式计算初始化失败;4)未设置环境变量 `RANK_ID` 或 `MINDSPORE_HCCL_CONFIG_PATH` 的情况下初始化HCCL服务。
|
||||
- **ValueError** – 环境变量 `RANK_ID` 设置成非数字。
|
||||
|
||||
**样例:**
|
||||
|
||||
|
@ -65,7 +79,7 @@ mindspore.communication
|
|||
|
||||
获取指定通信组实例的rank_size。
|
||||
|
||||
.. note:: `get_group_size` 方法应该在 `init` 方法之后使用。
|
||||
.. note:: `get_group_size` 方法应该在 `init` 方法之后使用。在跑用例之前用户需要预先配置通信相关的环境变量。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
@ -81,6 +95,18 @@ mindspore.communication
|
|||
- **ValueError** – 在后台不可用时抛出。
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
||||
|
||||
**样例:**
|
||||
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
>>> group = "0-4"
|
||||
>>> rank_ids = [0,4]
|
||||
>>> create_group(group, rank_ids)
|
||||
>>> group_size = get_group_size(group)
|
||||
>>> print("group_size is: ", group_size)
|
||||
>>> group_size is:2
|
||||
|
||||
.. py:class:: mindspore.communication.get_world_rank_from_group_rank(group, group_rank_id)
|
||||
|
||||
由指定通信组中的设备序号获取通信集群中的全局设备序号。
|
||||
|
@ -101,9 +127,9 @@ mindspore.communication
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
|
||||
- **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用,以及使用GPU版本的MindSpore时抛出。
|
||||
- **TypeError** – 参数 `group` 不是字符串或参数 `group_rank_id` 不是数字。
|
||||
- **ValueError** – 参数 `group` 是 `hccl_world_group` 或后台不可用。
|
||||
- **RuntimeError** – ‘HCCL’或‘NCCL’服务不可用,以及使用CPU版本的MindSpore。
|
||||
|
||||
**样例:**
|
||||
|
||||
|
@ -168,9 +194,9 @@ mindspore.communication
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
|
||||
- **ValueError** – 在列表rank_ids的长度小于1,或列表rank_ids内有重复数据,以及后台无效时抛出。
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’ 服务不可用,以及使用GPU版本的MindSpore时抛出。
|
||||
- **TypeError** – 参数 `group_rank_id` 不是数字或参数 `group` 不是字符串。
|
||||
- **ValueError** – 列表rank_ids的长度小于1,或列表rank_ids内有重复数据,以及后台无效。
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’ 服务不可用,以及使用CPU版本的MindSpore。
|
||||
|
||||
**样例:**
|
||||
|
||||
|
@ -242,4 +268,14 @@ mindspore.communication
|
|||
|
||||
- **TypeError** – 在参数 `group` 不是字符串时抛出。
|
||||
- **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
||||
|
||||
**样例:**
|
||||
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
>>> group = "0-8"
|
||||
>>> rank_ids = [0,8]
|
||||
>>> create_group(group, rank_ids)
|
||||
>>> destroy_group(group)
|
||||
|
|
|
@ -314,8 +314,6 @@ MindSpore上下文,用于配置当前执行环境,包括执行模式、执
|
|||
- SEMI_AUTO_PARALLEL:半自动并行模式。
|
||||
- AUTO_PARALLEL:自动并行模式。
|
||||
|
||||
MODE_LIST:表示所有支持的并行模式的列表。
|
||||
|
||||
.. py:function:: mindspore.context.set_ps_context(**kwargs)
|
||||
|
||||
设置参数服务器训练模式的上下文。
|
||||
|
|
|
@ -5,7 +5,7 @@ mindspore.nn.DistributedGradReducer
|
|||
|
||||
分布式优化器。
|
||||
|
||||
对反向梯度进行AllReduce运算。
|
||||
用于数据并行模式中,对所有卡的梯度利用AllReduce进行聚合。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
@ -90,15 +90,3 @@ mindspore.nn.DistributedGradReducer
|
|||
>>> grads = train_cell(inputs, label)
|
||||
>>> print(grads)
|
||||
256.0
|
||||
|
||||
.. py:method:: construct(grads)
|
||||
|
||||
某些情况下,梯度的数据精度可以与float16和float32混合。因此,AllReduce的结果不可靠。要解决这个问题,必须在AllReduce之前强制转换为float32,并在操作之后再强制转换为float32。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **grads** (Union[Tensor, tuple[Tensor]]) - 操作前的梯度Tensor或tuple。
|
||||
|
||||
**返回:**
|
||||
|
||||
- **new_grads** (Union[Tensor, tuple[Tensor]]),操作后的梯度Tensor或tuple。
|
||||
|
|
|
@ -280,7 +280,7 @@ class DistributedGradReducer(Cell):
|
|||
A distributed optimizer.
|
||||
|
||||
Constructs a gradient reducer Cell, which applies communication and average operations on
|
||||
single-process gradient values.
|
||||
single-process gradient values. Used in data parallel.
|
||||
|
||||
Args:
|
||||
parameters (list): the parameters to be updated.
|
||||
|
|
Loading…
Reference in New Issue