forked from mindspore-Ecosystem/mindspore
refine doc of CustomRegOp
This commit is contained in:
parent
3f9db0313a
commit
06d93e0cc0
|
@ -0,0 +1,104 @@
|
||||||
|
mindspore.ops.CustomRegOp
|
||||||
|
=========================
|
||||||
|
|
||||||
|
.. py:class:: mindspore.ops.CustomRegOp(op_name)
|
||||||
|
|
||||||
|
用于为 :class:`mindspore.ops.Custom` 的 `func` 参数生成算子注册信息的类。注册信息主要指定了 `func` 的输入和输出Tensor所支持的数据类型和数据格式、属性以及target信息。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **op_name** (str) - 表示kernel名称。无需设置此值,因为 `Custom` 会为该参数自动生成唯一名称。默认值:"Custom"。
|
||||||
|
|
||||||
|
.. py:method:: input(index=None, name=None, param_type="required", **kwargs)
|
||||||
|
|
||||||
|
指定 :class:`mindspore.ops.Custom` 的 `func` 参数的输入Tensor信息。每次调用该函数都会产生一个输入Tensor信息,也就是说,如果 `func` 有两个输入Tensor,那么该函数应该被连续调用两次。输入Tensor信息将生成为一个字典:{"index": `index`, "name": `name`, "param_type": `param_type`}。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **index** (int) - 表示输入的索引,从0开始。0表示第一个输入Tensor,1表示第二个输入Tensor,以此类推。如果该值为None,键"index"将不会出现在输入Tensor信息字典中。默认值:None。
|
||||||
|
- **name** (str) - 表示第 `index` 个输入的名称。如果该值为None,键"name"将不会出现在输入Tensor信息字典中。默认值:None。
|
||||||
|
- **param_type** (str) - 表示第 `index` 个输入的参数类型,可以是["required", "dynamic", "optional"]之一。如果该值为None,键"param_type"将不会出现在输入Tensor信息字典中。默认值:"required"。
|
||||||
|
|
||||||
|
- "required": 表示第 `index` 个输入存在并且只能是单个Tensor。
|
||||||
|
- "dynamic": 表示第 `index` 个输入存在且Tensor个数可能为多个,比如AddN算子的输入属于这种情况。
|
||||||
|
- "optional": 表示第 `index` 个输入存在且为单个Tensor,或者也可能不存在。
|
||||||
|
|
||||||
|
- **kwargs** (dict) - 表示输入的其他信息,用于扩展。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `index` 既不是int也不是None。
|
||||||
|
- **TypeError** - `name` 既不是str也不是None。
|
||||||
|
- **TypeError** - `param_type` 既不是str也不是None。
|
||||||
|
|
||||||
|
.. py:method:: output(index=None, name=None, param_type="required", **kwargs)
|
||||||
|
|
||||||
|
指定 :class:`mindspore.ops.Custom` 的 `func` 参数的输出Tensor信息。每次调用该函数都会产生一个输出Tensor信息,也就是说,如果 `func` 有两个输出Tensor,那么该函数应该被连续调用两次。输出Tensor信息将生成为一个字典:{"index": `index`, "name": `name`, "param_type": `param_type`}。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **index** (int) - 表示输出的索引,从0开始。0表示第一个输出Tensor,1表示第二个输出Tensor,以此类推。如果该值为None,键"index"将不会出现在输出Tensor信息字典中。默认值:None。
|
||||||
|
- **name** (str) - 表示第 `index` 个输出的名称。如果该值为None,键"name"将不会出现在输出Tensor信息字典中。默认值:None。
|
||||||
|
- **param_type** (str) - 表示第 `index` 个输出的参数类型,可以是["required", "dynamic", "optional"]之一。如果该值为None,键"param_type"将不会出现在输出Tensor信息字典中。默认值:"required"。
|
||||||
|
|
||||||
|
- "required": 表示第 `index` 个输出存在并且只能是单个Tensor。
|
||||||
|
- "dynamic": 表示第 `index` 个输出存在且Tensor个数可能为多个。
|
||||||
|
- "optional": 表示第 `index` 个输出存在且为单个Tensor,或者也可能不存在。
|
||||||
|
|
||||||
|
- **kwargs** (dict) - 表示输出的其他信息,用于扩展。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `index` 既不是int也不是None。
|
||||||
|
- **TypeError** - `name` 既不是str也不是None。
|
||||||
|
- **TypeError** - `param_type` 既不是str也不是None。
|
||||||
|
|
||||||
|
.. py:method:: dtype_format(*args)
|
||||||
|
|
||||||
|
指定 :class:`mindspore.ops.Custom` 的 `func` 参数的每个输入Tensor和输出Tensor所支持的数据类型和数据格式。正如上面给出的样例,该函数应在 `input` 和 `output` 函数之后被调用。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **args** (tuple) - 表示(数据类型、格式)组合的列表,`args` 的长度应该等于输入Tensor和输出Tensor数目的总和。 `args` 中的每一项也是一个tuple,tuple[0]和tuple[1]都是str类型,分别指定了一个Tensor的数据类型和数据格式。 :class:`mindspore.ops.DataType` 提供了很多预定义的(数据类型、格式)组合,例如 `DataType.F16_Default` 表示数据类型是float16,数据格式是默认格式。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **ValueError** - `args` 的长度不等于输入Tensor和输出Tensor数目的总和。
|
||||||
|
|
||||||
|
.. py:method:: attr(name=None, param_type=None, value_type=None, default_value=None, **kwargs)
|
||||||
|
|
||||||
|
指定 :class:`mindspore.ops.Custom` 的 `func` 参数的属性信息。每次调用该函数都会产生一个属性信息,也就是说,如果 `func` 有两个属性,那么这个函数应该被连续调用两次。属性信息将生成为一个字典:{"name": `name`, "param_type": `param_type`, "value_type": `value_type`, "default_value": `default_value`}。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **name** (str) - 表示属性的名称。如果该值为None,键"index"将不会出现在属性信息字典中。默认值:None。
|
||||||
|
- **param_type** (str) - 表示属性的参数类型,可以是["required", "optional"]之一。如果该值为None,键"param_type"将不会出现在属性信息字典中。默认值:None。
|
||||||
|
|
||||||
|
- "required": 表示必须通过在注册信息中设置默认值的方式或者在调用自定义算子时提供输入值的方式来为此属性提供值。
|
||||||
|
- "optional": 表示不强制为此属性提供值。
|
||||||
|
|
||||||
|
- **value_type** (str) - 表示属性的值的类型,可以是["int", "str", "bool", "float", "listInt", "listStr", "listBool", "listFloat"]之一。如果该值为None,键"value_type"将不会出现在属性信息字典中。默认值:None。
|
||||||
|
|
||||||
|
- "int": Python int类型的字符串表示。
|
||||||
|
- "str": Python str类型的字符串表示。
|
||||||
|
- "bool": Python bool类型的字符串表示。
|
||||||
|
- "float": Python float类型的字符串表示。
|
||||||
|
- "listInt": Python list of int类型的字符串表示。
|
||||||
|
- "listStr": Python list of str类型的字符串表示。
|
||||||
|
- "listBool": Python list of bool类型的字符串表示。
|
||||||
|
- "listFloat": Python list of float类型的字符串表示。
|
||||||
|
|
||||||
|
- **default_value** (str) - 表示属性的默认值。 `default_value` 和 `value_type` 配合使用。如果属性实际的默认值为1.0,那么 `value_type` 是"float", `default_value` 是"1.0"。如果属性实际的默认值是[1, 2, 3],那么 `value_type` 是"listInt", `default_value` 是"1,2,3",其中数值通过','分割。如果该值为None,键"default_value"将不会出现在属性信息字典中。目前用于"akg"、"aicpu"和"tbe"类型的自定义算子。默认值:None。
|
||||||
|
- **kwargs** (dict) - 表示属性的其他信息,用于扩展。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `name` 既不是str也不是None。
|
||||||
|
- **TypeError** - `param_type` 既不是str也不是None。
|
||||||
|
- **TypeError** - `value_type` 既不是str也不是None。
|
||||||
|
- **TypeError** - `default_value` 既不是str也不是None。
|
||||||
|
|
||||||
|
.. py:method:: target(target=None)
|
||||||
|
|
||||||
|
指定当前注册信息所对应的target。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **target** (str) - 表示当前注册信息所对应的target,可以是["Ascend", "GPU", "CPU"]之一。 对于同一个 :class:`mindspore.ops.Custom` 的 `func` 参数,其在不同的target上可能支持不同的数据类型和数据格式,使用此参数指定注册信息用于哪个target。如果该值为None,它将在 :class:`mindspore.ops.Custom` 内部被自动推断。默认值:None。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `target` 既不是str也不是None。
|
||||||
|
|
||||||
|
.. py:method:: get_op_info()
|
||||||
|
|
||||||
|
将生成的注册信息以字典类型返回。正如上面给出的样例, `CustomRegOp` 实例最后调用该函数。
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -752,9 +752,11 @@ class TBERegOp(RegOp):
|
||||||
class CustomRegOp(RegOp):
|
class CustomRegOp(RegOp):
|
||||||
r"""
|
r"""
|
||||||
Class used for generating the registration information for the `func` parameter of :class:`mindspore.ops.Custom`.
|
Class used for generating the registration information for the `func` parameter of :class:`mindspore.ops.Custom`.
|
||||||
|
The registration information mainly specifies the supported data types and formats of input and output tensors,
|
||||||
|
attributes and target of `func`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op_name (str): kernel name. No need to set this value as `Custom`, operator will generate a unique name
|
op_name (str): kernel name. No need to set this value because `Custom` operator will generate a unique name
|
||||||
automatically. Default: "Custom".
|
automatically. Default: "Custom".
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -774,13 +776,31 @@ class CustomRegOp(RegOp):
|
||||||
|
|
||||||
def input(self, index=None, name=None, param_type="required", **kwargs):
|
def input(self, index=None, name=None, param_type="required", **kwargs):
|
||||||
"""
|
"""
|
||||||
Register Custom op input information.
|
Specifies the input tensor information for the `func` parameter of :class:`mindspore.ops.Custom`. Each
|
||||||
|
invocation of this function will generate one input tensor information, that means, if `func` has two input
|
||||||
|
tensors, then this function should be invoked two times continuously. The input tensor information will be
|
||||||
|
generated as a dict: {"index": `index`, "name": `name`, "param_type": `param_type`}.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index (int): Order of the input. Default: None.
|
index (int): Index of the input, starts from 0. 0 means the first input tensor, 1 means the second input
|
||||||
name (str): Name of the input. Default: None.
|
tensor and so on. If None, key "index" will not appear in the input tensor information dict.
|
||||||
param_type (str): Param type of the input. Default: "required".
|
Default: None.
|
||||||
kwargs (dict): Other information of the input.
|
name (str): Name of the `index` 'th input. If None, key "name" will not appear in the input tensor
|
||||||
|
information dict. Default: None.
|
||||||
|
param_type (str): Parameter type of the `index` 'th input, can be one of
|
||||||
|
["required", "dynamic", "optional"]. If None, key "param_type" will not appear in the input tensor
|
||||||
|
information dict. Default: "required".
|
||||||
|
|
||||||
|
- "required": means the `index` 'th input exist and can only be a single tensor.
|
||||||
|
- "dynamic": means the `index` 'th input exist and may be multiple tensors, such as the input of AddN.
|
||||||
|
- "optional": means the `index` 'th input may exist and be a single tensor or may not exist.
|
||||||
|
|
||||||
|
kwargs (dict): Other information of the input, used for extension.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `index` is neither int nor None.
|
||||||
|
TypeError: If `name` is neither str nor None.
|
||||||
|
TypeError: If `param_type` is neither str nor None.
|
||||||
"""
|
"""
|
||||||
param_list = [index, name, param_type]
|
param_list = [index, name, param_type]
|
||||||
key_list = ["index", "name", "param_type"]
|
key_list = ["index", "name", "param_type"]
|
||||||
|
@ -791,13 +811,31 @@ class CustomRegOp(RegOp):
|
||||||
|
|
||||||
def output(self, index=None, name=None, param_type="required", **kwargs):
|
def output(self, index=None, name=None, param_type="required", **kwargs):
|
||||||
"""
|
"""
|
||||||
Register Custom op output information.
|
Specifies the output tensor information for the `func` parameter of :class:`mindspore.ops.Custom`. Each
|
||||||
|
invocation of this function will generate one output tensor information, which means, if `func` has two output
|
||||||
|
tensors, then this function should be invoked two times continuously. The output tensor information will be
|
||||||
|
generated as a dict: {"index": `index`, "name": `name`, "param_type": `param_type`}.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index (int): Order of the output. Default: None.
|
index (int): Index of the output, starts from 0. 0 means the first output tensor, 1 means the second output
|
||||||
name (str): Name of the output. Default: None.
|
tensor and so on. If None, key "index" will not appear in the output tensor information dict.
|
||||||
param_type (str): Param type of the output. Default: "required".
|
Default: None.
|
||||||
kwargs (dict): Other information of the output.
|
name (str): Name of the `index` 'th output. If None, key "name" will not appear in the output tensor
|
||||||
|
information dict. Default: None.
|
||||||
|
param_type (str): Parameter type of the `index` 'th output, can be one of
|
||||||
|
["required", "dynamic", "optional"]. If None, key "param_type" will not appear in the output tensor
|
||||||
|
information dict. Default: "required".
|
||||||
|
|
||||||
|
- "required": means the `index` 'th output exist and can only be a single tensor.
|
||||||
|
- "dynamic": means the `index` 'th output exist and may be multiple tensors.
|
||||||
|
- "optional": means the `index` 'th output may exist and be a single tensor or may not exist.
|
||||||
|
|
||||||
|
kwargs (dict): Other information of the output, used for extension.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `index` is neither int nor None.
|
||||||
|
TypeError: If `name` is neither str nor None.
|
||||||
|
TypeError: If `param_type` is neither str nor None.
|
||||||
"""
|
"""
|
||||||
param_list = [index, name, param_type]
|
param_list = [index, name, param_type]
|
||||||
key_list = ["index", "name", "param_type"]
|
key_list = ["index", "name", "param_type"]
|
||||||
|
@ -806,17 +844,73 @@ class CustomRegOp(RegOp):
|
||||||
self.outputs.append(output_dict)
|
self.outputs.append(output_dict)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def attr(self, name=None, param_type=None, value_type=None, default_value=None, **kwargs):
|
def dtype_format(self, *args):
|
||||||
"""
|
"""
|
||||||
Register Custom op attribute information.
|
Specifies the supported data type and format of each input tensor and output tensor for the `func` parameter
|
||||||
|
of :class:`mindspore.ops.Custom`. This function should be invoked after `input` and `output` function as shown
|
||||||
|
in the above example.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Name of the attribute. Default: None.
|
args (tuple): A tuple of (data type, format) pair, the length of `args` should be equal to the sum of input
|
||||||
param_type (str): Param type of the attribute. Default: None.
|
tensors and output tensors. Each item in `args` is also a tuple, tuple[0] and tuple[1] are both str
|
||||||
value_type (str): Value type of the attribute. Default: None.
|
type, which specifies the data type and format of a tensor respectively. :class:`mindspore.ops.DataType`
|
||||||
default_value (str): Default value of attribute. If value is a list, each item should split by ','.
|
provides many predefined (data type, format) combinations, for example, `DataType.F16_Default` means the
|
||||||
For example, if `value_type` is "listInt", then `default_value` can be "1,2,3". Default: None.
|
data type is float16 and the format is default format.
|
||||||
kwargs (dict): Other information of the attribute.
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the size of `args` not equal to the sum of input tensors and output tensors.
|
||||||
|
"""
|
||||||
|
io_nums = len(self.inputs) + len(self.outputs)
|
||||||
|
if len(args) != io_nums:
|
||||||
|
raise ValueError("The size of 'args' must be equal to the sum of input tensors and output tensors, but got "
|
||||||
|
"{} vs {}".format(len(args), io_nums))
|
||||||
|
return super(CustomRegOp, self).dtype_format(*args)
|
||||||
|
|
||||||
|
def attr(self, name=None, param_type=None, value_type=None, default_value=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Specifies the attributes information for the `func` parameter of :class:`mindspore.ops.Custom`. Each
|
||||||
|
invocation of this function will generate one attribute information, that means, if `func` has two attributes,
|
||||||
|
then this function should be invoked two times continuously. The attributes information will be
|
||||||
|
generated as a dict: {"name": `name`, "param_type": `param_type`, "value_type": `value_type`, "default_value":
|
||||||
|
`default_value`}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the attribute. If None, key "name" will not appear in the attributes tensor information
|
||||||
|
dict. Default: None.
|
||||||
|
param_type (str): Parameter type of the attribute, can be one of ["required", "optional"]. If None, key
|
||||||
|
"param_type" will not appear in the attributes tensor information dict. Default: None.
|
||||||
|
|
||||||
|
- "required": means must provide a value for this attribute either by setting a default value in the
|
||||||
|
registration information or providing an input value when calling the Custom operator.
|
||||||
|
- "optional": means does not have to provide a value for this attribute.
|
||||||
|
|
||||||
|
value_type (str): Value type of the attribute, can be one of ["int", "str", "bool", "float", "listInt",
|
||||||
|
"listStr", "listBool", "listFloat"]. If None, key "value_type" will not appear in the attributes tensor
|
||||||
|
information dict. Default: None.
|
||||||
|
|
||||||
|
- "int": string representation of Python type int.
|
||||||
|
- "str": string representation of Python type str.
|
||||||
|
- "bool": string representation of Python type bool.
|
||||||
|
- "float": string representation of Python type float.
|
||||||
|
- "listInt": string representation of Python type list of int.
|
||||||
|
- "listStr": string representation of Python type list of str.
|
||||||
|
- "listBool": string representation of Python type list of bool.
|
||||||
|
- "listFloat": string representation of Python type list of float.
|
||||||
|
|
||||||
|
default_value (str): Default value of the attribute. `default_value` and `value_type` are used together.
|
||||||
|
If the real default value of the attribute is float type with value 1.0, then the `value_type` should be
|
||||||
|
"float" and `default_value` should be "1.0". If the real default value of the attribute is a list of int
|
||||||
|
with value [1, 2, 3], then the `value_type` should be "listInt" and `default_value` should be "1,2,3",
|
||||||
|
each item should split by ','. If None, means the attribute has no default value and key "default_value"
|
||||||
|
will not appear in the attributes tensor information dict. It is used for "akg", "aicpu" and "tbe"
|
||||||
|
Custom operators currently. Default: None.
|
||||||
|
kwargs (dict): Other information of the attribute, used for extension.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `name` is neither str nor None.
|
||||||
|
TypeError: If `param_type` is neither str nor None.
|
||||||
|
TypeError: If `value_type` is neither str nor None.
|
||||||
|
TypeError: If `default_value` is neither str nor None.
|
||||||
"""
|
"""
|
||||||
param_list = [name, param_type, value_type, default_value]
|
param_list = [name, param_type, value_type, default_value]
|
||||||
key_list = ["name", "param_type", "type", "default_value"]
|
key_list = ["name", "param_type", "type", "default_value"]
|
||||||
|
@ -827,16 +921,34 @@ class CustomRegOp(RegOp):
|
||||||
|
|
||||||
def target(self, target=None):
|
def target(self, target=None):
|
||||||
"""
|
"""
|
||||||
Register Custom op target information.
|
Specifies the target that this registration information is used for.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target (str): Device target for current operator information, should be one of ["Ascend", "GPU", "CPU"].
|
target (str): Device target for current operator information, should be one of ["Ascend", "GPU", "CPU"].
|
||||||
Default: None.
|
For the same `func` of :class:`mindspore.ops.Custom`, it may support different data types and formats
|
||||||
|
on different targets, use `target` to specify which target that this registration information is used
|
||||||
|
for. If None, it will be inferred automatically inside :class:`mindspore.ops.Custom`. Default: None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `target` is neither str nor None.
|
||||||
"""
|
"""
|
||||||
self._is_string(target)
|
if target is not None:
|
||||||
|
self._is_string(target)
|
||||||
self.target_ = target
|
self.target_ = target
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_op_info(self):
|
||||||
|
"""
|
||||||
|
Return the generated registration information as a dict. This function should be invoked at last on the
|
||||||
|
`CustomRegOp` instance as shown in the above example.
|
||||||
|
"""
|
||||||
|
op_info = {}
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if isinstance(k, str) and k.endswith('_'):
|
||||||
|
k = k.rstrip('_')
|
||||||
|
op_info[k] = v
|
||||||
|
return op_info
|
||||||
|
|
||||||
|
|
||||||
class DataType:
|
class DataType:
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue