add_CN_API

This commit is contained in:
lichenever 2021-12-03 10:28:43 +08:00
parent 204822ecd8
commit ae40fea111
8 changed files with 566 additions and 14 deletions

View File

@ -0,0 +1,293 @@
mindspore.communication
========================
集合通信接口的类。
.. py:class:: mindspore.communication.GlobalComm
全局通信信息。GlobalComm 是一个全局类。 成员包含BACKEND、WORLD_COMM_GROUP。
.. py:method:: mindspore.communication.init(backend_name=None)
初始化通信服务需要的分布式后端例如HCCL或NCCL服务。
.. note::
HCCL的全称是华为集合通信库Huawei Collective Communication LibraryNCCL的全称是英伟达集合通信库NVIDIA Collective Communication Library`init` 方法应该在 `set_context` 方法之后使用。
**参数:**
- **backend_name** (`str`) 后台服务的名称可选HCCL或NCCL。如果未设置则根据硬件平台类型device_target进行推断默认值为None。
**异常:**
- **TypeError** 在参数 `backend_name` 不是字符串时抛出。
- **RuntimeError** 在以下情况将抛出1硬件设备类型无效2后台服务无效3分布式计算初始化失败4未设置环境变量 `RANK_ID``MINDSPORE_HCCL_CONFIG_PATH` 的情况下初始化HCCL服务。
- **ValueError** 在环境变量 `RANK_ID` 设置成非数字时抛出。
**样例:**
.. code-block::
>>> from mindspore.context import set_context
>>> set_context(device_target="Ascend")
>>> init()
.. py:class:: mindspore.communication.release()
释放分布式资源,例如HCCLNCCL服务。
.. note::
`release` 方法应该在 `init` 方法之后使用。
**异常:**
- **RuntimeError** - 在释放分布式资源失败时抛出。
.. py:class:: mindspore.communication.get_rank(group=GlobalComm.WORLD_COMM_GROUP)
在指定通信组中获取当前的设备序号。
.. note::
`get_rank` 方法应该在 `init` 方法之后使用。
**参数:**
- **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组。
- **默认值** - WORLD_COMM_GROUP
**返回:**
int, 调用该方法的进程对应的组内序号。
**异常:**
- **TypeError** 在参数 `group` 不是字符串时抛出。
- **ValueError** 在后台不可用时抛出。
- **RuntimeError** HCCLNCCL服务不可用时抛出。
.. py:class:: mindspore.communication.get_group_size(group=GlobalComm.WORLD_COMM_GROUP)
获取指定通信组的设备总数。
.. note::
`get_group_size` 方法应该在 `init` 方法之后使用。
**参数:**
- **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组。
- **默认值** - WORLD_COMM_GROUP
**返回:**
int, 指定通信组的设备总数。
**异常:**
- **TypeError** 在参数 `group` 不是字符串时抛出。
- **ValueError** 在后台不可用时抛出。
- **RuntimeError** HCCLNCCL服务不可用时抛出。
.. py:class:: mindspore.communication.get_world_rank_from_group_rank(group, group_rank_id)
由指定通信组中的设备序号获取通信集群中的全局设备序号。
.. note::
1. GPU 版本的MindSpore不支持此方法
2. 参数 `group` 不能是 `hccl_world_group`
3. `get_world_rank_from_group_rank` 方法应该在 `init` 方法之后使用。
**参数:**
- **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建。
- **group_rank_id** (`int`) - 通信组内的设备序号
**返回:**
int, 通信集群中的全局设备序号。
**异常:**
- **TypeError** 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
- **ValueError** 在参数 `group``hccl_world_group` 或后台不可用时抛出。
- **RuntimeError** HCCLNCCL服务不可用以及使用GPU版本的MindSpore时抛出。
**样例:**
.. code-block::
>>> from mindspore.context import set_context
>>> set_context(device_target="Ascend")
>>> init()
>>> group = "0-4"
>>> rank_ids = [0,4]
>>> create_group(group, rank_ids)
>>> world_rank_id = get_world_rank_from_group_rank(group, 1)
>>> print("world_rank_id is: ", world_rank_id) # 全局设备序号为4
.. py:class:: mindspore.communication.get_group_rank_from_world_rank(world_rank_id, group)
由通信集群中的全局设备序号获取指定用户通信组中的设备序号。
.. note::
1. GPU 版本的MindSpore不支持此方法
2. 参数 `group` 不能是 `hccl_world_group`
3. `get_group_rank_from_world_rank` 方法应该在 `init` 方法之后使用。
**参数:**
- **world_rank_id** (`int`) - 通信集群内的全局设备序号。
- **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建。
**返回:**
int, 当前用户通信组中的设备序号。
**异常:**
- **TypeError** 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
- **ValueError** 在参数 `group``hccl_world_group` 或后台不可用时抛出。
- **RuntimeError** HCCLNCCL服务不可用以及使用GPU版本的MindSpore时抛出。
**样例:**
.. code-block::
>>> from mindspore.context import set_context
>>> set_context(device_target="Ascend")
>>> init()
>>> group = "0-4"
>>> rank_ids = [0,4]
>>> create_group(group, rank_ids)
>>> group_rank_id = get_group_rank_from_world_rank(4, group)
>>> print("group_rank_id is: ", group_rank_id) # 组内设备序号是1
.. py:class:: mindspore.communication.create_group(group, rank_ids)
创建用户通信组。
.. note::
1. GPU 版本的MindSpore不支持此方法
2. 列表rank_ids的长度应大于1
3. 列表rank_ids内不能有重复数据
4. `create_group` 方法应该在 `init` 方法之后使用。
**参数:**
- **group** (`str`) - 将被创建的通信组名称。
- **rank_ids** (`list`) - 设备编号列表。
**异常:**
- **TypeError** 在参数 `group_rank_id` 不是数字或参数 `group` 不是字符串时抛出。
- **ValueError** 在列表rank_ids的长度小于1或列表rank_ids内有重复数据以及后台无效时抛出。
- **RuntimeError** HCCLNCCL 服务不可用以及使用GPU版本的MindSpore时抛出。
**样例:**
.. code-block::
>>> from mindspore.context import set_context
>>> set_context(device_target="Ascend")
>>> init()
>>> group = "0-8"
>>> rank_ids = [0,8]
>>> create_group(group, rank_ids)
.. py:class:: mindspore.communication.get_local_rank(group=GlobalComm.WORLD_COMM_GROUP)
获取指定通信组中当前设备的本地设备序号。
.. note::
1. GPU 版本的MindSpore不支持此方法
2. `get_local_rank` 方法应该在 `init` 方法之后使用。
**参数:**
- **group** (`str`) - 通信组名称,通常由 `create_group` 方法创建,否则将使用默认组名称。
- **默认值** - WORLD_COMM_GROUP
**返回:**
int, 调用该方法的进程对应的通信组内本地设备序号。
**异常:**
- **TypeError** 在参数 `group` 不是字符串时抛出。
- **ValueError** 在后台不可用时抛出。
- **RuntimeError** HCCLNCCL服务不可用时抛出。
.. py:class:: mindspore.communication.get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP)
获取指定通信组的本地设备总数。
.. note::
1. GPU 版本的MindSpore不支持此方法
2. `get_local_rank_size` 方法应该在 `init` 方法之后使用。
**参数:**
- **group** (`str`) - 传入的通信组名称,通常由 `create_group` 方法创建或默认使用WORLD_COMM_GROUP
**返回:**
int, 调用该方法的进程对应的通信组设备总数。
**异常:**
- **TypeError** 在参数 `group` 不是字符串时抛出。
- **ValueError** 在后台不可用时抛出。
- **RuntimeError** HCCLNCCL服务不可用时抛出。
.. py:class:: mindspore.communication.destroy_group(group)
销毁用户通信组。
.. note::
1. GPU 版本的MindSpore不支持此方法
2. 参数 `group` 不能是 `hccl_world_group`
3. `destroy_group` 方法应该在 `init` 方法之后使用。
**参数:**
- **group** (`str`) - 将被销毁的通信组,通常由 `create_group` 方法创建。
**异常:**
- **TypeError** 在参数 `group` 不是字符串时抛出。
- **ValueError** 在参数 `group``hccl_world_group` 或后台不可用时抛出。
- **RuntimeError** HCCLNCCL服务不可用时抛出。

