forked from mindspore-Ecosystem/mindspore
add_CN_API
This commit is contained in:
parent
204822ecd8
commit
ae40fea111
|
@ -0,0 +1,293 @@
|
|||
mindspore.communication
|
||||
========================
|
||||
集合通信接口的类。
|
||||
|
||||
.. py:class:: mindspore.communication.GlobalComm
|
||||
|
||||
全局通信信息。GlobalComm 是一个全局类。 成员包含:BACKEND、WORLD_COMM_GROUP。
|
||||
|
||||
.. py:method:: mindspore.communication.init(backend_name=None)
|
||||
|
||||
初始化通信服务需要的分布式后端,例如HCCL或NCCL服务。
|
||||
|
||||
.. note::
|
||||
|
||||
HCCL的全称是华为集合通信库(Huawei Collective Communication Library),NCCL的全称是英伟达集合通信库(NVIDIA Collective Communication Library)。`init` 方法应该在 `set_context` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **backend_name** (`str`) – 后台服务的名称,可选HCCL或NCCL。如果未设置则根据硬件平台类型(device_target)进行推断,默认值为None。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `backend_name` 不是字符串时抛出。
|
||||
|
||||
- **RuntimeError** – 在以下情况将抛出:1)硬件设备类型无效;2)后台服务无效;3)分布式计算初始化失败;4)未设置环境变量 `RANK_ID` 或 `MINDSPORE_HCCL_CONFIG_PATH` 的情况下初始化HCCL服务。
|
||||
|
||||
- **ValueError** – 在环境变量 `RANK_ID` 设置成非数字时抛出。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
|
||||
.. py:class:: mindspore.communication.release()
|
||||
|
||||
释放分布式资源,例如‘HCCL’或‘NCCL’服务。
|
||||
|
||||
.. note::
|
||||
|
||||
`release` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **RuntimeError** - 在释放分布式资源失败时抛出。
|
||||
|
||||
.. py:class:: mindspore.communication.get_rank(group=GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
在指定通信组中获取当前的设备序号。
|
||||
|
||||
.. note::
|
||||
|
||||
`get_rank` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组。
|
||||
|
||||
- **默认值** - ‘WORLD_COMM_GROUP’。
|
||||
|
||||
**返回:**
|
||||
|
||||
int, 调用该方法的进程对应的组内序号。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在后台不可用时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
||||
|
||||
.. py:class:: mindspore.communication.get_group_size(group=GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
获取指定通信组的设备总数。
|
||||
|
||||
.. note::
|
||||
|
||||
`get_group_size` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组。
|
||||
|
||||
- **默认值** - ‘WORLD_COMM_GROUP’。
|
||||
|
||||
**返回:**
|
||||
|
||||
int, 指定通信组的设备总数。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在后台不可用时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
||||
|
||||
|
||||
|
||||
.. py:class:: mindspore.communication.get_world_rank_from_group_rank(group, group_rank_id)
|
||||
|
||||
由指定通信组中的设备序号获取通信集群中的全局设备序号。
|
||||
|
||||
.. note::
|
||||
|
||||
1. GPU 版本的MindSpore不支持此方法;
|
||||
2. 参数 `group` 不能是 `hccl_world_group`;
|
||||
3. `get_world_rank_from_group_rank` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建。
|
||||
|
||||
- **group_rank_id** (`int`) - 通信组内的设备序号
|
||||
|
||||
**返回:**
|
||||
|
||||
int, 通信集群中的全局设备序号。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用,以及使用GPU版本的MindSpore时抛出。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
>>> group = "0-4"
|
||||
>>> rank_ids = [0,4]
|
||||
>>> create_group(group, rank_ids)
|
||||
>>> world_rank_id = get_world_rank_from_group_rank(group, 1)
|
||||
>>> print("world_rank_id is: ", world_rank_id) # 全局设备序号为4
|
||||
|
||||
.. py:class:: mindspore.communication.get_group_rank_from_world_rank(world_rank_id, group)
|
||||
|
||||
由通信集群中的全局设备序号获取指定用户通信组中的设备序号。
|
||||
|
||||
.. note::
|
||||
|
||||
1. GPU 版本的MindSpore不支持此方法;
|
||||
2. 参数 `group` 不能是 `hccl_world_group`;
|
||||
3. `get_group_rank_from_world_rank` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **world_rank_id** (`int`) - 通信集群内的全局设备序号。
|
||||
|
||||
- **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建。
|
||||
|
||||
**返回:**
|
||||
|
||||
int, 当前用户通信组中的设备序号。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用,以及使用GPU版本的MindSpore时抛出。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
>>> group = "0-4"
|
||||
>>> rank_ids = [0,4]
|
||||
>>> create_group(group, rank_ids)
|
||||
>>> group_rank_id = get_group_rank_from_world_rank(4, group)
|
||||
>>> print("group_rank_id is: ", group_rank_id) # 组内设备序号是1
|
||||
|
||||
.. py:class:: mindspore.communication.create_group(group, rank_ids)
|
||||
|
||||
创建用户通信组。
|
||||
|
||||
.. note::
|
||||
|
||||
1. GPU 版本的MindSpore不支持此方法;
|
||||
2. 列表rank_ids的长度应大于1;
|
||||
3. 列表rank_ids内不能有重复数据;
|
||||
4. `create_group` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **group** (`str`) - 将被创建的通信组名称。
|
||||
|
||||
- **rank_ids** (`list`) - 设备编号列表。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在列表rank_ids的长度小于1,或列表rank_ids内有重复数据,以及后台无效时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’ 服务不可用,以及使用GPU版本的MindSpore时抛出。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.context import set_context
|
||||
>>> set_context(device_target="Ascend")
|
||||
>>> init()
|
||||
>>> group = "0-8"
|
||||
>>> rank_ids = [0,8]
|
||||
>>> create_group(group, rank_ids)
|
||||
|
||||
.. py:class:: mindspore.communication.get_local_rank(group=GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
获取指定通信组中当前设备的本地设备序号。
|
||||
|
||||
.. note::
|
||||
|
||||
1. GPU 版本的MindSpore不支持此方法;
|
||||
2. `get_local_rank` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组名称。
|
||||
|
||||
- **默认值** - ‘WORLD_COMM_GROUP’。
|
||||
|
||||
**返回:**
|
||||
|
||||
int, 调用该方法的进程对应的通信组内本地设备序号。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在后台不可用时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
||||
|
||||
.. py:class:: mindspore.communication.get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
获取指定通信组的本地设备总数。
|
||||
|
||||
.. note::
|
||||
|
||||
1. GPU 版本的MindSpore不支持此方法;
|
||||
2. `get_local_rank_size` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建,或默认使用‘WORLD_COMM_GROUP’。
|
||||
|
||||
**返回:**
|
||||
|
||||
int, 调用该方法的进程对应的通信组设备总数。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在后台不可用时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
||||
|
||||
.. py:class:: mindspore.communication.destroy_group(group)
|
||||
|
||||
销毁用户通信组。
|
||||
|
||||
.. note::
|
||||
|
||||
1. GPU 版本的MindSpore不支持此方法;
|
||||
2. 参数 `group` 不能是 `hccl_world_group`;
|
||||
3. `destroy_group` 方法应该在 `init` 方法之后使用。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **group** (`str`) - 将被销毁的通信组,通常由 `create_group` 方法创建。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 在参数 `group` 不是字符串时抛出。
|
||||
|
||||
- **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。
|
||||
|
||||
- **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。
|
|
@ -0,0 +1,125 @@
|
|||
get_auto_parallel_context(attr_key)
|
||||
|
||||
根据key获取自动并行的配置。
|
||||
|
||||
参数:
|
||||
attr_key (str):配置的key。
|
||||
|
||||
返回:
|
||||
根据key返回配置的值。
|
||||
|
||||
异常:
|
||||
ValueError:输入key不在自动并行的配置列表中。
|
||||
|
||||
Class mindspore.context.ParallelMode
|
||||
|
||||
并行模式。
|
||||
|
||||
有五种并行模式,分别是STAND_ALONE、DATA_PARALLEL、HYBRID_PARALLEL、SEMI_AUTO_PARALLEL和AUTO_PARALLEL。
|
||||
默认值:STAND_ALONE。
|
||||
|
||||
- STAND_ALONE:单卡模式。
|
||||
- DATA_PARALLEL:数据并行模式。
|
||||
- HYBRID_PARALLEL:手动实现数据并行和模型并行。
|
||||
- SEMI_AUTO_PARALLEL:半自动并行模式。
|
||||
- AUTO_PARALLEL:自动并行模式。
|
||||
|
||||
MODE_LIST:表示所有支持的并行模式的列表。
|
||||
|
||||
|
||||
reset_auto_parallel_context()
|
||||
|
||||
重置自动并行的配置为默认值。
|
||||
|
||||
- device_num:1。
|
||||
- global_rank:0。
|
||||
- gradients_mean:False。
|
||||
- gradient_fp32_sync:True。
|
||||
- parallel_mode:'stand_alone'。
|
||||
- auto_parallel_search_mode:'dynamic_programming'。
|
||||
- parameter_broadcast:False。
|
||||
- strategy_ckpt_load_file:''。
|
||||
- strategy_ckpt_save_file:''。
|
||||
- full_batch:False。
|
||||
- enable_parallel_optimizer:False。
|
||||
- pipeline_stages:1。
|
||||
|
||||
set_auto_parallel_context(**kwargs)
|
||||
|
||||
配置自动并行,仅在Ascend和GPU上有效。
|
||||
|
||||
应在init之前配置自动并行。
|
||||
|
||||
注:
|
||||
配置时,必须输入配置的名称。
|
||||
如果某个程序具有不同并行模式下的任务,则需要再为下一个任务设置新的并行模式之前,调用reset_auto_parallel_context()接口来重置配置。
|
||||
若要设置或更改并行模式,必须在创建任何Initializer之前调用接口,否则,在编译网络时,可能会出现RuntimeError。
|
||||
|
||||
某些配置适用于特定的并行模式,有关详细信息,请参见下表:
|
||||
|
||||
=========================== ===========================
|
||||
Common AUTO_PARALLEL
|
||||
=========================== ===========================
|
||||
device_num gradient_fp32_sync
|
||||
global_rank loss_repeated_mean
|
||||
gradients_mean auto_parallel_search_mode
|
||||
parallel_mode strategy_ckpt_load_file
|
||||
all_reduce_fusion_config strategy_ckpt_save_file
|
||||
enable_parallel_optimizer dataset_strategy
|
||||
\ pipeline_stages
|
||||
\ grad_accumulation_step
|
||||
=========================== ===========================
|
||||
|
||||
参数:
|
||||
device_num (int):表示可用设备的编号,必须在【1,4096】范围中。默认值:1。
|
||||
global_rank (int):表示全局秩的ID,必须在【0,4095】范围中。默认值:0。
|
||||
gradients_mean (bool):表示是否在梯度的allreduce后执行平均算子。
|
||||
stand_alone不支持gradients_mean。默认值:False。
|
||||
gradient_fp32_sync (bool):在FP32中运行gradients的allreduce。stand_alone、data_parallel和hybrid_parallel不支持gradient_fp32_sync。默认值:True。
|
||||
parallel_mode (str):有五种并行模式,分别是stand_alone、data_parallel、hybrid_parallel、semi_auto_parallel和auto_parallel。默认值:stand_alone。
|
||||
|
||||
- stand_alone:单卡模式。
|
||||
|
||||
- data_parallel:数据并行模式。
|
||||
|
||||
- hybrid_parallel:手动实现数据并行和模型并行。
|
||||
|
||||
- semi_auto_parallel:半自动并行模式。
|
||||
|
||||
- auto_parallel:自动并行模式。
|
||||
auto_parallel_search_mode (str):表示有两种策略搜索模式,分别是recursive_programming和dynamic_programming。默认值:dynamic_programming。
|
||||
|
||||
- recursive_programming:表示双递归搜索模式。
|
||||
|
||||
- dynamic_programming:表示动态规划搜索模式。
|
||||
parameter_broadcast (bool):表示在训练前是否广播参数。在训练之前,为了使所有设备的网络初始化参数值相同,请将设备0上的参数广播到其他设备。不同并行模式下的参数广播不同。
|
||||
在data_parallel模式下,除layerwise_parallel属性为True的参数外,所有参数都会被广播。在Hybrid_parallel、semi_auto_parallel和auto_parallel模式下,分段参数不参与广播。默认值:False。
|
||||
strategy_ckpt_load_file (str):表示用于加载并行策略checkpoint的路径。默认值:''。
|
||||
strategy_ckpt_save_file (str):表示用于保存并行策略checkpoint的路径。默认值:''。
|
||||
full_batch (bool):如果在auto_parallel模式下加载整个batch数据集,则此参数应设置为True。默认值:False。目前不建议使用该接口,建议使用dataset_strategy来替换它。
|
||||
dataset_strategy (Union[str, tuple]):表示数据集分片策略。默认值:data_parallel。
|
||||
dataset_strategy="data_parallel"等于full_batch=False,dataset_strategy="full_batch"等于full_batch=True。对于通过模型并列策略加载到网络的数据集,如ds_stra ((1, 8)、(1, 8)),需要使用set_auto_parallel_context(dataset_strategy=ds_stra)。
|
||||
enable_parallel_optimizer (bool):这是一个开发中的特性,它可以为数据并行训练对权重更新计算进行分片,以节省时间和内存。目前,自动和半自动并行模式支持Ascend和GPU中的所有优化器。数据并行模式仅支持Ascend中的`Lamb`和`AdamWeightDecay`。默认值:False。
|
||||
all_reduce_fusion_config (list):通过参数索引设置allreduce 融合策略。仅支持ReduceOp.SUM和HCCL_WORLD_GROUP/NCCL_WORLD_GROUP。没有默认值。如果不设置,则关闭算子融合。
|
||||
pipeline_stages (int):设置pipeline并行的阶段信息。这表明了设备如何单独分布在pipeline上。所有的设备将被划分为pipeline_stags个阶段。
|
||||
目前,这只能在启动semi_auto_parallel模式的情况下使用。默认值:1。
|
||||
grad_accumulation_step (int):在自动和半自动并行模式下设置梯度的累积step。
|
||||
其值应为正整数。默认值:1。
|
||||
|
||||
异常:
|
||||
ValueError:输入key不是自动并行上下文中的属性。
|
||||
|
||||
样例:
|
||||
>>> context.set_auto_parallel_context(device_num=8)
|
||||
>>> context.set_auto_parallel_context(global_rank=0)
|
||||
>>> context.set_auto_parallel_context(gradients_mean=True)
|
||||
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
||||
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
|
||||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8)))
|
||||
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
|
||||
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
|
||||
>>> context.set_auto_parallel_context(pipeline_stages=2)
|
|
@ -0,0 +1,22 @@
|
|||
mindspore.load_distributed_checkpoint
|
||||
======================================
|
||||
|
||||
.. py:method:: mindspore.load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM')
|
||||
|
||||
给分布式预测加载checkpoint文件到网络,用于分布式推理。关于分布式推理的细节,请参考:'<https://www.mindspore.cn/docs/programming_guide/zh-CN/master/distributed_inference.html>' 。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **network** (Cell):分布式预测网络。
|
||||
- **checkpoint_filenames** (list[str]):checkpoint文件的名称,按rank id 顺序排列。
|
||||
- **predict_strategy** (dict):predict时参数的切分策略。
|
||||
- **train_strategy_filename** (str):训练策略proto文件名。默认值:None。
|
||||
- **strict_load** (bool):表示是否严格加载参数到网络。如果值为False,则当checkpoint文件中参数名称的后缀与网络中的参数相同时,加载参数到网络。当类型不一致时,对相同类型的参数进行类型转换,如从float32到float16。默认值:False。
|
||||
- **dec_key** (Union[None, bytes]):用于解密的字节类型key。如果value为None,则不需要解密。默认值:None。
|
||||
- **dec_mode** (str):仅当dec_key不设为None时,该参数有效。指定了解密模式,目前支持AES-GCM和AES-CBC。默认值:AES-GCM。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError:** 输入类型不符合要求。
|
||||
- **ValueError:** 无法加载checkpoint文件到网络。
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
Class mindspore.nn.DistributedGradReducer(parameters, mean=True, degree=None, fusion_type=1, group='hccl_world_group')
|
||||
|
||||
分布式优化器。
|
||||
|
||||
对反向梯度进行AllReduce运算。
|
||||
|
||||
|
||||
参数:
|
||||
parameters (list):需要更新的参数。
|
||||
mean (bool):当mean为True时,对AllReduce之后的梯度求均值。默认值:False。
|
||||
degree (int):平均系数,通常等于设备编号。默认值:None。
|
||||
fusion_type (int):AllReduce算子的融合类型。默认值:1。
|
||||
|
||||
异常:
|
||||
ValueError:如果degree不是int或小于0。
|
||||
|
||||
支持平台:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
示例:
|
||||
>>> #此示例应与多个进程一起运行。
|
||||
>>> #请参考Mindpore.cn上的“教程>分布式训练”。
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.communication import init
|
||||
>>> from mindspore import ops
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore.context import ParallelMode
|
||||
>>> from mindspore import Parameter, Tensor
|
||||
>>> from mindspore import nn
|
||||
>>>
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> init()
|
||||
>>> context.reset_auto_parallel_context()
|
||||
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
>>>
|
||||
>>> class TrainingWrapper(nn.Cell):
|
||||
... def __init__(self, network, optimizer, sens=1.0):
|
||||
... super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
... self.network = network
|
||||
... self.network.add_flags(defer_inline=True)
|
||||
... self.weights = optimizer.parameters
|
||||
... self.optimizer = optimizer
|
||||
... self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
... self.sens = sens
|
||||
... self.reducer_flag = False
|
||||
... self.grad_reducer = None
|
||||
... self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
... if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
... self.reducer_flag = True
|
||||
... if self.reducer_flag:
|
||||
... mean = context.get_auto_parallel_context("gradients_mean")
|
||||
... degree = context.get_auto_parallel_context("device_num")
|
||||
... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
...
|
||||
... def construct(self, *args):
|
||||
... weights = self.weights
|
||||
... loss = self.network(*args)
|
||||
... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
|
||||
... grads = self.grad(self.network, weights)(*args, sens)
|
||||
... if self.reducer_flag:
|
||||
... # apply grad reducer on grads
|
||||
... grads = self.grad_reducer(grads)
|
||||
... return ops.Depend(loss, self.optimizer(grads))
|
||||
>>>
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self, in_features, out_features):
|
||||
... super(Net, self).__init__()
|
||||
... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
|
||||
... name='weight')
|
||||
... self.matmul = ops.MatMul()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... output = self.matmul(x, self.weight)
|
||||
... return output
|
||||
>>>
|
||||
>>> size, in_features, out_features = 16, 16, 10
|
||||
>>> network = Net(in_features, out_features)
|
||||
>>> loss = nn.MSELoss()
|
||||
>>> net_with_loss = nn.WithLossCell(network, loss)
|
||||
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> train_cell = TrainingWrapper(net_with_loss, optimizer)
|
||||
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
|
||||
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
|
||||
>>> grads = train_cell(inputs, label)
|
||||
>>> print(grads)
|
||||
256.0
|
||||
|
||||
construct(grads)
|
||||
|
||||
某些情况下,梯度的数据精度可以与float16和float32混合。因此,AllReduce的结果不可靠。
|
||||
要解决这个问题,必须在AllReduce之前强制转换为float32,并在操作之后再强制转换为float32。
|
||||
|
||||
|
||||
参数:
|
||||
grads (Union[Tensor, tuple[Tensor]]):操作前的梯度tensor或tuple。
|
||||
|
||||
返回:
|
||||
new_grads (Union[Tensor, tuple[Tensor]]),操作后的梯度tensor或tuple。
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
Class mindspore.nn.PipelineCell(network, micro_size)
|
||||
|
||||
将MiniBatch切分成更细粒度的MicroBatch,用于流水线并行的训练中。
|
||||
|
||||
注:
|
||||
micro_size必须大于或等于流水线stage的个数。
|
||||
|
||||
参数:
|
||||
network (Cell):要包裹的目标网络。
|
||||
micro_size (int):MicroBatch大小。
|
||||
|
||||
示例:
|
||||
>>> net = Net()
|
||||
>>> net = PipelineCell(net, 4)
|
||||
|
|
@ -157,7 +157,7 @@ class AllReduce(PrimitiveWithInfer):
|
|||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('index', 0)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
@ -231,7 +231,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('mean_flag', False)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||
|
@ -350,7 +350,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
|||
validator.check_value_type("rank_id", r, (int,), self.name)
|
||||
self.group_size = len(group)
|
||||
self.add_prim_attr('group', group)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||
|
@ -425,7 +425,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if self.rank_size == 0:
|
||||
|
@ -480,7 +480,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
|||
self.op = op
|
||||
self.group_size = len(group)
|
||||
self.add_prim_attr('group', group)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if x_shape[0] % self.group_size != 0:
|
||||
|
@ -558,7 +558,7 @@ class Broadcast(PrimitiveWithInfer):
|
|||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
check_hcom_group_valid(group, prim_name=self.name)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
@ -605,7 +605,7 @@ class AllSwap(PrimitiveWithCheck):
|
|||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def __check__(self, tensor_in, send_size, recv_size):
|
||||
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
|
||||
|
@ -650,7 +650,7 @@ class NeighborExchange(Primitive):
|
|||
self.recv_shapes = recv_shapes
|
||||
self.send_shapes = send_shapes
|
||||
self.recv_type = recv_type
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
@ -690,7 +690,7 @@ class AlltoAll(PrimitiveWithInfer):
|
|||
self.split_dim = split_dim
|
||||
self.concat_dim = concat_dim
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
rank_size = get_group_size(_get_group(self.group))
|
||||
|
@ -740,7 +740,7 @@ class NeighborExchangeV2(Primitive):
|
|||
self.send_lens = send_lens
|
||||
self.recv_lens = recv_lens
|
||||
self.format = data_format
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
self.add_prim_attr('no_eliminate', True)
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1450,9 +1450,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
Args:
|
||||
network (Cell): Network for distributed predication.
|
||||
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
|
||||
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
|
||||
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
|
||||
it means that the predication process just uses single device. Default: None.
|
||||
predict_strategy (dict): Strategy of predication process. Default: None.
|
||||
train_strategy_filename (str): Train strategy proto file name. Default: None.
|
||||
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
||||
into net when parameter name's suffix in checkpoint file is the same as the
|
||||
|
|
|
@ -320,7 +320,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
|||
} else if (name == "index") {
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||
ASSERT_EQ(converted_ret->ToString(), "0");
|
||||
} else if (name == "no_elimilate") {
|
||||
} else if (name == "no_eliminate") {
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||
ASSERT_EQ(converted_ret->ToString(), "true");
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue