Add chinese document for primitive, ...

This commit is contained in:
huangbingjian 2022-02-10 20:57:43 +08:00
parent 5b360e39fb
commit a5d8b0b3ef
10 changed files with 448 additions and 6 deletions

View File

@ -75,3 +75,34 @@ Random类型算子
mindspore.ops.Gamma mindspore.ops.Gamma
mindspore.ops.UniformReal mindspore.ops.UniformReal
原语
----
.. cnmsplatformautosummary::
:toctree: ops
mindspore.ops.constexpr
mindspore.ops.prim_attr_register
mindspore.ops.Primitive
mindspore.ops.PrimitiveWithCheck
mindspore.ops.PrimitiveWithInfer
函数实现注册
--------------
.. cnmsplatformautosummary::
:toctree: ops
mindspore.ops.get_vm_impl_fn
算子信息注册
--------------
.. cnmsplatformautosummary::
:toctree: ops
mindspore.ops.DataType

View File

@ -0,0 +1,173 @@
mindspore.ops.DataType
======================
.. py:class:: mindspore.ops.DataType:
Ascend算子的dtype和format的多种组合。
当前支持:
.. code-block::
None_None = ("", "")
None_Default = ("", "DefaultFormat")
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")
BOOL_FracZ = ("bool", "FracZ")
BOOL_FracNZ = ("bool", "FRACTAL_NZ")
BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
BOOL_NCHW = ("bool", "NCHW")
BOOL_NHWC = ("bool", "NHWC")
BOOL_HWCN = ("bool", "HWCN")
BOOL_NDHWC = ("bool", "NDHWC")
BOOL_ChannelLast = ("bool", "ChannelLast")
I8_None = ("int8", "")
I8_Default = ("int8", "DefaultFormat")
I8_5HD = ("int8", "NC1HWC0")
I8_FracZ = ("int8", "FracZ")
I8_FracNZ = ("int8", "FRACTAL_NZ")
I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
I8_NCHW = ("int8", "NCHW")
I8_NHWC = ("int8", "NHWC")
I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
I8_ChannelLast = ("int8", "ChannelLast")
U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
U8_5HD = ("uint8", "NC1HWC0")
U8_FracZ = ("uint8", "FracZ")
U8_FracNZ = ("uint8", "FRACTAL_NZ")
U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
U8_NCHW = ("uint8", "NCHW")
U8_NHWC = ("uint8", "NHWC")
U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
U8_ChannelLast = ("uint8", "ChannelLast")
I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")
I16_5HD = ("int16", "NC1HWC0")
I16_FracZ = ("int16", "FracZ")
I16_FracNZ = ("int16", "FRACTAL_NZ")
I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
I16_NCHW = ("int16", "NCHW")
I16_NHWC = ("int16", "NHWC")
I16_HWCN = ("int16", "HWCN")
I16_NDHWC = ("int16", "NDHWC")
I16_ChannelLast = ("int16", "ChannelLast")
U16_None = ("uint16", "")
U16_Default = ("uint16", "DefaultFormat")
U16_5HD = ("uint16", "NC1HWC0")
U16_FracZ = ("uint16", "FracZ")
U16_FracNZ = ("uint16", "FRACTAL_NZ")
U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
U16_NCHW = ("uint16", "NCHW")
U16_NHWC = ("uint16", "NHWC")
U16_HWCN = ("uint16", "HWCN")
U16_NDHWC = ("uint16", "NDHWC")
U16_ChannelLast = ("uint16", "ChannelLast")
I32_None = ("int32", "")
I32_Default = ("int32", "DefaultFormat")
I32_5HD = ("int32", "NC1HWC0")
I32_FracZ = ("int32", "FracZ")
I32_FracNZ = ("int32", "FRACTAL_NZ")
I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
I32_NCHW = ("int32", "NCHW")
I32_NHWC = ("int32", "NHWC")
I32_HWCN = ("int32", "HWCN")
I32_NDHWC = ("int32", "NDHWC")
I32_ChannelLast = ("int32", "ChannelLast")
U32_None = ("uint32", "")
U32_Default = ("uint32", "DefaultFormat")
U32_5HD = ("uint32", "NC1HWC0")
U32_FracZ = ("uint32", "FracZ")
U32_FracNZ = ("uint32", "FRACTAL_NZ")
U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
U32_NCHW = ("uint32", "NCHW")
U32_NHWC = ("uint32", "NHWC")
U32_HWCN = ("uint32", "HWCN")
U32_NDHWC = ("uint32", "NDHWC")
U32_ChannelLast = ("uint32", "ChannelLast")
I64_None = ("int64", "")
I64_Default = ("int64", "DefaultFormat")
I64_5HD = ("int64", "NC1HWC0")
I64_FracZ = ("int64", "FracZ")
I64_FracNZ = ("int64", "FRACTAL_NZ")
I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
I64_NCHW = ("int64", "NCHW")
I64_NHWC = ("int64", "NHWC")
I64_HWCN = ("int64", "HWCN")
I64_NDHWC = ("int64", "NDHWC")
I64_ChannelLast = ("int64", "ChannelLast")
U64_None = ("uint64", "")
U64_Default = ("uint64", "DefaultFormat")
U64_5HD = ("uint64", "NC1HWC0")
U64_FracZ = ("uint64", "FracZ")
U64_FracNZ = ("uint64", "FRACTAL_NZ")
U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
U64_NCHW = ("uint64", "NCHW")
U64_NHWC = ("uint64", "NHWC")
U64_HWCN = ("uint64", "HWCN")
U64_NDHWC = ("uint64", "NDHWC")
U64_ChannelLast = ("uint64", "ChannelLast")
F16_None = ("float16", "")
F16_Default = ("float16", "DefaultFormat")
F16_5HD = ("float16", "NC1HWC0")
F16_FracZ = ("float16", "FracZ")
F16_FracNZ = ("float16", "FRACTAL_NZ")
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
F16_NCHW = ("float16", "NCHW")
F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN")
F16_NDHWC = ("float16", "NDHWC")
F16_NCDHW = ("float16", "NCDHW")
F16_DHWCN = ("float16", "DHWCN")
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
F16_ChannelLast = ("float16", "ChannelLast")
F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat")
F32_5HD = ("float32", "NC1HWC0")
F32_FracZ = ("float32", "FracZ")
F32_FracNZ = ("float32", "FRACTAL_NZ")
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")
F32_NDHWC = ("float32", "NDHWC")
F32_NCDHW = ("float32", "NCDHW")
F32_DHWCN = ("float32", "DHWCN")
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
F32_ChannelLast = ("float32", "ChannelLast")
F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
F64_5HD = ("float64", "NC1HWC0")
F64_FracZ = ("float64", "FracZ")
F64_FracNZ = ("float64", "FRACTAL_NZ")
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC")
F64_ChannelLast = ("float64", "ChannelLast")
C64_Default = ("complex64", "DefaultFormat")
C128_Default = ("complex128", "DefaultFormat")