View File

@ -0,0 +1,125 @@
get_auto_parallel_context(attr_key)
根据key获取自动并行的配置。
参数:
attr_key (str)配置的key。
返回:
根据key返回配置的值。
异常:
ValueError输入key不在自动并行的配置列表中。
Class mindspore.context.ParallelMode
并行模式。
有五种并行模式分别是STAND_ALONE、DATA_PARALLEL、HYBRID_PARALLEL、SEMI_AUTO_PARALLEL和AUTO_PARALLEL。
默认值STAND_ALONE。
- STAND_ALONE单卡模式。
- DATA_PARALLEL数据并行模式。
- HYBRID_PARALLEL手动实现数据并行和模型并行。
- SEMI_AUTO_PARALLEL半自动并行模式。
- AUTO_PARALLEL自动并行模式。
MODE_LIST表示所有支持的并行模式的列表。
reset_auto_parallel_context()
重置自动并行的配置为默认值。
- device_num1。
- global_rank0。
- gradients_meanFalse。
- gradient_fp32_syncTrue。
- parallel_mode'stand_alone'。
- auto_parallel_search_mode'dynamic_programming'。
- parameter_broadcastFalse。
- strategy_ckpt_load_file''。
- strategy_ckpt_save_file''。
- full_batchFalse。
- enable_parallel_optimizerFalse。
- pipeline_stages1。
set_auto_parallel_context(**kwargs)
配置自动并行仅在Ascend和GPU上有效。
应在init之前配置自动并行。
注:
配置时,必须输入配置的名称。
如果某个程序具有不同并行模式下的任务则需要再为下一个任务设置新的并行模式之前调用reset_auto_parallel_context()接口来重置配置。
若要设置或更改并行模式必须在创建任何Initializer之前调用接口否则在编译网络时可能会出现RuntimeError。
某些配置适用于特定的并行模式,有关详细信息,请参见下表:
=========================== ===========================
Common AUTO_PARALLEL
=========================== ===========================
device_num gradient_fp32_sync
global_rank loss_repeated_mean
gradients_mean auto_parallel_search_mode
parallel_mode strategy_ckpt_load_file
all_reduce_fusion_config strategy_ckpt_save_file
enable_parallel_optimizer dataset_strategy
\ pipeline_stages
\ grad_accumulation_step
=========================== ===========================
参数:
device_num (int)表示可用设备的编号必须在【1,4096】范围中。默认值1。
global_rank (int)表示全局秩的ID必须在【0,4095】范围中。默认值0。
gradients_mean (bool)表示是否在梯度的allreduce后执行平均算子。
stand_alone不支持gradients_mean。默认值False。
gradient_fp32_sync (bool)在FP32中运行gradients的allreduce。stand_alone、data_parallel和hybrid_parallel不支持gradient_fp32_sync。默认值True。
parallel_mode (str)有五种并行模式分别是stand_alone、data_parallel、hybrid_parallel、semi_auto_parallel和auto_parallel。默认值stand_alone。
- stand_alone单卡模式。
- data_parallel数据并行模式。
- hybrid_parallel手动实现数据并行和模型并行。
- semi_auto_parallel半自动并行模式。
- auto_parallel自动并行模式。
auto_parallel_search_mode (str)表示有两种策略搜索模式分别是recursive_programming和dynamic_programming。默认值dynamic_programming。
- recursive_programming表示双递归搜索模式。
- dynamic_programming表示动态规划搜索模式。
parameter_broadcast (bool)表示在训练前是否广播参数。在训练之前为了使所有设备的网络初始化参数值相同请将设备0上的参数广播到其他设备。不同并行模式下的参数广播不同。
在data_parallel模式下除layerwise_parallel属性为True的参数外所有参数都会被广播。在Hybrid_parallel、semi_auto_parallel和auto_parallel模式下分段参数不参与广播。默认值False。
strategy_ckpt_load_file (str)表示用于加载并行策略checkpoint的路径。默认值''。
strategy_ckpt_save_file (str)表示用于保存并行策略checkpoint的路径。默认值''。
full_batch (bool)如果在auto_parallel模式下加载整个batch数据集则此参数应设置为True。默认值False。目前不建议使用该接口建议使用dataset_strategy来替换它。
dataset_strategy (Union[str, tuple])表示数据集分片策略。默认值data_parallel。
dataset_strategy="data_parallel"等于full_batch=Falsedataset_strategy="full_batch"等于full_batch=True。对于通过模型并列策略加载到网络的数据集如ds_stra ((1, 8)、(1, 8))需要使用set_auto_parallel_context(dataset_strategy=ds_stra)。
enable_parallel_optimizer (bool)这是一个开发中的特性它可以为数据并行训练对权重更新计算进行分片以节省时间和内存。目前自动和半自动并行模式支持Ascend和GPU中的所有优化器。数据并行模式仅支持Ascend中的`Lamb`和`AdamWeightDecay`。默认值False。
all_reduce_fusion_config (list)通过参数索引设置allreduce 融合策略。仅支持ReduceOp.SUM和HCCL_WORLD_GROUP/NCCL_WORLD_GROUP。没有默认值。如果不设置则关闭算子融合。
pipeline_stages (int)设置pipeline并行的阶段信息。这表明了设备如何单独分布在pipeline上。所有的设备将被划分为pipeline_stags个阶段。
目前这只能在启动semi_auto_parallel模式的情况下使用。默认值1。
grad_accumulation_step (int)在自动和半自动并行模式下设置梯度的累积step。
其值应为正整数。默认值1。
异常:
ValueError输入key不是自动并行上下文中的属性。
样例:
>>> context.set_auto_parallel_context(device_num=8)
>>> context.set_auto_parallel_context(global_rank=0)
>>> context.set_auto_parallel_context(gradients_mean=True)
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8)))
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
>>> context.set_auto_parallel_context(pipeline_stages=2)

