commit
3170c0a90b
|
@ -9,7 +9,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **auto_prefix** (bool) – 是否自动为Cell及其子Cell生成NameSpace。`auto_prefix` 的设置影响网络参数的命名,如果设置为True,则自动给网络参数的名称添加前缀,否则不添加前缀。默认值:True。
|
||||
- **auto_prefix** (bool) - 是否自动为Cell及其子Cell生成NameSpace。`auto_prefix` 的设置影响网络参数的命名,如果设置为True,则自动给网络参数的名称添加前缀,否则不添加前缀。默认值:True。
|
||||
- **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
|
||||
|
||||
.. py:method:: add_flags(**flags)
|
||||
|
@ -30,6 +30,18 @@
|
|||
|
||||
- **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
|
||||
|
||||
.. py:method:: auto_cast_inputs(inputs)
|
||||
|
||||
在混合精度下,自动对输入进行类型转换。
|
||||
|
||||
**参数:**
|
||||
|
||||
**inputs** (tuple) - construct方法的输入。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tuple类型,经过类型转换后的输入。
|
||||
|
||||
.. py:method:: auto_parallel_compile_and_run()
|
||||
|
||||
是否在‘AUTO_PARALLEL’或‘SEMI_AUTO_PARALLEL’模式下执行编译流程。
|
||||
|
@ -64,7 +76,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **param** (Parameter) – 需要被转换类型的输入参数。
|
||||
- **param** (Parameter) - 需要被转换类型的输入参数。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -78,14 +90,14 @@
|
|||
|
||||
Iteration类型,Cell的子Cell。
|
||||
|
||||
.. py:method:: cells_and_names(cells=None, name_prefix="")
|
||||
.. py:method:: cells_and_names(cells=None, name_prefix='')
|
||||
|
||||
递归地获取当前Cell及输入 `cells` 的所有子Cell的迭代器,包括Cell的名称及其本身。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **cells** (str) – 需要进行迭代的Cell。默认值:None。
|
||||
- **name_prefix** (str) – 作用域。默认值:''。
|
||||
- **cells** (str) - 需要进行迭代的Cell。默认值:None。
|
||||
- **name_prefix** (str) - 作用域。默认值:''。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -95,25 +107,13 @@
|
|||
|
||||
检查Cell中的网络参数名称是否重复。
|
||||
|
||||
|
||||
.. py:method:: set_inputs(*inputs)
|
||||
|
||||
设置编译计算图所需的输入,输入需与实例中定义的输入一致。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **inputs** (tuple) – Cell的输入。
|
||||
|
||||
.. note::
|
||||
这是一个实验接口,可能会被更改或者删除。
|
||||
|
||||
.. py:method:: compile(*inputs)
|
||||
|
||||
编译Cell为计算图,输入需与construct中定义的输入一致。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **inputs** (tuple) – Cell的输入。
|
||||
- **inputs** (tuple) - Cell的输入。
|
||||
|
||||
.. py:method:: compile_and_run(*inputs)
|
||||
|
||||
|
@ -124,7 +124,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **inputs** (tuple) – Cell的输入。
|
||||
- **inputs** (tuple) - Cell的输入。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -139,8 +139,8 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **inputs** – 可变参数列表,默认值:()。
|
||||
- **kwargs** – 可变的关键字参数的字典,默认值:{}。
|
||||
- **inputs** (tuple) - 可变参数列表,默认值:()。
|
||||
- **kwargs** (dict) - 可变的关键字参数的字典,默认值:{}。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -168,13 +168,24 @@
|
|||
|
||||
返回图的二进制原型。
|
||||
|
||||
.. py:method:: get_inputs()
|
||||
|
||||
返回编译计算图所设置的输入。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tuple类型,编译计算图所设置的输入。
|
||||
|
||||
.. note::
|
||||
这是一个实验接口,可能会被更改或者删除。
|
||||
|
||||
.. py:method:: get_parameters(expand=True)
|
||||
|
||||
返回Cell中parameter的迭代器。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **expand** (bool) – 如果为True,则递归地获取当前Cell和所有子Cell的parameter。否则,只生成当前Cell的子Cell的parameter。默认值:True。
|
||||
- **expand** (bool) - 如果为True,则递归地获取当前Cell和所有子Cell的parameter。否则,只生成当前Cell的子Cell的parameter。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -187,17 +198,6 @@
|
|||
**返回:**
|
||||
|
||||
String类型,网络的作用域。
|
||||
|
||||
.. py:method:: get_inputs()
|
||||
|
||||
返回编译计算图所设置的输入。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tuple类型,编译计算图所设置的输入。
|
||||
|
||||
.. note::
|
||||
这是一个实验接口,可能会被更改或者删除。
|
||||
|
||||
.. py:method:: infer_param_pipeline_stage()
|
||||
|
||||
|
@ -213,7 +213,7 @@
|
|||
|
||||
**异常:**
|
||||
|
||||
- **RuntimeError** – 如果参数不属于任何stage。
|
||||
- **RuntimeError** - 如果参数不属于任何stage。
|
||||
|
||||
.. py:method:: init_parameters_data(auto_parallel_mode=False)
|
||||
|
||||
|
@ -224,7 +224,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **auto_parallel_mode** (bool) – 是否在自动并行模式下执行。 默认值:False。
|
||||
- **auto_parallel_mode** (bool) - 是否在自动并行模式下执行。 默认值:False。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -236,13 +236,13 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **child_name** (str) – 子Cell名称。
|
||||
- **child_cell** (Cell) – 要插入的子Cell。
|
||||
- **child_name** (str) - 子Cell名称。
|
||||
- **child_cell** (Cell) - 要插入的子Cell。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **KeyError** – 如果子Cell的名称不正确或与其他子Cell名称重复。
|
||||
- **TypeError** – 如果子Cell的类型不正确。
|
||||
- **KeyError** - 如果子Cell的名称不正确或与其他子Cell名称重复。
|
||||
- **TypeError** - 如果子Cell的类型不正确。
|
||||
|
||||
.. py:method:: insert_param_to_cell(param_name, param, check_name_contain_dot=True)
|
||||
|
||||
|
@ -252,14 +252,14 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **param_name** (str) – 参数名称。
|
||||
- **param** (Parameter) – 要插入到Cell的参数。
|
||||
- **check_name_contain_dot** (bool) – 是否对 `param_name` 中的"."进行检查。默认值:True。
|
||||
- **param_name** (str) - 参数名称。
|
||||
- **param** (Parameter) - 要插入到Cell的参数。
|
||||
- **check_name_contain_dot** (bool) - 是否对 `param_name` 中的"."进行检查。默认值:True。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **KeyError** – 如果参数名称为空或包含"."。
|
||||
- **TypeError** – 如果参数的类型不是Parameter。
|
||||
- **KeyError** - 如果参数名称为空或包含"."。
|
||||
- **TypeError** - 如果参数的类型不是Parameter。
|
||||
|
||||
.. py:method:: load_parameter_slice(params)
|
||||
|
||||
|
@ -269,7 +269,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
**params** (dict) – 用于初始化数据图的参数字典。
|
||||
**params** (dict) - 用于初始化数据图的参数字典。
|
||||
|
||||
.. py:method:: name_cells()
|
||||
|
||||
|
@ -324,7 +324,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **recurse** (bool) – 是否递归得包含所有子Cell的parameter。默认值:True。
|
||||
- **recurse** (bool) - 是否递归得包含所有子Cell的parameter。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -344,54 +344,8 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **mp_comm_recompute** (bool) – 表示在自动并行或半自动并行模式下,指定Cell内部由模型并行引入的通信操作是否重计算。默认值:True。
|
||||
- **parallel_optimizer_comm_recompute** (bool) – 表示在自动并行或半自动并行模式下,指定Cell内部由优化器并行引入的AllGather通信是否重计算。默认值:False。
|
||||
|
||||
.. py:method:: register_forward_pre_hook(hook_fn)
|
||||
|
||||
设置Cell对象的正向pre_hook函数。
|
||||
|
||||
.. note::
|
||||
- `register_forward_pre_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
|
||||
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。
|
||||
- hook_fn返回新的输入数据或者None:hook_fn(cell_id, inputs) -> New inputs or None。
|
||||
- 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` 。
|
||||
- PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **hook_fn** (function) – 捕获Cell对象信息和正向输入数据的hook_fn函数。
|
||||
|
||||
**返回:**
|
||||
|
||||
`mindspore.common.hook_handle.HookHandle` 类型,与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 如果 `hook_fn` 不是Python函数。
|
||||
|
||||
.. py:method:: register_forward_hook(hook_fn)
|
||||
|
||||
设置Cell对象的正向hook函数。
|
||||
|
||||
.. note::
|
||||
- `register_forward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
|
||||
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。 `outputs` 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。
|
||||
- hook_fn返回新的输出数据或者None:hook_fn(cell_id, inputs, outputs) -> New outputs or None。
|
||||
- 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` 。
|
||||
- PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **hook_fn** (function) – 捕获Cell对象信息和正向输入,输出数据的hook_fn函数。
|
||||
|
||||
**返回:**
|
||||
|
||||
`mindspore.common.hook_handle.HookHandle` 类型,与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 如果 `hook_fn` 不是Python函数。
|
||||
- **mp_comm_recompute** (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由模型并行引入的通信操作是否重计算。默认值:True。
|
||||
- **parallel_optimizer_comm_recompute** (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由优化器并行引入的AllGather通信是否重计算。默认值:False。
|
||||
|
||||
.. py:method:: register_backward_hook(hook_fn)
|
||||
|
||||
|
@ -406,7 +360,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **hook_fn** (function) – 捕获Cell对象信息和反向输入,输出梯度的hook_fn函数。
|
||||
- **hook_fn** (function) - 捕获Cell对象信息和反向输入,输出梯度的hook_fn函数。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -414,7 +368,53 @@
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 如果 `hook_fn` 不是Python函数。
|
||||
- **TypeError** - 如果 `hook_fn` 不是Python函数。
|
||||
|
||||
.. py:method:: register_forward_hook(hook_fn)
|
||||
|
||||
设置Cell对象的正向hook函数。
|
||||
|
||||
.. note::
|
||||
- `register_forward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
|
||||
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。 `outputs` 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。
|
||||
- hook_fn返回新的输出数据或者None:hook_fn(cell_id, inputs, outputs) -> New outputs or None。
|
||||
- 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` 。
|
||||
- PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **hook_fn** (function) - 捕获Cell对象信息和正向输入,输出数据的hook_fn函数。
|
||||
|
||||
**返回:**
|
||||
|
||||
`mindspore.common.hook_handle.HookHandle` 类型,与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `hook_fn` 不是Python函数。
|
||||
|
||||
.. py:method:: register_forward_pre_hook(hook_fn)
|
||||
|
||||
设置Cell对象的正向pre_hook函数。
|
||||
|
||||
.. note::
|
||||
- `register_forward_pre_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
|
||||
- hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。
|
||||
- hook_fn返回新的输入数据或者None:hook_fn(cell_id, inputs) -> New inputs or None。
|
||||
- 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` 。
|
||||
- PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **hook_fn** (function) - 捕获Cell对象信息和正向输入数据的hook_fn函数。
|
||||
|
||||
**返回:**
|
||||
|
||||
`mindspore.common.hook_handle.HookHandle` 类型,与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `hook_fn` 不是Python函数。
|
||||
|
||||
.. py:method:: remove_redundant_parameters()
|
||||
|
||||
|
@ -431,8 +431,8 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **cast_inputs** (tuple) – Cell的输入。
|
||||
- **kwargs** (dict) – 关键字参数。
|
||||
- **cast_inputs** (tuple) - Cell的输入。
|
||||
- **kwargs** (dict) - 关键字参数。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -444,6 +444,35 @@
|
|||
|
||||
.. note:: 如果一个Cell需要使用自动并行或半自动并行模式来进行训练、评估或预测,则该Cell需要调用此接口。
|
||||
|
||||
.. py:method:: set_boost(boost_type)
|
||||
|
||||
为了提升网络性能,可以配置boost内的算法让框架自动使能该算法来加速网络训练。
|
||||
|
||||
请确保 `boost_type` 所选择的算法在
|
||||
`algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_ 算法库中。
|
||||
|
||||
.. note:: 部分加速算法可能影响网络精度,请谨慎选择。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **boost_type** (str) - 加速算法。
|
||||
|
||||
**返回:**
|
||||
|
||||
Cell类型,Cell本身。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - 如果 `boost_type` 不在boost算法库内。
|
||||
|
||||
.. py:method:: set_broadcast_flag(mode=True)
|
||||
|
||||
设置该Cell的参数广播模式。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **mode** (bool) - 指定当前模式是否进行参数广播。默认值:True。
|
||||
|
||||
.. py:method:: set_comm_fusion(fusion_type, recurse=True)
|
||||
|
||||
为Cell中的参数设置融合类型。请参考 :class:`mindspore.Parameter.comm_fusion` 的描述。
|
||||
|
@ -452,16 +481,8 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **fusion_type** (int) – Parameter的 `comm_fusion` 属性的设置值。
|
||||
- **recurse** (bool) – 是否递归地设置子Cell的可训练参数。默认值:True。
|
||||
|
||||
.. py:method:: set_broadcast_flag(mode=True)
|
||||
|
||||
设置该Cell的参数广播模式。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **mode** (bool) – 指定当前模式是否进行参数广播。默认值:True。
|
||||
- **fusion_type** (int) - Parameter的 `comm_fusion` 属性的设置值。
|
||||
- **recurse** (bool) - 是否递归地设置子Cell的可训练参数。默认值:True。
|
||||
|
||||
.. py:method:: set_data_parallel()
|
||||
|
||||
|
@ -469,6 +490,72 @@
|
|||
|
||||
.. note:: 仅在图模式、全自动并行(AUTO_PARALLEL)模式下生效。
|
||||
|
||||
.. py:method:: set_grad(requires_grad=True)
|
||||
|
||||
Cell的梯度设置。在PyNative模式下,该参数指定Cell是否需要梯度。如果为True,则在执行正向网络时,将生成需要计算梯度的反向网络。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **requires_grad** (bool) - 指定网络是否需要梯度,如果为True,PyNative模式下Cell将构建反向网络。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
Cell类型,Cell本身。
|
||||
|
||||
.. py:method:: set_inputs(*inputs)
|
||||
|
||||
设置编译计算图所需的输入,输入需与实例中定义的输入一致。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **inputs** (tuple) - Cell的输入。
|
||||
|
||||
.. note::
|
||||
这是一个实验接口,可能会被更改或者删除。
|
||||
|
||||
.. py:method:: set_parallel_input_with_inputs(*inputs)
|
||||
|
||||
通过并行策略对输入张量进行切分。
|
||||
|
||||
**参数:**
|
||||
|
||||
**inputs** (tuple) - construct方法的输入。
|
||||
|
||||
.. py:method:: set_param_fl(push_to_server=False, pull_from_server=False, requires_aggr=True)
|
||||
|
||||
设置参数与服务器交互的方式。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **push_to_server** (bool) - 是否将参数推送到服务器。默认值:False。
|
||||
- **pull_from_server** (bool) - 是否从服务器提取参数。默认值:False。
|
||||
- **requires_aggr** (bool) - 是否在服务器中聚合参数。默认值:True。
|
||||
|
||||
.. py:method:: set_param_ps(recurse=True, init_in_server=False)
|
||||
|
||||
设置可训练参数是否由参数服务器更新,以及是否在服务器上初始化可训练参数。
|
||||
|
||||
.. note:: 只在运行的任务处于参数服务器模式时有效。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **recurse** (bool) - 是否设置子网络的可训练参数。默认值:True。
|
||||
- **init_in_server** (bool) - 是否在服务器上初始化由参数服务器更新的可训练参数。默认值:False。
|
||||
|
||||
.. py:method:: set_train(mode=True)
|
||||
|
||||
将Cell设置为训练模式。
|
||||
|
||||
设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **mode** (bool) - 指定模型是否为训练模式。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
Cell类型,Cell本身。
|
||||
|
||||
.. py:method:: shard(in_strategy, out_strategy, device="Ascend", level=0)
|
||||
|
||||
指定输入/输出Tensor的分布策略,其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个Cell以图模式进行分布式执行。 in_strategy/out_strategy需要为元组类型,
|
||||
|
@ -479,78 +566,10 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **in_strategy** (tuple) – 指定各输入的切分策略,输入元组的每个元素可以为元组或None,元组即具体指定输入每一维的切分策略,None则会默认以数据并行执行。
|
||||
- **out_strategy** (tuple) – 指定各输出的切分策略,用法同in_strategy。
|
||||
- **in_strategy** (tuple) - 指定各输入的切分策略,输入元组的每个元素可以为元组或None,元组即具体指定输入每一维的切分策略,None则会默认以数据并行执行。
|
||||
- **out_strategy** (tuple) - 指定各输出的切分策略,用法同in_strategy。
|
||||
- **device** (string) - 指定执行设备,可以为["CPU", "GPU", "Ascend"]中任意一个,默认值:"Ascend"。目前尚未使能。
|
||||
- **level** (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[0, 1, 2]中任意一个,默认值:0。目前仅支持
|
||||
最大化计算通信比,其余模式尚未使能。
|
||||
|
||||
**返回:**
|
||||
|
||||
Cell类型,Cell本身。
|
||||
|
||||
.. py:method:: auto_cast_inputs(inputs)
|
||||
|
||||
在混合精度下,自动对输入进行类型转换。
|
||||
|
||||
**参数:**
|
||||
|
||||
**inputs** (tuple) – construct方法的输入。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tuple类型,经过类型转换后的输入。
|
||||
|
||||
.. py:method:: set_grad(requires_grad=True)
|
||||
|
||||
Cell的梯度设置。在PyNative模式下,该参数指定Cell是否需要梯度。如果为True,则在执行正向网络时,将生成需要计算梯度的反向网络。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **requires_grad** (bool) – 指定网络是否需要梯度,如果为True,PyNative模式下Cell将构建反向网络。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
Cell类型,Cell本身。
|
||||
|
||||
.. py:method:: set_parallel_input_with_inputs(*inputs)
|
||||
|
||||
通过并行策略对输入张量进行切分。
|
||||
|
||||
**参数:**
|
||||
|
||||
**inputs** (tuple) – construct方法的输入。
|
||||
|
||||
.. py:method:: set_param_fl(push_to_server=False, pull_from_server=False, requires_aggr=True)
|
||||
|
||||
设置参数与服务器交互的方式。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **push_to_server** (bool) – 是否将参数推送到服务器。默认值:False。
|
||||
- **pull_from_server** (bool) – 是否从服务器提取参数。默认值:False。
|
||||
- **requires_aggr** (bool) – 是否在服务器中聚合参数。默认值:True。
|
||||
|
||||
.. py:method:: set_param_ps(recurse=True, init_in_server=False)
|
||||
|
||||
设置可训练参数是否由参数服务器更新,以及是否在服务器上初始化可训练参数。
|
||||
|
||||
.. note:: 只在运行的任务处于参数服务器模式时有效。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **recurse** (bool) – 是否设置子网络的可训练参数。默认值:True。
|
||||
- **init_in_server** (bool) – 是否在服务器上初始化由参数服务器更新的可训练参数。默认值:False。
|
||||
|
||||
.. py:method:: set_train(mode=True)
|
||||
|
||||
将Cell设置为训练模式。
|
||||
|
||||
设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **mode** (bool) – 指定模型是否为训练模式。默认值:True。
|
||||
- **level** (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[0, 1, 2]中任意一个,默认值:0。目前仅支持最大化计算通信比,其余模式尚未使能。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -566,7 +585,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **dst_type** (mindspore.dtype) – Cell转换为 `dst_type` 类型运行。 `dst_type` 可以是 `mindspore.dtype.float16` 或者 `mindspore.dtype.float32` 。
|
||||
- **dst_type** (mindspore.dtype) - Cell转换为 `dst_type` 类型运行。 `dst_type` 可以是 `mindspore.dtype.float16` 或者 `mindspore.dtype.float32` 。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -574,29 +593,7 @@
|
|||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** – 如果 `dst_type` 不是 `mindspore.dtype.float32` ,也不是 `mindspore.dtype.float16`。
|
||||
|
||||
.. py:method:: set_boost(boost_type)
|
||||
|
||||
为了提升网络性能,可以配置boost内的算法让框架自动使能该算法来加速网络训练。
|
||||
|
||||
请确保 `boost_type` 所选择的算法在
|
||||
`algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_ 算法库中。
|
||||
|
||||
.. note:: 部分加速算法可能影响网络精度,请谨慎选择。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **boost_type** (str) – 加速算法。
|
||||
|
||||
**返回:**
|
||||
|
||||
Cell类型,Cell本身。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** – 如果 `boost_type` 不在boost算法库内。
|
||||
|
||||
- **ValueError** - 如果 `dst_type` 不是 `mindspore.dtype.float32` ,也不是 `mindspore.dtype.float16`。
|
||||
|
||||
.. py:method:: trainable_params(recurse=True)
|
||||
|
||||
|
@ -606,7 +603,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **recurse** (bool) – 是否递归地包含当前Cell的所有子Cell的可训练参数。默认值:True。
|
||||
- **recurse** (bool) - 是否递归地包含当前Cell的所有子Cell的可训练参数。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -620,7 +617,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **recurse** (bool) – 是否递归地包含当前Cell的所有子Cell的不可训练参数。默认值:True。
|
||||
- **recurse** (bool) - 是否递归地包含当前Cell的所有子Cell的不可训练参数。默认值:True。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -640,7 +637,7 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **cell_type** (str) – 被更新的类型,`cell_type` 可以是"quant"或"second-order"。
|
||||
- **cell_type** (str) - 被更新的类型,`cell_type` 可以是"quant"或"second-order"。
|
||||
|
||||
.. py:method:: update_parameters_name(prefix='', recurse=True)
|
||||
|
||||
|
@ -648,5 +645,5 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **prefix** (str) – 前缀字符串。默认值:''。
|
||||
- **recurse** (bool) – 是否递归地包含所有子Cell的参数。默认值:True。
|
||||
- **prefix** (str) - 前缀字符串。默认值:''。
|
||||
- **recurse** (bool) - 是否递归地包含所有子Cell的参数。默认值:True。
|
||||
|
|
|
@ -18,27 +18,27 @@ mindspore.nn.Conv2d
|
|||
|
||||
**参数:**
|
||||
|
||||
- **in_channels** (`int`) – Conv2d层输入Tensor的空间维度。
|
||||
- **in_channels** (`int`) - Conv2d层输入Tensor的空间维度。
|
||||
- **out_channels** (`dict`) - Conv2d层输出Tensor的空间维度。
|
||||
- **kernel_size** (`Union[int, tuple[int]]`) – 指定二维卷积核的高度和宽度。数据类型为整型或两个整型的tuple。一个整数表示卷积核的高度和宽度均为该值。两个整数的tuple分别表示卷积核的高度和宽度。
|
||||
- **stride** (`Union[int, tuple[int]]`) – 二维卷积核的移动步长。数据类型为整型或两个整型的tuple。一个整数表示在高度和宽度方向的移动步长均为该值。两个整数的tuple分别表示在高度和宽度方向的移动步长。默认值:1。
|
||||
- **pad_mode** (`str`) – 指定填充模式。可选值为"same"、"valid"、"pad"。默认值:"same"。
|
||||
- **kernel_size** (`Union[int, tuple[int]]`) - 指定二维卷积核的高度和宽度。数据类型为整型或两个整型的tuple。一个整数表示卷积核的高度和宽度均为该值。两个整数的tuple分别表示卷积核的高度和宽度。
|
||||
- **stride** (`Union[int, tuple[int]]`) - 二维卷积核的移动步长。数据类型为整型或两个整型的tuple。一个整数表示在高度和宽度方向的移动步长均为该值。两个整数的tuple分别表示在高度和宽度方向的移动步长。默认值:1。
|
||||
- **pad_mode** (`str`) - 指定填充模式。可选值为"same"、"valid"、"pad"。默认值:"same"。
|
||||
|
||||
- **same**:输出的高度和宽度分别与输入整除 `stride` 后的值相同。若设置该模式,`padding` 的值必须为0。
|
||||
- **valid**:在不填充的前提下返回有效计算所得的输出。不满足计算的多余像素会被丢弃。如果设置此模式,则 `padding` 的值必须为0。
|
||||
- **pad**:对输入进行填充。在输入的高度和宽度方向上填充 `padding` 大小的0。如果设置此模式, `padding` 必须大于或等于0。
|
||||
|
||||
- **padding** (`Union[int, tuple[int]]`) – 输入的高度和宽度方向上填充的数量。数据类型为int或包含4个整数的tuple。如果 `padding` 是一个整数,那么上、下、左、右的填充都等于 `padding` 。如果 `padding` 是一个有4个整数的tuple,那么上、下、左、右的填充分别等于 `padding[0]` 、 `padding[1]` 、 `padding[2]` 和 `padding[3]` 。值应该要大于等于0,默认值:0。
|
||||
- **dilation** (`Union[int, tuple[int]]`) – 二维卷积核膨胀尺寸。数据类型为整型或具有两个整型的tuple。若 :math:`k > 1` ,则kernel间隔 `k` 个元素进行采样。垂直和水平方向上的 `k` ,其取值范围分别为[1, H]和[1, W]。默认值:1。
|
||||
- **group** (`int`) – 将过滤器拆分为组, `in_channels` 和 `out_channels` 必须可被 `group` 整除。如果组数等于 `in_channels` 和 `out_channels` ,这个二维卷积层也被称为二维深度卷积层。默认值:1.
|
||||
- **has_bias** (`bool`) – Conv2d层是否添加偏置参数。默认值:False。
|
||||
- **padding** (`Union[int, tuple[int]]`) - 输入的高度和宽度方向上填充的数量。数据类型为int或包含4个整数的tuple。如果 `padding` 是一个整数,那么上、下、左、右的填充都等于 `padding` 。如果 `padding` 是一个有4个整数的tuple,那么上、下、左、右的填充分别等于 `padding[0]` 、 `padding[1]` 、 `padding[2]` 和 `padding[3]` 。值应该要大于等于0,默认值:0。
|
||||
- **dilation** (`Union[int, tuple[int]]`) - 二维卷积核膨胀尺寸。数据类型为整型或具有两个整型的tuple。若 :math:`k > 1` ,则kernel间隔 `k` 个元素进行采样。垂直和水平方向上的 `k` ,其取值范围分别为[1, H]和[1, W]。默认值:1。
|
||||
- **group** (`int`) - 将过滤器拆分为组, `in_channels` 和 `out_channels` 必须可被 `group` 整除。如果组数等于 `in_channels` 和 `out_channels` ,这个二维卷积层也被称为二维深度卷积层。默认值:1.
|
||||
- **has_bias** (`bool`) - Conv2d层是否添加偏置参数。默认值:False。
|
||||
- **weight_init** (Union[Tensor, str, Initializer, numbers.Number]) - 权重参数的初始化方法。它可以是Tensor,str,Initializer或numbers.Number。当使用str时,可选"TruncatedNormal","Normal","Uniform","HeUniform"和"XavierUniform"分布以及常量"One"和"Zero"分布的值,可接受别名"xavier_uniform","he_uniform","ones"和"zeros"。上述字符串大小写均可。更多细节请参考Initializer的值。默认值:"normal"。
|
||||
- **bias_init** (Union[Tensor, str, Initializer, numbers.Number]) - 偏置参数的初始化方法。可以使用的初始化方法与"weight_init"相同。更多细节请参考Initializer的值。默认值:"zeros"。
|
||||
- **data_format** (`str`) – 数据格式的可选值有"NHWC","NCHW"。默认值:"NCHW"。
|
||||
- **data_format** (`str`) - 数据格式的可选值有"NHWC","NCHW"。默认值:"NCHW"。
|
||||
|
||||
**输入:**
|
||||
|
||||
**x** (Tensor) - Shape为 :math:`(N, C_{in}, H_{in}, W_{in})` 或者 :math:`(N, H_{in}, W_{in}, C_{in})` 的Tensor。
|
||||
- **x** (Tensor) - Shape为 :math:`(N, C_{in}, H_{in}, W_{in})` 或者 :math:`(N, H_{in}, W_{in}, C_{in})` 的Tensor。
|
||||
|
||||
**输出:**
|
||||
|
||||
|
|
|
@ -13,20 +13,20 @@ mindspore.nn.Conv2dTranspose
|
|||
|
||||
**参数:**
|
||||
|
||||
- **in_channels** (`int`) – Conv2dTranspose层输入Tensor的空间维度。
|
||||
- **in_channels** (`int`) - Conv2dTranspose层输入Tensor的空间维度。
|
||||
- **out_channels** (`dict`) - Conv2dTranspose层输出Tensor的空间维度。
|
||||
- **kernel_size** (`Union[int, tuple[int]]`) – 指定二维卷积核的高度和宽度。数据类型为整型或两个整型的tuple。一个整数表示卷积核的高度和宽度均为该值。两个整数的tuple分别表示卷积核的高度和宽度。
|
||||
- **stride** (`Union[int, tuple[int]]`) – 二维卷积核的移动步长。数据类型为整型或两个整型的tuple。一个整数表示在高度和宽度方向的移动步长均为该值。两个整数的tuple分别表示在高度和宽度方向的移动步长。默认值:1。
|
||||
- **pad_mode** (`str`) – 指定填充模式。可选值为"same"、"valid"、"pad"。默认值:"same"。
|
||||
- **kernel_size** (`Union[int, tuple[int]]`) - 指定二维卷积核的高度和宽度。数据类型为整型或两个整型的tuple。一个整数表示卷积核的高度和宽度均为该值。两个整数的tuple分别表示卷积核的高度和宽度。
|
||||
- **stride** (`Union[int, tuple[int]]`) - 二维卷积核的移动步长。数据类型为整型或两个整型的tuple。一个整数表示在高度和宽度方向的移动步长均为该值。两个整数的tuple分别表示在高度和宽度方向的移动步长。默认值:1。
|
||||
- **pad_mode** (`str`) - 指定填充模式。可选值为"same"、"valid"、"pad"。默认值:"same"。
|
||||
|
||||
- **same**:输出的高度和宽度分别与输入整除 `stride` 后的值相同。若设置该模式,`padding` 的值必须为0。
|
||||
- **valid**:在不填充的前提下返回有效计算所得的输出。不满足计算的多余像素会被丢弃。如果设置此模式,则 `padding` 的值必须为0。
|
||||
- **pad**:对输入进行填充。在输入的高度和宽度方向上填充 `padding` 大小的0。如果设置此模式, `padding` 必须大于或等于0。
|
||||
|
||||
- **padding** (`Union[int, tuple[int]]`) – 输入的高度和宽度方向上填充的数量。数据类型为整型或包含四个整数的tuple。如果 `padding` 是一个整数,那么上、下、左、右的填充都等于 `padding` 。如果 `padding` 是一个有四个整数的tuple,那么上、下、左、右的填充分别等于 `padding[0]` 、 `padding[1]` 、 `padding[2]` 和 `padding[3]` 。值应该要大于等于0,默认值:0。
|
||||
- **dilation** (`Union[int, tuple[int]]`) – 二维卷积核膨胀尺寸。数据类型为整型或具有两个整型的tuple。若 :math:`k > 1` ,则kernel间隔 `k` 个元素进行采样。高度和宽度方向上的 `k` ,其取值范围分别为[1, H]和[1, W]。默认值:1。
|
||||
- **group** (`int`) – 将过滤器拆分为组, `in_channels` 和 `out_channels` 必须可被 `group` 整除。如果组数等于 `in_channels` 和 `out_channels` ,这个二维卷积层也被称为二维深度卷积层。默认值:1.
|
||||
- **has_bias** (`bool`) – Conv2dTranspose层是否添加偏置参数。默认值:False。
|
||||
- **padding** (`Union[int, tuple[int]]`) - 输入的高度和宽度方向上填充的数量。数据类型为整型或包含四个整数的tuple。如果 `padding` 是一个整数,那么上、下、左、右的填充都等于 `padding` 。如果 `padding` 是一个有四个整数的tuple,那么上、下、左、右的填充分别等于 `padding[0]` 、 `padding[1]` 、 `padding[2]` 和 `padding[3]` 。值应该要大于等于0,默认值:0。
|
||||
- **dilation** (`Union[int, tuple[int]]`) - 二维卷积核膨胀尺寸。数据类型为整型或具有两个整型的tuple。若 :math:`k > 1` ,则kernel间隔 `k` 个元素进行采样。高度和宽度方向上的 `k` ,其取值范围分别为[1, H]和[1, W]。默认值:1。
|
||||
- **group** (`int`) - 将过滤器拆分为组, `in_channels` 和 `out_channels` 必须可被 `group` 整除。如果组数等于 `in_channels` 和 `out_channels` ,这个二维卷积层也被称为二维深度卷积层。默认值:1.
|
||||
- **has_bias** (`bool`) - Conv2dTranspose层是否添加偏置参数。默认值:False。
|
||||
- **weight_init** (Union[Tensor, str, Initializer, numbers.Number]) - 权重参数的初始化方法。它可以是Tensor,str,Initializer或numbers.Number。当使用str时,可选"TruncatedNormal","Normal","Uniform","HeUniform"和"XavierUniform"分布以及常量"One"和"Zero"分布的值,可接受别名"xavier_uniform","he_uniform","ones"和"zeros"。上述字符串大小写均可。更多细节请参考Initializer的值。默认值:"normal"。
|
||||
- **bias_init** (Union[Tensor, str, Initializer, numbers.Number]) - 偏置参数的初始化方法。可以使用的初始化方法与"weight_init"相同。更多细节请参考Initializer的值。默认值:"zeros"。
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ mindspore.nn.Conv3dTranspose
|
|||
- pad:对输入进行填充。 在输入的深度、高度和宽度方向上填充 `padding` 大小的0。如果设置此模式, `padding` 必须大于或等于0。
|
||||
|
||||
- **padding** (Union(int, tuple[int])) - 输入的深度、高度和宽度方向上填充的数量。数据类型为int或包含6个整数的tuple。如果 `padding` 是一个整数,则前部、后部、顶部,底部,左边和右边的填充都等于 `padding` 。如果 `padding` 是6个整数的tuple,则前部、尾部、顶部、底部、左边和右边的填充分别等于填充padding[0]、padding[1]、padding[2]、padding[3]、padding[4]和padding[5]。值应该要大于等于0,默认值:0。
|
||||
- **dilation** (Union[int, tuple[int]]) - 三维卷积核膨胀尺寸。数据类型为int或三个整数的tuple。若 :math:`k > 1` ,则kernel间隔 `k` 个元素进行采样。深度、高度和宽度方向上的 `k` ,其取值范围分别为[1, D]、[1, H]和[1, W]。默认值:1。
|
||||
- **dilation** (Union[int, tuple[int]]) - 三维卷积核膨胀尺寸。数据类型为int或三个整数的tuple。若 :math:`k > 1` ,则kernel间隔 `k` 个元素进行采样。深度、高度和宽度方向上的 `k` ,其取值范围分别为[1, D]、[1, H]和[1, W]。默认值:1。
|
||||
- **group** (int) - 将过滤器拆分为组, `in_channels` 和 `out_channels` 必须可被 `group` 整除。当 `group` 大于1时,暂不支持Ascend平台。默认值:1。当前仅支持1。
|
||||
- **output_padding** (Union(int, tuple[int])) - 输出的深度、高度和宽度方向上填充的数量。数据类型为int或包含6个整数的tuple。如果 `output_padding` 是一个整数,则前部、后部、顶部,底部,左边和右边的填充都等于 `output_padding` 。如果 `output_padding` 是6个整数的tuple,则前部、尾部、顶部、底部、左边和右边的填充分别等于填充output_padding[0]、output_padding[1]、output_padding[2]、output_padding[3]、output_padding[4]output_padding[5]。值应该要大于等于0,默认值:0。
|
||||
- **has_bias** (bool) - Conv3dTranspose层是否添加偏置参数。默认值:False。
|
||||
|
|
|
@ -15,7 +15,7 @@ mindspore.nn.Embedding
|
|||
- **vocab_size** (int) - 词典的大小。
|
||||
- **embedding_size** (int) - 每个嵌入向量的大小。
|
||||
- **use_one_hot** (bool) - 指定是否使用one-hot形式。默认值:False。
|
||||
- **embedding_table** (Union[Tensor, str, Initializer, numbers.Number]) – embedding_table的初始化方法。当指定为字符串,字符串取值请参见类 `mindspore.common.initializer` 。默认值:'normal'。
|
||||
- **embedding_table** (Union[Tensor, str, Initializer, numbers.Number]) - embedding_table的初始化方法。当指定为字符串,字符串取值请参见类 `mindspore.common.initializer` 。默认值:'normal'。
|
||||
- **dtype** (mindspore.dtype) - x的数据类型。默认值:mindspore.float32。
|
||||
- **padding_idx** (int, None) - 将 `padding_idx` 对应索引所输出的嵌入向量用零填充。默认值:None。该功能已停用。
|
||||
|
||||
|
|
|
@ -9,15 +9,15 @@ mindspore.nn.ForwardValueAndGrad
|
|||
通过梯度函数来创建反向图,用以计算梯度。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
||||
- **network** (Cell) - 训练网络。
|
||||
- **weights** (ParameterTuple) - 训练网络中需要计算梯度的的参数。
|
||||
- **get_all** (bool) - 如果为True,则计算网络输入对应的梯度。默认值:False。
|
||||
- **get_by_list** (bool) - 如果为True,则计算参数变量对应的梯度。如果 `get_all` 和 `get_by_list` 都为False,则计算第一个输入对应的梯度。如果 `get_all` 和 `get_by_list` 都为True,则以((输入的梯度),(参数的梯度))的形式同时获取输入和参数变量的梯度。默认值:False。
|
||||
- **sens_param** (bool) - 是否将sens作为输入。如果 `sens_param` 为False,则sens默认为'ones_like(outputs)'。默认值:False。如果 `sens_param` 为True,则需要指定sens的值。
|
||||
|
||||
|
||||
**输入:**
|
||||
|
||||
|
||||
- **(\*inputs)** (Tuple(Tensor...)):shape为 :math:`(N, \ldots)` 的输入tuple。
|
||||
- **(sens)**:反向传播梯度的缩放值。如果网络有单个输出,则sens是tensor。如果网络有多个输出,则sens是tuple(tensor)。
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ mindspore.nn.GraphCell
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – 如果图不是FuncGraph类型。
|
||||
- **TypeError** – 如果 `params_init` 不是字典。
|
||||
- **TypeError** – 如果 `params_init` 的key不是字符串。
|
||||
- **TypeError** – 如果 `params_init` 的value既不是 Tensor也不是Parameter。
|
||||
- **TypeError** - 如果图不是FuncGraph类型。
|
||||
- **TypeError** - 如果 `params_init` 不是字典。
|
||||
- **TypeError** - 如果 `params_init` 的key不是字符串。
|
||||
- **TypeError** - 如果 `params_init` 的value既不是 Tensor也不是Parameter。
|
||||
|
|
|
@ -16,7 +16,7 @@ mindspore.nn.TrainOneStepCell
|
|||
|
||||
**输入:**
|
||||
|
||||
**(\*inputs)** (Tuple(Tensor)) - shape为 :math:`(N, \ldots)` 的Tensor组成的元组。
|
||||
- **(\*inputs)** (Tuple(Tensor)) - shape为 :math:`(N, \ldots)` 的Tensor组成的元组。
|
||||
|
||||
**输出:**
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ mindspore.nn.TrainOneStepWithLossScaleCell
|
|||
|
||||
**输入:**
|
||||
|
||||
**(*inputs)** (Tuple(Tensor))- shape为 :math:`(N, \ldots)` 的Tensor组成的元组。
|
||||
- **(*inputs)** (Tuple(Tensor)) - shape为 :math:`(N, \ldots)` 的Tensor组成的元组。
|
||||
|
||||
**输出:**
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@ mindspore.nn.WithLossCell
|
|||
|
||||
**输入:**
|
||||
|
||||
- **data** (Tensor) - shape为 :math:`(N, \ldots)` 的Tensor。
|
||||
- **label** (Tensor) - shape为 :math:`(N, \ldots)` 的Tensor。
|
||||
- **data** (Tensor) - shape为 :math:`(N, \ldots)` 的Tensor。
|
||||
- **label** (Tensor) - shape为 :math:`(N, \ldots)` 的Tensor。
|
||||
|
||||
**输出:**
|
||||
|
||||
|
|
Loading…
Reference in New Issue