View File

@ -0,0 +1,100 @@
mindspore.ops.Primitive
=======================
.. py:class:: mindspore.Primitive(name)
Primitive是Python中算子原语的基类。
**参数:**
- **name** (str) - 当前Primitive的名称。
.. py:method:: add_prim_attr(name, value)
添加Primitive的属性。
**参数:**
- **name** (str) - 属性名称。
- **value** (Any) - 属性值。
.. py:method:: del_prim_attr(name)
删除Primitive的属性。
**参数:**
- **name** (str) - 属性名称。
.. py:method:: check_elim(*args)
检查是否可以消除此Primitive。有需要的子类可以重写该方法。
**参数:**
- **args** (Primitive参数的类型) - 与当前Primitive的参数相同。
**返回:**
由两个元素组成的元组。第一个元素是指是否能在编译阶段计算Primitive第二个元素是计算结果。
.. py:method:: init_prim_io_names(inputs, outputs)
初始化Tensor或属性的输入输出的名称。
**参数:**
- **inputs** (list[str]) - 输入名称的列表。
- **outputs** (list[str]) - 输出名称的列表。
.. py:method:: recompute(mode=True)
设置Primitive的重计算属性。
如果有一个被设置了重计算属性的Primitive并且其结果在计算导数的时候被使用那么不会保存该Primitive在前向网络中的中间计算结果而是在自动微分的时候重新进行计算。
.. note::
- 如果计算涉及随机化或全局变量,则暂无法保证等效性。
- 在PyNative模式下不支持。
**参数:**
- **mode** (bool) - Primitive是否设置了重计算。默认值True。
.. py:method:: set_prim_instance_name(instance_name)
设置Primitive算子的实例的名称。
.. note::
当用户定义Primitive算子时默认调用它。
**参数:**
- **instance_name** (str) - 用户设置的Primitive算子的实例的名称。
.. py:method:: set_stage(stage)
将stage的ID添加到Primitive属性中。
.. note::
仅在半自动并行模式下有效。在其他并行模式下请将其设置为0。
**参数:**
- **stage** (int) - 当前stage的ID。
.. py:method:: shard(in_strategy, out_strategy)
将切分策略添加到Primitive属性中。
.. note::
仅在半自动并行或自动并行模式下有效。在其他并行模式中,将忽略此处设置的策略。
**参数:**
- **in_strategy** (tuple) - 描述算子输入的切分策略。默认值None。
- **out_strategy** (tuple) - 描述算子输出的切分策略仅针对某些算子如MatMul。默认值None。
.. py:method:: update_parameter()
判断此Primitive是否会更新参数的值。

View File