View File

@ -0,0 +1,22 @@
mindspore.load_distributed_checkpoint
======================================
.. py:method:: mindspore.load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM')
给分布式预测加载checkpoint文件到网络用于分布式推理。关于分布式推理的细节请参考'<https://www.mindspore.cn/docs/programming_guide/zh-CN/master/distributed_inference.html>' 。
**参数:**
- **network** (Cell):分布式预测网络。
- **checkpoint_filenames** (list[str])checkpoint文件的名称按rank id 顺序排列。
- **predict_strategy** (dict)predict时参数的切分策略。
- **train_strategy_filename** (str)训练策略proto文件名。默认值None。
- **strict_load** (bool)表示是否严格加载参数到网络。如果值为False则当checkpoint文件中参数名称的后缀与网络中的参数相同时加载参数到网络。当类型不一致时对相同类型的参数进行类型转换如从float32到float16。默认值False。
- **dec_key** (Union[None, bytes])用于解密的字节类型key。如果value为None则不需要解密。默认值None。
- **dec_mode** (str)仅当dec_key不设为None时该参数有效。指定了解密模式目前支持AES-GCM和AES-CBC。默认值AES-GCM。
**异常:**
- **TypeError** 输入类型不符合要求。
- **ValueError** 无法加载checkpoint文件到网络。

