forked from mindspore-Ecosystem/mindspore
!29911 Add chinese document for primitive, ...
Merge pull request !29911 from huangbingjian/code_docs_primitive
This commit is contained in:
commit
bd8867d69b
|
@ -75,3 +75,34 @@ Random类型算子
|
|||
|
||||
mindspore.ops.Gamma
|
||||
mindspore.ops.UniformReal
|
||||
|
||||
|
||||
原语
|
||||
----
|
||||
|
||||
.. cnmsplatformautosummary::
|
||||
:toctree: ops
|
||||
|
||||
mindspore.ops.constexpr
|
||||
mindspore.ops.prim_attr_register
|
||||
mindspore.ops.Primitive
|
||||
mindspore.ops.PrimitiveWithCheck
|
||||
mindspore.ops.PrimitiveWithInfer
|
||||
|
||||
|
||||
函数实现注册
|
||||
--------------
|
||||
|
||||
.. cnmsplatformautosummary::
|
||||
:toctree: ops
|
||||
|
||||
mindspore.ops.get_vm_impl_fn
|
||||
|
||||
|
||||
算子信息注册
|
||||
--------------
|
||||
|
||||
.. cnmsplatformautosummary::
|
||||
:toctree: ops
|
||||
|
||||
mindspore.ops.DataType
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
mindspore.ops.DataType
|
||||
======================
|
||||
|
||||
.. py:class:: mindspore.ops.DataType:
|
||||
|
||||
Ascend算子的dtype和format的多种组合。
|
||||
|
||||
当前支持:
|
||||
|
||||
.. code-block::
|
||||
|
||||
None_None = ("", "")
|
||||
None_Default = ("", "DefaultFormat")
|
||||
BOOL_None = ("bool", "")
|
||||
BOOL_Default = ("bool", "DefaultFormat")
|
||||
BOOL_5HD = ("bool", "NC1HWC0")
|
||||
BOOL_FracZ = ("bool", "FracZ")
|
||||
BOOL_FracNZ = ("bool", "FRACTAL_NZ")
|
||||
BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
|
||||
BOOL_NCHW = ("bool", "NCHW")
|
||||
BOOL_NHWC = ("bool", "NHWC")
|
||||
BOOL_HWCN = ("bool", "HWCN")
|
||||
BOOL_NDHWC = ("bool", "NDHWC")
|
||||
BOOL_ChannelLast = ("bool", "ChannelLast")
|
||||
|
||||
I8_None = ("int8", "")
|
||||
I8_Default = ("int8", "DefaultFormat")
|
||||
I8_5HD = ("int8", "NC1HWC0")
|
||||
I8_FracZ = ("int8", "FracZ")
|
||||
I8_FracNZ = ("int8", "FRACTAL_NZ")
|
||||
I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
|
||||
I8_NCHW = ("int8", "NCHW")
|
||||
I8_NHWC = ("int8", "NHWC")
|
||||
I8_HWCN = ("int8", "HWCN")
|
||||
I8_NDHWC = ("int8", "NDHWC")
|
||||
I8_ChannelLast = ("int8", "ChannelLast")
|
||||
|
||||
U8_None = ("uint8", "")
|
||||
U8_Default = ("uint8", "DefaultFormat")
|
||||
U8_5HD = ("uint8", "NC1HWC0")
|
||||
U8_FracZ = ("uint8", "FracZ")
|
||||
U8_FracNZ = ("uint8", "FRACTAL_NZ")
|
||||
U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
|
||||
U8_NCHW = ("uint8", "NCHW")
|
||||
U8_NHWC = ("uint8", "NHWC")
|
||||
U8_HWCN = ("uint8", "HWCN")
|
||||
U8_NDHWC = ("uint8", "NDHWC")
|
||||
U8_ChannelLast = ("uint8", "ChannelLast")
|
||||
|
||||
I16_None = ("int16", "")
|
||||
I16_Default = ("int16", "DefaultFormat")
|
||||
I16_5HD = ("int16", "NC1HWC0")
|
||||
I16_FracZ = ("int16", "FracZ")
|
||||
I16_FracNZ = ("int16", "FRACTAL_NZ")
|
||||
I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
|
||||
I16_NCHW = ("int16", "NCHW")
|
||||
I16_NHWC = ("int16", "NHWC")
|
||||
I16_HWCN = ("int16", "HWCN")
|
||||
I16_NDHWC = ("int16", "NDHWC")
|
||||
I16_ChannelLast = ("int16", "ChannelLast")
|
||||
|
||||
U16_None = ("uint16", "")
|
||||
U16_Default = ("uint16", "DefaultFormat")
|
||||
U16_5HD = ("uint16", "NC1HWC0")
|
||||
U16_FracZ = ("uint16", "FracZ")
|
||||
U16_FracNZ = ("uint16", "FRACTAL_NZ")
|
||||
U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
|
||||
U16_NCHW = ("uint16", "NCHW")
|
||||
U16_NHWC = ("uint16", "NHWC")
|
||||
U16_HWCN = ("uint16", "HWCN")
|
||||
U16_NDHWC = ("uint16", "NDHWC")
|
||||
U16_ChannelLast = ("uint16", "ChannelLast")
|
||||
|
||||
I32_None = ("int32", "")
|
||||
I32_Default = ("int32", "DefaultFormat")
|
||||
I32_5HD = ("int32", "NC1HWC0")
|
||||
I32_FracZ = ("int32", "FracZ")
|
||||
I32_FracNZ = ("int32", "FRACTAL_NZ")
|
||||
I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
|
||||
I32_NCHW = ("int32", "NCHW")
|
||||
I32_NHWC = ("int32", "NHWC")
|
||||
I32_HWCN = ("int32", "HWCN")
|
||||
I32_NDHWC = ("int32", "NDHWC")
|
||||
I32_ChannelLast = ("int32", "ChannelLast")
|
||||
|
||||
U32_None = ("uint32", "")
|
||||
U32_Default = ("uint32", "DefaultFormat")
|
||||
U32_5HD = ("uint32", "NC1HWC0")
|
||||
U32_FracZ = ("uint32", "FracZ")
|
||||
U32_FracNZ = ("uint32", "FRACTAL_NZ")
|
||||
U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
|
||||
U32_NCHW = ("uint32", "NCHW")
|
||||
U32_NHWC = ("uint32", "NHWC")
|
||||
U32_HWCN = ("uint32", "HWCN")
|
||||
U32_NDHWC = ("uint32", "NDHWC")
|
||||
U32_ChannelLast = ("uint32", "ChannelLast")
|
||||
|
||||
I64_None = ("int64", "")
|
||||
I64_Default = ("int64", "DefaultFormat")
|
||||
I64_5HD = ("int64", "NC1HWC0")
|
||||
I64_FracZ = ("int64", "FracZ")
|
||||
I64_FracNZ = ("int64", "FRACTAL_NZ")
|
||||
I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
|
||||
I64_NCHW = ("int64", "NCHW")
|
||||
I64_NHWC = ("int64", "NHWC")
|
||||
I64_HWCN = ("int64", "HWCN")
|
||||
I64_NDHWC = ("int64", "NDHWC")
|
||||
I64_ChannelLast = ("int64", "ChannelLast")
|
||||
|
||||
U64_None = ("uint64", "")
|
||||
U64_Default = ("uint64", "DefaultFormat")
|
||||
U64_5HD = ("uint64", "NC1HWC0")
|
||||
U64_FracZ = ("uint64", "FracZ")
|
||||
U64_FracNZ = ("uint64", "FRACTAL_NZ")
|
||||
U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
|
||||
U64_NCHW = ("uint64", "NCHW")
|
||||
U64_NHWC = ("uint64", "NHWC")
|
||||
U64_HWCN = ("uint64", "HWCN")
|
||||
U64_NDHWC = ("uint64", "NDHWC")
|
||||
U64_ChannelLast = ("uint64", "ChannelLast")
|
||||
|
||||
F16_None = ("float16", "")
|
||||
F16_Default = ("float16", "DefaultFormat")
|
||||
F16_5HD = ("float16", "NC1HWC0")
|
||||
F16_FracZ = ("float16", "FracZ")
|
||||
F16_FracNZ = ("float16", "FRACTAL_NZ")
|
||||
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
|
||||
F16_NCHW = ("float16", "NCHW")
|
||||
F16_NHWC = ("float16", "NHWC")
|
||||
F16_HWCN = ("float16", "HWCN")
|
||||
F16_NDHWC = ("float16", "NDHWC")
|
||||
F16_NCDHW = ("float16", "NCDHW")
|
||||
F16_DHWCN = ("float16", "DHWCN")
|
||||
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
||||
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
||||
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
||||
F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
|
||||
F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
|
||||
F16_ChannelLast = ("float16", "ChannelLast")
|
||||
|
||||
F32_None = ("float32", "")
|
||||
F32_Default = ("float32", "DefaultFormat")
|
||||
F32_5HD = ("float32", "NC1HWC0")
|
||||
F32_FracZ = ("float32", "FracZ")
|
||||
F32_FracNZ = ("float32", "FRACTAL_NZ")
|
||||
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
|
||||
F32_NCHW = ("float32", "NCHW")
|
||||
F32_NHWC = ("float32", "NHWC")
|
||||
F32_HWCN = ("float32", "HWCN")
|
||||
F32_NDHWC = ("float32", "NDHWC")
|
||||
F32_NCDHW = ("float32", "NCDHW")
|
||||
F32_DHWCN = ("float32", "DHWCN")
|
||||
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
||||
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
||||
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
||||
F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
|
||||
F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
|
||||
F32_ChannelLast = ("float32", "ChannelLast")
|
||||
|
||||
F64_None = ("float64", "")
|
||||
F64_Default = ("float64", "DefaultFormat")
|
||||
F64_5HD = ("float64", "NC1HWC0")
|
||||
F64_FracZ = ("float64", "FracZ")
|
||||
F64_FracNZ = ("float64", "FRACTAL_NZ")
|
||||
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
|
||||
F64_NCHW = ("float64", "NCHW")
|
||||
F64_NHWC = ("float64", "NHWC")
|
||||
F64_HWCN = ("float64", "HWCN")
|
||||
F64_NDHWC = ("float64", "NDHWC")
|
||||
F64_ChannelLast = ("float64", "ChannelLast")
|
||||
|
||||
C64_Default = ("complex64", "DefaultFormat")
|
||||
C128_Default = ("complex128", "DefaultFormat")
|
|
@ -0,0 +1,100 @@
|
|||
mindspore.ops.Primitive
|
||||
=======================
|
||||
|
||||
.. py:class:: mindspore.Primitive(name)
|
||||
|
||||
Primitive是Python中算子原语的基类。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **name** (str) - 当前Primitive的名称。
|
||||
|
||||
.. py:method:: add_prim_attr(name, value)
|
||||
|
||||
添加Primitive的属性。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **name** (str) - 属性名称。
|
||||
- **value** (Any) - 属性值。
|
||||
|
||||
.. py:method:: del_prim_attr(name)
|
||||
|
||||
删除Primitive的属性。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **name** (str) - 属性名称。
|
||||
|
||||
.. py:method:: check_elim(*args)
|
||||
|
||||
检查是否可以消除此Primitive。有需要的子类可以重写该方法。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **args** (Primitive参数的类型) - 与当前Primitive的参数相同。
|
||||
|
||||
**返回:**
|
||||
|
||||
由两个元素组成的元组。第一个元素是指是否能在编译阶段计算Primitive,第二个元素是计算结果。
|
||||
|
||||
.. py:method:: init_prim_io_names(inputs, outputs)
|
||||
|
||||
初始化Tensor或属性的输入输出的名称。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **inputs** (list[str]) - 输入名称的列表。
|
||||
- **outputs** (list[str]) - 输出名称的列表。
|
||||
|
||||
.. py:method:: recompute(mode=True)
|
||||
|
||||
设置Primitive的重计算属性。
|
||||
|
||||
如果有一个被设置了重计算属性的Primitive,并且其结果在计算导数的时候被使用,那么不会保存该Primitive在前向网络中的中间计算结果,而是在自动微分的时候重新进行计算。
|
||||
|
||||
.. note::
|
||||
- 如果计算涉及随机化或全局变量,则暂无法保证等效性。
|
||||
- 在PyNative模式下不支持。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **mode** (bool) - Primitive是否设置了重计算。默认值:True。
|
||||
|
||||
.. py:method:: set_prim_instance_name(instance_name)
|
||||
|
||||
设置Primitive算子的实例的名称。
|
||||
|
||||
.. note::
|
||||
当用户定义Primitive算子时,默认调用它。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **instance_name** (str) - 用户设置的Primitive算子的实例的名称。
|
||||
|
||||
.. py:method:: set_stage(stage)
|
||||
|
||||
将stage的ID添加到Primitive属性中。
|
||||
|
||||
.. note::
|
||||
仅在半自动并行模式下有效。在其他并行模式下,请将其设置为0。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **stage** (int) - 当前stage的ID。
|
||||
|
||||
.. py:method:: shard(in_strategy, out_strategy)
|
||||
|
||||
将切分策略添加到Primitive属性中。
|
||||
|
||||
.. note::
|
||||
仅在半自动并行或自动并行模式下有效。在其他并行模式中,将忽略此处设置的策略。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **in_strategy** (tuple) - 描述算子输入的切分策略。默认值:None。
|
||||
- **out_strategy** (tuple) - 描述算子输出的切分策略,仅针对某些算子,如MatMul。默认值:None。
|
||||
|
||||
.. py:method:: update_parameter()
|
||||
|
||||
判断此Primitive是否会更新参数的值。
|
|
@ -0,0 +1,41 @@
|
|||
mindspore.ops.PrimitiveWithCheck
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.PrimitiveWithCheck(name)
|
||||
|
||||
PrimitiveWithCheck是Python中原语的基类,定义了检查算子输入参数的函数,但是使用了C++源码中注册的推理方法。
|
||||
|
||||
可以重写三个方法来定义Primitive的检查逻辑: __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__(),则__check__()的优先级最高。
|
||||
|
||||
如果未定义__check__(),则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法(如PrimitiveWithInfer),用于常量传播。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **name** (str) - 当前Primitive的名称。
|
||||
|
||||
.. py:method:: check_dtype(*args)
|
||||
|
||||
检查输入参数的数据类型。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **args** (:class:`mindspore.dtype`) - 输入的数据类型。
|
||||
|
||||
**返回:**
|
||||
|
||||
None。
|
||||
|
||||
.. py:method:: check_shape(*args)
|
||||
|
||||
检查输入参数的shape。
|
||||
|
||||
.. note::
|
||||
Scalar的shape是一个空元组。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **args** (tuple(int)) - 输入tensor的shape。
|
||||
|
||||
**返回:**
|
||||
|
||||
None。
|
|
@ -0,0 +1,53 @@
|
|||
mindspore.ops.PrimitiveWithInfer
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.PrimitiveWithInfer(name)
|
||||
|
||||
PrimitiveWithInfer是Python中的原语基类,在python中定义了跟踪推理的函数。
|
||||
|
||||
可以重写四个方法来定义Primitive的推断逻辑:__infer__()、infer_shape()、infer_dtype()和infer_value()。如果在Primitive中定义了__infer__(),则__infer__()的优先级最高。
|
||||
|
||||
如果未定义__infer__(),则可以定义infer_shape()和infer_dtype()来描述shape和类型的推断逻辑。infer_value()用于常量传播。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **name** (str) - 当前Primitive的名称。
|
||||
|
||||
.. py:method:: infer_dtype(*args)
|
||||
|
||||
根据输入类型推断输出类型。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **args** (:class:`mindspore.dtype`) - 输入的数据类型。
|
||||
|
||||
**返回:**
|
||||
|
||||
:class:`mindspore.dtype`,输出的数据类型。
|
||||
|
||||
.. py:method:: infer_shape(*args)
|
||||
|
||||
根据输入形状推断输出形状。
|
||||
|
||||
.. note::
|
||||
Scalar的shape是一个空元组。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **args** (tuple(int)) - 输入tensor的shape。
|
||||
|
||||
**返回:**
|
||||
|
||||
`tuple(int)`,输出tensor的shape。
|
||||
|
||||
.. py:method:: infer_value(*args)
|
||||
|
||||
根据编译时的输入值推断输出值。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **args** (Any) - 输入的值。
|
||||
|
||||
**返回:**
|
||||
|
||||
输出的值。如果编译时无法推断该值,返回`None`。
|
|
@ -0,0 +1,12 @@
|
|||
mindspore.ops.constexpr
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.constexpr(fn=None, get_instance=True, name=None):
|
||||
|
||||
创建PrimiveWithInfer算子,用于在编译时推断值。可以用它定义函数,从而使用构造函数中的常量计算出常量值。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **fn** (function) - `fn` 用作输出算子的infer_value。默认值:None。
|
||||
- **get_instance** (bool) - 如果为True,返回算子的实例,否则返回算子的类。默认值:True。
|
||||
- **name** (str) - 定义算子的名称。如果 `name` 为None,则使用函数名称作为算子名称。默认值:None。
|
|
@ -0,0 +1,17 @@
|
|||
mindspore.ops.get_vm_impl_fn
|
||||
============================
|
||||
|
||||
.. py:function:: mindspore.ops.get_vm_impl_fn(prim):
|
||||
|
||||
通过Primitive对象或Primitive名称,获取虚拟实现函数。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **prim** (Union[Primitive, str]) - 算子注册的Primitive对象或名称。
|
||||
|
||||
.. note::
|
||||
该机制目前适用于调试。
|
||||
|
||||
**返回:**
|
||||
|
||||
函数,虚拟实现函数。
|
|
@ -0,0 +1,16 @@
|
|||
mindspore.ops.prim_attr_register
|
||||
================================
|
||||
|
||||
.. py:function:: mindspore.ops.prim_attr_register(fn):
|
||||
|
||||
Primitive属性的注册器。
|
||||
|
||||
注册装饰器,其中装饰器用于内置算子的Primitive的'__init__'函数。该函数将添加'__init__'的所有参数作为算子属性,并且初始化Primitive的名称。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **fn** (function) - Primitive的__init__函数。
|
||||
|
||||
**返回:**
|
||||
|
||||
函数,原始函数。
|
|
@ -764,10 +764,6 @@ class DataType:
|
|||
r"""
|
||||
Various combinations of dtype and format of Ascend ops.
|
||||
|
||||
The current list below may be incomplete.
|
||||
|
||||
Please add it if necessary.
|
||||
|
||||
current support:
|
||||
|
||||
.. code-block::
|
||||
|
@ -933,6 +929,9 @@ class DataType:
|
|||
F64_HWCN = ("float64", "HWCN")
|
||||
F64_NDHWC = ("float64", "NDHWC")
|
||||
F64_ChannelLast = ("float64", "ChannelLast")
|
||||
|
||||
C64_Default = ("complex64", "DefaultFormat")
|
||||
C128_Default = ("complex128", "DefaultFormat")
|
||||
"""
|
||||
|
||||
None_None = ("", "")
|
||||
|
|
|
@ -381,8 +381,8 @@ class Primitive(Primitive_):
|
|||
|
||||
class PrimitiveWithCheck(Primitive):
|
||||
"""
|
||||
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator
|
||||
input arguments but used the infer method registered in c++ source codes.
|
||||
PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments
|
||||
of operators, but uses the infer method registered in c++ source codes.
|
||||
|
||||
There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(),
|
||||
check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called.
|
||||
|
|
Loading…
Reference in New Issue