!17789 update api doc

From: @lianliguang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-06-04 17:46:31 +08:00 committed by Gitee
commit cf295a00bb
2 changed files with 35 additions and 8 deletions

View File

@ -26,7 +26,7 @@ from . import signature as sig
class Primitive(Primitive_):
"""
Primitive is the base class of primitives in python.
Primitive is the base class of operator primitives in python.
Args:
name (str): Name for the current Primitive.
@ -175,14 +175,26 @@ class Primitive(Primitive_):
def check_elim(self, *args):
"""
Check if certain inputs should go to the backend. Subclass in need should override this method.
Check if the primitive can be eliminated. Subclass in need should override this method.
Args:
args(Primitive args): Same as arguments of current Primitive.
Returns:
A tuple consisting of two elements. The first element indicates whether we should filter out current
arguments; the second element is the output if we need to filter out the arguments.
A tuple consisting of two elements.
The first element means if the primitive can be calculated in compiling stage
the second element is calculated result .
Examples:
>>> class AddN(Primitive):
>>> @prim_attr_register
>>> def __init__(self):
>>> self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
>>> def check_elim(self, inputs):
>>> if len(inputs) != 1:
>>> return (False, None)
>>> if isinstance(inputs[0], Tensor):
>>> return (True, inputs[0])
"""
return (False, None)
@ -223,7 +235,7 @@ class Primitive(Primitive_):
@property
def update_parameter(self):
""" Whether the primitive will update the value of parameter."""
"""Return whether the primitive will update the value of parameter."""
return self._update_parameter
def recompute(self, mode=True):
@ -321,7 +333,7 @@ class PrimitiveWithCheck(Primitive):
return None
def __check__(self, *args):
"""Check shape, type, and value at the same time by using dictionary as arguments."""
"""Checking the input shape and the input type of ops is valid """
tracks = ['dtype', 'shape']
for track in tracks:
fn = getattr(self, 'check_' + track)
@ -478,13 +490,23 @@ def prim_attr_register(fn):
Primitive attributes register.
Register the decorator of the built-in operator primitive '__init__'.
The function will add all the parameters of '__init__' as operator attributes.
The function will add all the parameters of '__init__' as operator attributes ,
and init primtive name.
Args:
fn (function): __init__ function of primitive.
Returns:
function, original function.
Examples:
>>> class MatMul(PrimitiveWithCheck):
>>> @prim_attr_register
>>> def __init__(self, transpose_a=False, transpose_b=False):
>>> self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
>>> cls_name = self.name
>>> validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
>>> validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
"""
def deco(self, *args, **kwargs):
@ -539,10 +561,12 @@ def constexpr(fn=None, get_instance=True, name=None):
def deco(fn):
"""Decorator for CompileOp."""
class CompileOp(PrimitiveWithInfer):
"""
CompileOp is a temporary operator used to execute the constexpr function.
"""
def __init__(self):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)

View File

@ -42,6 +42,9 @@ def get_vm_impl_fn(prim):
Args:
prim (Union[Primitive, str]): primitive object or name for operator register.
Note:
This mechanism applied for debugging currently.
Returns:
function, vm function
"""