View File

@ -0,0 +1,99 @@
Class mindspore.nn.DistributedGradReducer(parameters, mean=True, degree=None, fusion_type=1, group='hccl_world_group')
分布式优化器。
对反向梯度进行AllReduce运算。
参数:
parameters (list):需要更新的参数。
mean (bool)当mean为True时对AllReduce之后的梯度求均值。默认值False。
degree (int)平均系数通常等于设备编号。默认值None。
fusion_type (int)AllReduce算子的融合类型。默认值1。
异常:
ValueError如果degree不是int或小于0。
支持平台:
``Ascend`` ``GPU``
示例:
>>> #此示例应与多个进程一起运行。
>>> #请参考Mindpore.cn上的“教程>分布式训练”。
>>> import numpy as np
>>> from mindspore.communication import init
>>> from mindspore import ops
>>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import Parameter, Tensor
>>> from mindspore import nn
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>> class TrainingWrapper(nn.Cell):
... def __init__(self, network, optimizer, sens=1.0):
... super(TrainingWrapper, self).__init__(auto_prefix=False)
... self.network = network
... self.network.add_flags(defer_inline=True)
... self.weights = optimizer.parameters
... self.optimizer = optimizer
... self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
... self.sens = sens
... self.reducer_flag = False
... self.grad_reducer = None
... self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
... if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
... self.reducer_flag = True
... if self.reducer_flag:
... mean = context.get_auto_parallel_context("gradients_mean")
... degree = context.get_auto_parallel_context("device_num")
... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
...
... def construct(self, *args):
... weights = self.weights
... loss = self.network(*args)
... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
... grads = self.grad(self.network, weights)(*args, sens)
... if self.reducer_flag:
... # apply grad reducer on grads
... grads = self.grad_reducer(grads)
... return ops.Depend(loss, self.optimizer(grads))
>>>
>>> class Net(nn.Cell):
... def __init__(self, in_features, out_features)
... super(Net, self).__init__()
... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
... name='weight')
... self.matmul = ops.MatMul()
...
... def construct(self, x)
... output = self.matmul(x, self.weight)
... return output
>>>
>>> size, in_features, out_features = 16, 16, 10
>>> network = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> net_with_loss = nn.WithLossCell(network, loss)
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(net_with_loss, optimizer)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> grads = train_cell(inputs, label)
>>> print(grads)
256.0
construct(grads)
某些情况下梯度的数据精度可以与float16和float32混合。因此AllReduce的结果不可靠。
要解决这个问题必须在AllReduce之前强制转换为float32并在操作之后再强制转换为float32。
参数:
grads (Union[Tensor, tuple[Tensor]])操作前的梯度tensor或tuple。
返回:
new_grads (Union[Tensor, tuple[Tensor]])操作后的梯度tensor或tuple。

