diff --git a/docs/api/api_python/parallel/mindspore.communication.rst b/docs/api/api_python/parallel/mindspore.communication.rst new file mode 100644 index 00000000000..30cbab08bda --- /dev/null +++ b/docs/api/api_python/parallel/mindspore.communication.rst @@ -0,0 +1,293 @@ +mindspore.communication +======================== +集合通信接口的类。 + +.. py:class:: mindspore.communication.GlobalComm + + 全局通信信息。GlobalComm 是一个全局类。 成员包含:BACKEND、WORLD_COMM_GROUP。 + +.. py:method:: mindspore.communication.init(backend_name=None) + + 初始化通信服务需要的分布式后端,例如HCCL或NCCL服务。 + + .. note:: + + HCCL的全称是华为集合通信库(Huawei Collective Communication Library),NCCL的全称是英伟达集合通信库(NVIDIA Collective Communication Library)。`init` 方法应该在 `set_context` 方法之后使用。 + + **参数:** + + - **backend_name** (`str`) – 后台服务的名称,可选HCCL或NCCL。如果未设置则根据硬件平台类型(device_target)进行推断,默认值为None。 + + **异常:** + + - **TypeError** – 在参数 `backend_name` 不是字符串时抛出。 + + - **RuntimeError** – 在以下情况将抛出:1)硬件设备类型无效;2)后台服务无效;3)分布式计算初始化失败;4)未设置环境变量 `RANK_ID` 或 `MINDSPORE_HCCL_CONFIG_PATH` 的情况下初始化HCCL服务。 + + - **ValueError** – 在环境变量 `RANK_ID` 设置成非数字时抛出。 + + **样例:** + + .. code-block:: + + >>> from mindspore.context import set_context + >>> set_context(device_target="Ascend") + >>> init() + +.. py:class:: mindspore.communication.release() + + 释放分布式资源,例如‘HCCL’或‘NCCL’服务。 + + .. note:: + + `release` 方法应该在 `init` 方法之后使用。 + + **异常:** + + - **RuntimeError** - 在释放分布式资源失败时抛出。 + +.. py:class:: mindspore.communication.get_rank(group=GlobalComm.WORLD_COMM_GROUP) + + 在指定通信组中获取当前的设备序号。 + + .. note:: + + `get_rank` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组。 + + - **默认值** - ‘WORLD_COMM_GROUP’。 + + **返回:** + + int, 调用该方法的进程对应的组内序号。 + + **异常:** + + - **TypeError** – 在参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在后台不可用时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。 + +.. py:class:: mindspore.communication.get_group_size(group=GlobalComm.WORLD_COMM_GROUP) + + 获取指定通信组的设备总数。 + + .. note:: + + `get_group_size` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组。 + + - **默认值** - ‘WORLD_COMM_GROUP’。 + + **返回:** + + int, 指定通信组的设备总数。 + + **异常:** + + - **TypeError** – 在参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在后台不可用时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。 + + + +.. py:class:: mindspore.communication.get_world_rank_from_group_rank(group, group_rank_id) + + 由指定通信组中的设备序号获取通信集群中的全局设备序号。 + + .. note:: + + 1. GPU 版本的MindSpore不支持此方法; + 2. 参数 `group` 不能是 `hccl_world_group`; + 3. `get_world_rank_from_group_rank` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建。 + + - **group_rank_id** (`int`) - 通信组内的设备序号 + + **返回:** + + int, 通信集群中的全局设备序号。 + + **异常:** + + - **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用,以及使用GPU版本的MindSpore时抛出。 + + **样例:** + + .. code-block:: + + >>> from mindspore.context import set_context + >>> set_context(device_target="Ascend") + >>> init() + >>> group = "0-4" + >>> rank_ids = [0,4] + >>> create_group(group, rank_ids) + >>> world_rank_id = get_world_rank_from_group_rank(group, 1) + >>> print("world_rank_id is: ", world_rank_id) # 全局设备序号为4 + +.. py:class:: mindspore.communication.get_group_rank_from_world_rank(world_rank_id, group) + + 由通信集群中的全局设备序号获取指定用户通信组中的设备序号。 + + .. note:: + + 1. GPU 版本的MindSpore不支持此方法; + 2. 参数 `group` 不能是 `hccl_world_group`; + 3. `get_group_rank_from_world_rank` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **world_rank_id** (`int`) - 通信集群内的全局设备序号。 + + - **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建。 + + **返回:** + + int, 当前用户通信组中的设备序号。 + + **异常:** + + - **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用,以及使用GPU版本的MindSpore时抛出。 + + **样例:** + + .. code-block:: + + >>> from mindspore.context import set_context + >>> set_context(device_target="Ascend") + >>> init() + >>> group = "0-4" + >>> rank_ids = [0,4] + >>> create_group(group, rank_ids) + >>> group_rank_id = get_group_rank_from_world_rank(4, group) + >>> print("group_rank_id is: ", group_rank_id) # 组内设备序号是1 + +.. py:class:: mindspore.communication.create_group(group, rank_ids) + + 创建用户通信组。 + + .. note:: + + 1. GPU 版本的MindSpore不支持此方法; + 2. 列表rank_ids的长度应大于1; + 3. 列表rank_ids内不能有重复数据; + 4. `create_group` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **group** (`str`) - 将被创建的通信组名称。 + + - **rank_ids** (`list`) - 设备编号列表。 + + **异常:** + + - **TypeError** – 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在列表rank_ids的长度小于1,或列表rank_ids内有重复数据,以及后台无效时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’ 服务不可用,以及使用GPU版本的MindSpore时抛出。 + + **样例:** + + .. code-block:: + + >>> from mindspore.context import set_context + >>> set_context(device_target="Ascend") + >>> init() + >>> group = "0-8" + >>> rank_ids = [0,8] + >>> create_group(group, rank_ids) + +.. py:class:: mindspore.communication.get_local_rank(group=GlobalComm.WORLD_COMM_GROUP) + + 获取指定通信组中当前设备的本地设备序号。 + + .. note:: + + 1. GPU 版本的MindSpore不支持此方法; + 2. `get_local_rank` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组名称。 + + - **默认值** - ‘WORLD_COMM_GROUP’。 + + **返回:** + + int, 调用该方法的进程对应的通信组内本地设备序号。 + + **异常:** + + - **TypeError** – 在参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在后台不可用时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。 + +.. py:class:: mindspore.communication.get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP) + + 获取指定通信组的本地设备总数。 + + .. note:: + + 1. GPU 版本的MindSpore不支持此方法; + 2. `get_local_rank_size` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建,或默认使用‘WORLD_COMM_GROUP’。 + + **返回:** + + int, 调用该方法的进程对应的通信组设备总数。 + + **异常:** + + - **TypeError** – 在参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在后台不可用时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。 + +.. py:class:: mindspore.communication.destroy_group(group) + + 销毁用户通信组。 + + .. note:: + + 1. GPU 版本的MindSpore不支持此方法; + 2. 参数 `group` 不能是 `hccl_world_group`; + 3. `destroy_group` 方法应该在 `init` 方法之后使用。 + + **参数:** + + - **group** (`str`) - 将被销毁的通信组,通常由 `create_group` 方法创建。 + + **异常:** + + - **TypeError** – 在参数 `group` 不是字符串时抛出。 + + - **ValueError** – 在参数 `group` 是 `hccl_world_group` 或后台不可用时抛出。 + + - **RuntimeError** – 在‘HCCL’或‘NCCL’服务不可用时抛出。 \ No newline at end of file diff --git a/docs/api/api_python/parallel/mindspore.context.txt b/docs/api/api_python/parallel/mindspore.context.txt new file mode 100644 index 00000000000..2597a84bea8 --- /dev/null +++ b/docs/api/api_python/parallel/mindspore.context.txt @@ -0,0 +1,125 @@ +get_auto_parallel_context(attr_key) + + 根据key获取自动并行的配置。 + + 参数: + attr_key (str):配置的key。 + + 返回: + 根据key返回配置的值。 + + 异常: + ValueError:输入key不在自动并行的配置列表中。 + +Class mindspore.context.ParallelMode + + 并行模式。 + + 有五种并行模式,分别是STAND_ALONE、DATA_PARALLEL、HYBRID_PARALLEL、SEMI_AUTO_PARALLEL和AUTO_PARALLEL。 + 默认值:STAND_ALONE。 + + - STAND_ALONE:单卡模式。 + - DATA_PARALLEL:数据并行模式。 + - HYBRID_PARALLEL:手动实现数据并行和模型并行。 + - SEMI_AUTO_PARALLEL:半自动并行模式。 + - AUTO_PARALLEL:自动并行模式。 + + MODE_LIST:表示所有支持的并行模式的列表。 + + +reset_auto_parallel_context() + + 重置自动并行的配置为默认值。 + + - device_num:1。 + - global_rank:0。 + - gradients_mean:False。 + - gradient_fp32_sync:True。 + - parallel_mode:'stand_alone'。 + - auto_parallel_search_mode:'dynamic_programming'。 + - parameter_broadcast:False。 + - strategy_ckpt_load_file:''。 + - strategy_ckpt_save_file:''。 + - full_batch:False。 + - enable_parallel_optimizer:False。 + - pipeline_stages:1。 + +set_auto_parallel_context(**kwargs) + + 配置自动并行,仅在Ascend和GPU上有效。 + + 应在init之前配置自动并行。 + + 注: + 配置时,必须输入配置的名称。 + 如果某个程序具有不同并行模式下的任务,则需要再为下一个任务设置新的并行模式之前,调用reset_auto_parallel_context()接口来重置配置。 + 若要设置或更改并行模式,必须在创建任何Initializer之前调用接口,否则,在编译网络时,可能会出现RuntimeError。 + + 某些配置适用于特定的并行模式,有关详细信息,请参见下表: + + =========================== =========================== + Common AUTO_PARALLEL + =========================== =========================== + device_num gradient_fp32_sync + global_rank loss_repeated_mean + gradients_mean auto_parallel_search_mode + parallel_mode strategy_ckpt_load_file + all_reduce_fusion_config strategy_ckpt_save_file + enable_parallel_optimizer dataset_strategy + \ pipeline_stages + \ grad_accumulation_step + =========================== =========================== + + 参数: + device_num (int):表示可用设备的编号,必须在【1,4096】范围中。默认值:1。 + global_rank (int):表示全局秩的ID,必须在【0,4095】范围中。默认值:0。 + gradients_mean (bool):表示是否在梯度的allreduce后执行平均算子。 + stand_alone不支持gradients_mean。默认值:False。 + gradient_fp32_sync (bool):在FP32中运行gradients的allreduce。stand_alone、data_parallel和hybrid_parallel不支持gradient_fp32_sync。默认值:True。 + parallel_mode (str):有五种并行模式,分别是stand_alone、data_parallel、hybrid_parallel、semi_auto_parallel和auto_parallel。默认值:stand_alone。 + + - stand_alone:单卡模式。 + + - data_parallel:数据并行模式。 + + - hybrid_parallel:手动实现数据并行和模型并行。 + + - semi_auto_parallel:半自动并行模式。 + + - auto_parallel:自动并行模式。 + auto_parallel_search_mode (str):表示有两种策略搜索模式,分别是recursive_programming和dynamic_programming。默认值:dynamic_programming。 + + - recursive_programming:表示双递归搜索模式。 + + - dynamic_programming:表示动态规划搜索模式。 + parameter_broadcast (bool):表示在训练前是否广播参数。在训练之前,为了使所有设备的网络初始化参数值相同,请将设备0上的参数广播到其他设备。不同并行模式下的参数广播不同。 + 在data_parallel模式下,除layerwise_parallel属性为True的参数外,所有参数都会被广播。在Hybrid_parallel、semi_auto_parallel和auto_parallel模式下,分段参数不参与广播。默认值:False。 + strategy_ckpt_load_file (str):表示用于加载并行策略checkpoint的路径。默认值:''。 + strategy_ckpt_save_file (str):表示用于保存并行策略checkpoint的路径。默认值:''。 + full_batch (bool):如果在auto_parallel模式下加载整个batch数据集,则此参数应设置为True。默认值:False。目前不建议使用该接口,建议使用dataset_strategy来替换它。 + dataset_strategy (Union[str, tuple]):表示数据集分片策略。默认值:data_parallel。 + dataset_strategy="data_parallel"等于full_batch=False,dataset_strategy="full_batch"等于full_batch=True。对于通过模型并列策略加载到网络的数据集,如ds_stra ((1, 8)、(1, 8)),需要使用set_auto_parallel_context(dataset_strategy=ds_stra)。 + enable_parallel_optimizer (bool):这是一个开发中的特性,它可以为数据并行训练对权重更新计算进行分片,以节省时间和内存。目前,自动和半自动并行模式支持Ascend和GPU中的所有优化器。数据并行模式仅支持Ascend中的`Lamb`和`AdamWeightDecay`。默认值:False。 + all_reduce_fusion_config (list):通过参数索引设置allreduce 融合策略。仅支持ReduceOp.SUM和HCCL_WORLD_GROUP/NCCL_WORLD_GROUP。没有默认值。如果不设置,则关闭算子融合。 + pipeline_stages (int):设置pipeline并行的阶段信息。这表明了设备如何单独分布在pipeline上。所有的设备将被划分为pipeline_stags个阶段。 + 目前,这只能在启动semi_auto_parallel模式的情况下使用。默认值:1。 + grad_accumulation_step (int):在自动和半自动并行模式下设置梯度的累积step。 + 其值应为正整数。默认值:1。 + + 异常: + ValueError:输入key不是自动并行上下文中的属性。 + + 样例: + >>> context.set_auto_parallel_context(device_num=8) + >>> context.set_auto_parallel_context(global_rank=0) + >>> context.set_auto_parallel_context(gradients_mean=True) + >>> context.set_auto_parallel_context(gradient_fp32_sync=False) + >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") + >>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming") + >>> context.set_auto_parallel_context(parameter_broadcast=False) + >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") + >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") + >>> context.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8))) + >>> context.set_auto_parallel_context(enable_parallel_optimizer=False) + >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) + >>> context.set_auto_parallel_context(pipeline_stages=2) diff --git a/docs/api/api_python/parallel/mindspore.load_distributed_checkpoint.txt b/docs/api/api_python/parallel/mindspore.load_distributed_checkpoint.txt new file mode 100644 index 00000000000..da0d8ed6443 --- /dev/null +++ b/docs/api/api_python/parallel/mindspore.load_distributed_checkpoint.txt @@ -0,0 +1,22 @@ +mindspore.load_distributed_checkpoint +====================================== + +.. py:method:: mindspore.load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM') + + 给分布式预测加载checkpoint文件到网络,用于分布式推理。关于分布式推理的细节,请参考:'' 。 + + **参数:** + + - **network** (Cell):分布式预测网络。 + - **checkpoint_filenames** (list[str]):checkpoint文件的名称,按rank id 顺序排列。 + - **predict_strategy** (dict):predict时参数的切分策略。 + - **train_strategy_filename** (str):训练策略proto文件名。默认值:None。 + - **strict_load** (bool):表示是否严格加载参数到网络。如果值为False,则当checkpoint文件中参数名称的后缀与网络中的参数相同时,加载参数到网络。当类型不一致时,对相同类型的参数进行类型转换,如从float32到float16。默认值:False。 + - **dec_key** (Union[None, bytes]):用于解密的字节类型key。如果value为None,则不需要解密。默认值:None。 + - **dec_mode** (str):仅当dec_key不设为None时,该参数有效。指定了解密模式,目前支持AES-GCM和AES-CBC。默认值:AES-GCM。 + + **异常:** + + - **TypeError:** 输入类型不符合要求。 + - **ValueError:** 无法加载checkpoint文件到网络。 + \ No newline at end of file diff --git a/docs/api/api_python/parallel/mindspore.nn.DistributedGradReducer.txt b/docs/api/api_python/parallel/mindspore.nn.DistributedGradReducer.txt new file mode 100644 index 00000000000..a89eee23d0d --- /dev/null +++ b/docs/api/api_python/parallel/mindspore.nn.DistributedGradReducer.txt @@ -0,0 +1,99 @@ +Class mindspore.nn.DistributedGradReducer(parameters, mean=True, degree=None, fusion_type=1, group='hccl_world_group') + + 分布式优化器。 + + 对反向梯度进行AllReduce运算。 + + + 参数: + parameters (list):需要更新的参数。 + mean (bool):当mean为True时,对AllReduce之后的梯度求均值。默认值:False。 + degree (int):平均系数,通常等于设备编号。默认值:None。 + fusion_type (int):AllReduce算子的融合类型。默认值:1。 + + 异常: + ValueError:如果degree不是int或小于0。 + + 支持平台: + ``Ascend`` ``GPU`` + + 示例: + >>> #此示例应与多个进程一起运行。 + >>> #请参考Mindpore.cn上的“教程>分布式训练”。 + >>> import numpy as np + >>> from mindspore.communication import init + >>> from mindspore import ops + >>> from mindspore import context + >>> from mindspore.context import ParallelMode + >>> from mindspore import Parameter, Tensor + >>> from mindspore import nn + >>> + >>> context.set_context(mode=context.GRAPH_MODE) + >>> init() + >>> context.reset_auto_parallel_context() + >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) + >>> + >>> class TrainingWrapper(nn.Cell): + ... def __init__(self, network, optimizer, sens=1.0): + ... super(TrainingWrapper, self).__init__(auto_prefix=False) + ... self.network = network + ... self.network.add_flags(defer_inline=True) + ... self.weights = optimizer.parameters + ... self.optimizer = optimizer + ... self.grad = ops.GradOperation(get_by_list=True, sens_param=True) + ... self.sens = sens + ... self.reducer_flag = False + ... self.grad_reducer = None + ... self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + ... if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + ... self.reducer_flag = True + ... if self.reducer_flag: + ... mean = context.get_auto_parallel_context("gradients_mean") + ... degree = context.get_auto_parallel_context("device_num") + ... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + ... + ... def construct(self, *args): + ... weights = self.weights + ... loss = self.network(*args) + ... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) + ... grads = self.grad(self.network, weights)(*args, sens) + ... if self.reducer_flag: + ... # apply grad reducer on grads + ... grads = self.grad_reducer(grads) + ... return ops.Depend(loss, self.optimizer(grads)) + >>> + >>> class Net(nn.Cell): + ... def __init__(self, in_features, out_features): + ... super(Net, self).__init__() + ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), + ... name='weight') + ... self.matmul = ops.MatMul() + ... + ... def construct(self, x): + ... output = self.matmul(x, self.weight) + ... return output + >>> + >>> size, in_features, out_features = 16, 16, 10 + >>> network = Net(in_features, out_features) + >>> loss = nn.MSELoss() + >>> net_with_loss = nn.WithLossCell(network, loss) + >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> train_cell = TrainingWrapper(net_with_loss, optimizer) + >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) + >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) + >>> grads = train_cell(inputs, label) + >>> print(grads) + 256.0 + +construct(grads) + + 某些情况下,梯度的数据精度可以与float16和float32混合。因此,AllReduce的结果不可靠。 + 要解决这个问题,必须在AllReduce之前强制转换为float32,并在操作之后再强制转换为float32。 + + + 参数: + grads (Union[Tensor, tuple[Tensor]]):操作前的梯度tensor或tuple。 + + 返回: + new_grads (Union[Tensor, tuple[Tensor]]),操作后的梯度tensor或tuple。 + \ No newline at end of file diff --git a/docs/api/api_python/parallel/mindspore.nn.PipelineCell.txt b/docs/api/api_python/parallel/mindspore.nn.PipelineCell.txt new file mode 100644 index 00000000000..e014e3b1b28 --- /dev/null +++ b/docs/api/api_python/parallel/mindspore.nn.PipelineCell.txt @@ -0,0 +1,15 @@ +Class mindspore.nn.PipelineCell(network, micro_size) + + 将MiniBatch切分成更细粒度的MicroBatch,用于流水线并行的训练中。 + + 注: + micro_size必须大于或等于流水线stage的个数。 + + 参数: + network (Cell):要包裹的目标网络。 + micro_size (int):MicroBatch大小。 + + 示例: + >>> net = Net() + >>> net = PipelineCell(net, 4) + \ No newline at end of file diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 73233a69431..3dc246b2aa0 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -157,7 +157,7 @@ class AllReduce(PrimitiveWithInfer): self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) self.add_prim_attr('index', 0) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def infer_shape(self, x_shape): return x_shape @@ -231,7 +231,7 @@ class AllGather(PrimitiveWithInfer): self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) self.add_prim_attr('mean_flag', False) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def infer_shape(self, x_shape): validator.check_positive_int(len(x_shape), "x shape", self.name) @@ -350,7 +350,7 @@ class _HostAllGather(PrimitiveWithInfer): validator.check_value_type("rank_id", r, (int,), self.name) self.group_size = len(group) self.add_prim_attr('group', group) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def infer_shape(self, x_shape): validator.check_positive_int(len(x_shape), "x shape", self.name) @@ -425,7 +425,7 @@ class ReduceScatter(PrimitiveWithInfer): self.add_prim_attr('rank_size', self.rank_size) self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def infer_shape(self, x_shape): if self.rank_size == 0: @@ -480,7 +480,7 @@ class _HostReduceScatter(PrimitiveWithInfer): self.op = op self.group_size = len(group) self.add_prim_attr('group', group) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def infer_shape(self, x_shape): if x_shape[0] % self.group_size != 0: @@ -558,7 +558,7 @@ class Broadcast(PrimitiveWithInfer): validator.check_value_type('group', _get_group(group), (str,), self.name) check_hcom_group_valid(group, prim_name=self.name) self.add_prim_attr('group', _get_group(group)) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def infer_shape(self, x_shape): return x_shape @@ -605,7 +605,7 @@ class AllSwap(PrimitiveWithCheck): validator.check_value_type('group', _get_group(group), (str,), self.name) self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out']) self.add_prim_attr('group', _get_group(group)) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def __check__(self, tensor_in, send_size, recv_size): validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name) @@ -650,7 +650,7 @@ class NeighborExchange(Primitive): self.recv_shapes = recv_shapes self.send_shapes = send_shapes self.recv_type = recv_type - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def __call__(self, tensor): raise NotImplementedError @@ -690,7 +690,7 @@ class AlltoAll(PrimitiveWithInfer): self.split_dim = split_dim self.concat_dim = concat_dim self.add_prim_attr('group', _get_group(group)) - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def infer_shape(self, x_shape): rank_size = get_group_size(_get_group(self.group)) @@ -740,7 +740,7 @@ class NeighborExchangeV2(Primitive): self.send_lens = send_lens self.recv_lens = recv_lens self.format = data_format - self.add_prim_attr('no_elimilate', True) + self.add_prim_attr('no_eliminate', True) def __call__(self, tensor): raise NotImplementedError diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 49828e9db76..e32ed5bd63b 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -1450,9 +1450,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= Args: network (Cell): Network for distributed predication. checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. - predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or - a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, - it means that the predication process just uses single device. Default: None. + predict_strategy (dict): Strategy of predication process. Default: None. train_strategy_filename (str): Train strategy proto file name. Default: None. strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter into net when parameter name's suffix in checkpoint file is the same as the diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index 20d629be8e6..4b10ea1fe06 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -320,7 +320,7 @@ TEST_F(TestStepParallel, CreatOpInstance) { } else if (name == "index") { parse::ConvertData(py::cast(item.second), &converted_ret); ASSERT_EQ(converted_ret->ToString(), "0"); - } else if (name == "no_elimilate") { + } else if (name == "no_eliminate") { parse::ConvertData(py::cast(item.second), &converted_ret); ASSERT_EQ(converted_ret->ToString(), "true"); } else {