@ -0,0 +1,41 @@
mindspore.ops.PrimitiveWithCheck
================================
.. py:class:: mindspore.PrimitiveWithCheck(name)
PrimitiveWithCheck是Python中原语的基类定义了检查算子输入参数的函数但是使用了C++源码中注册的推理方法。
可以重写三个方法来定义Primitive的检查逻辑 __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__()则__check__()的优先级最高。
如果未定义__check__()则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法如PrimitiveWithInfer用于常量传播。
**参数:**
- **name** (str) - 当前Primitive的名称。
.. py:method:: check_dtype(*args)
检查输入参数的数据类型。
**参数:**
- **args** (:class:`mindspore.dtype`) - 输入的数据类型。
**返回:**
None。
.. py:method:: check_shape(*args)
检查输入参数的shape。
.. note::
Scalar的shape是一个空元组。
**参数:**
- **args** (tuple(int)) - 输入tensor的shape。
**返回:**
None。

View File

@ -0,0 +1,53 @@
mindspore.ops.PrimitiveWithInfer
================================
.. py:class:: mindspore.PrimitiveWithInfer(name)
PrimitiveWithInfer是Python中的原语基类在python中定义了跟踪推理的函数。
可以重写四个方法来定义Primitive的推断逻辑__infer__()、infer_shape()、infer_dtype()和infer_value()。如果在Primitive中定义了__infer__()则__infer__()的优先级最高。
如果未定义__infer__()则可以定义infer_shape()和infer_dtype()来描述shape和类型的推断逻辑。infer_value()用于常量传播。
**参数:**
- **name** (str) - 当前Primitive的名称。
.. py:method:: infer_dtype(*args)
根据输入类型推断输出类型。
**参数:**
- **args** (:class:`mindspore.dtype`) - 输入的数据类型。
**返回:**
:class:`mindspore.dtype`,输出的数据类型。
.. py:method:: infer_shape(*args)
根据输入形状推断输出形状。
.. note::
Scalar的shape是一个空元组。
**参数:**
- **args** (tuple(int)) - 输入tensor的shape。
**返回:**
`tuple(int)`输出tensor的shape。
.. py:method:: infer_value(*args)
根据编译时的输入值推断输出值。
**参数:**
- **args** (Any) - 输入的值。
**返回:**
输出的值。如果编译时无法推断该值,返回`None`

View File

@ -0,0 +1,12 @@
mindspore.ops.constexpr
=======================
.. py:function:: mindspore.ops.constexpr(fn=None, get_instance=True, name=None):
创建PrimiveWithInfer算子用于在编译时推断值。可以用它定义函数从而使用构造函数中的常量计算出常量值。
**参数:**
- **fn** (function) - `fn` 用作输出算子的infer_value。默认值None。
- **get_instance** (bool) - 如果为True返回算子的实例否则返回算子的类。默认值True。
- **name** (str) - 定义算子的名称。如果 `name` 为None则使用函数名称作为算子名称。默认值None。

View File

@ -0,0 +1,17 @@
mindspore.ops.get_vm_impl_fn
============================
.. py:function:: mindspore.ops.get_vm_impl_fn(prim):
通过Primitive对象或Primitive名称获取虚拟实现函数。
**参数:**
- **prim** (Union[Primitive, str]) - 算子注册的Primitive对象或名称。
.. note::
该机制目前适用于调试。
**返回:**
函数,虚拟实现函数。

View File

@ -0,0 +1,16 @@
mindspore.ops.prim_attr_register
================================
.. py:function:: mindspore.ops.prim_attr_register(fn):
Primitive属性的注册器。
注册装饰器其中装饰器用于内置算子的Primitive的'__init__'函数。该函数将添加'__init__'的所有参数作为算子属性并且初始化Primitive的名称。
**参数:**
- **fn** (function) - Primitive的__init__函数。
**返回:**
函数,原始函数。

View File

@ -764,10 +764,6 @@ class DataType:
r""" r"""
Various combinations of dtype and format of Ascend ops. Various combinations of dtype and format of Ascend ops.
The current list below may be incomplete.
Please add it if necessary.
current support: current support:
.. code-block:: .. code-block::
@ -933,6 +929,9 @@ class DataType:
F64_HWCN = ("float64", "HWCN") F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC") F64_NDHWC = ("float64", "NDHWC")
F64_ChannelLast = ("float64", "ChannelLast") F64_ChannelLast = ("float64", "ChannelLast")
C64_Default = ("complex64", "DefaultFormat")
C128_Default = ("complex128", "DefaultFormat")
""" """
None_None = ("", "") None_None = ("", "")

View File

@ -381,8 +381,8 @@ class Primitive(Primitive_):
class PrimitiveWithCheck(Primitive): class PrimitiveWithCheck(Primitive):
""" """
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments
input arguments but used the infer method registered in c++ source codes. of operators, but uses the infer method registered in c++ source codes.
There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(), There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called. check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called.