View File

@ -0,0 +1,15 @@
Class mindspore.nn.PipelineCell(network, micro_size)
将MiniBatch切分成更细粒度的MicroBatch用于流水线并行的训练中。
注:
micro_size必须大于或等于流水线stage的个数。
参数:
network (Cell):要包裹的目标网络。
micro_size (int)MicroBatch大小。
示例:
>>> net = Net()
>>> net = PipelineCell(net, 4)

View File

@ -157,7 +157,7 @@ class AllReduce(PrimitiveWithInfer):
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
return x_shape
@ -231,7 +231,7 @@ class AllGather(PrimitiveWithInfer):
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('mean_flag', False)
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
@ -350,7 +350,7 @@ class _HostAllGather(PrimitiveWithInfer):
validator.check_value_type("rank_id", r, (int,), self.name)
self.group_size = len(group)
self.add_prim_attr('group', group)
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
@ -425,7 +425,7 @@ class ReduceScatter(PrimitiveWithInfer):
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
if self.rank_size == 0:
@ -480,7 +480,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
self.op = op
self.group_size = len(group)
self.add_prim_attr('group', group)
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
if x_shape[0] % self.group_size != 0:
@ -558,7 +558,7 @@ class Broadcast(PrimitiveWithInfer):
validator.check_value_type('group', _get_group(group), (str,), self.name)
check_hcom_group_valid(group, prim_name=self.name)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
return x_shape
@ -605,7 +605,7 @@ class AllSwap(PrimitiveWithCheck):
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def __check__(self, tensor_in, send_size, recv_size):
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
@ -650,7 +650,7 @@ class NeighborExchange(Primitive):
self.recv_shapes = recv_shapes
self.send_shapes = send_shapes
self.recv_type = recv_type
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def __call__(self, tensor):
raise NotImplementedError
@ -690,7 +690,7 @@ class AlltoAll(PrimitiveWithInfer):
self.split_dim = split_dim
self.concat_dim = concat_dim
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def infer_shape(self, x_shape):
rank_size = get_group_size(_get_group(self.group))
@ -740,7 +740,7 @@ class NeighborExchangeV2(Primitive):
self.send_lens = send_lens
self.recv_lens = recv_lens
self.format = data_format
self.add_prim_attr('no_elimilate', True)
self.add_prim_attr('no_eliminate', True)
def __call__(self, tensor):
raise NotImplementedError

View File

@ -1450,9 +1450,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
Args:
network (Cell): Network for distributed predication.
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device. Default: None.
predict_strategy (dict): Strategy of predication process. Default: None.
train_strategy_filename (str): Train strategy proto file name. Default: None.
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
into net when parameter name's suffix in checkpoint file is the same as the

View File

@ -320,7 +320,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
} else if (name == "index") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "0");
} else if (name == "no_elimilate") {
} else if (name == "no_eliminate") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "true");
} else {