!17789 update api doc

From: @lianliguang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-06-04 17:46:31 +08:00 committed by Gitee
commit cf295a00bb
2 changed files with 35 additions and 8 deletions

View File

@ -26,7 +26,7 @@ from . import signature as sig
class Primitive(Primitive_): class Primitive(Primitive_):
""" """
Primitive is the base class of primitives in python. Primitive is the base class of operator primitives in python.
Args: Args:
name (str): Name for the current Primitive. name (str): Name for the current Primitive.
@ -175,15 +175,27 @@ class Primitive(Primitive_):
def check_elim(self, *args): def check_elim(self, *args):
""" """
Check if certain inputs should go to the backend. Subclass in need should override this method. Check if the primitive can be eliminated. Subclass in need should override this method.
Args: Args:
args(Primitive args): Same as arguments of current Primitive. args(Primitive args): Same as arguments of current Primitive.
Returns: Returns:
A tuple consisting of two elements. The first element indicates whether we should filter out current A tuple consisting of two elements.
arguments; the second element is the output if we need to filter out the arguments. The first element means if the primitive can be calculated in compiling stage
""" the second element is calculated result .
Examples:
>>> class AddN(Primitive):
>>> @prim_attr_register
>>> def __init__(self):
>>> self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
>>> def check_elim(self, inputs):
>>> if len(inputs) != 1:
>>> return (False, None)
>>> if isinstance(inputs[0], Tensor):
>>> return (True, inputs[0])
"""
return (False, None) return (False, None)
def __call__(self, *args): def __call__(self, *args):
@ -223,7 +235,7 @@ class Primitive(Primitive_):
@property @property
def update_parameter(self): def update_parameter(self):
""" Whether the primitive will update the value of parameter.""" """Return whether the primitive will update the value of parameter."""
return self._update_parameter return self._update_parameter
def recompute(self, mode=True): def recompute(self, mode=True):
@ -321,7 +333,7 @@ class PrimitiveWithCheck(Primitive):
return None return None
def __check__(self, *args): def __check__(self, *args):
"""Check shape, type, and value at the same time by using dictionary as arguments.""" """Checking the input shape and the input type of ops is valid """
tracks = ['dtype', 'shape'] tracks = ['dtype', 'shape']
for track in tracks: for track in tracks:
fn = getattr(self, 'check_' + track) fn = getattr(self, 'check_' + track)
@ -478,13 +490,23 @@ def prim_attr_register(fn):
Primitive attributes register. Primitive attributes register.
Register the decorator of the built-in operator primitive '__init__'. Register the decorator of the built-in operator primitive '__init__'.
The function will add all the parameters of '__init__' as operator attributes. The function will add all the parameters of '__init__' as operator attributes ,
and init primtive name.
Args: Args:
fn (function): __init__ function of primitive. fn (function): __init__ function of primitive.
Returns: Returns:
function, original function. function, original function.
Examples:
>>> class MatMul(PrimitiveWithCheck):
>>> @prim_attr_register
>>> def __init__(self, transpose_a=False, transpose_b=False):
>>> self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
>>> cls_name = self.name
>>> validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
>>> validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
""" """
def deco(self, *args, **kwargs): def deco(self, *args, **kwargs):
@ -539,10 +561,12 @@ def constexpr(fn=None, get_instance=True, name=None):
def deco(fn): def deco(fn):
"""Decorator for CompileOp.""" """Decorator for CompileOp."""
class CompileOp(PrimitiveWithInfer): class CompileOp(PrimitiveWithInfer):
""" """
CompileOp is a temporary operator used to execute the constexpr function. CompileOp is a temporary operator used to execute the constexpr function.
""" """
def __init__(self): def __init__(self):
op_name = name if name else fn.__name__ op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name) PrimitiveWithInfer.__init__(self, op_name)

View File

@ -42,6 +42,9 @@ def get_vm_impl_fn(prim):
Args: Args:
prim (Union[Primitive, str]): primitive object or name for operator register. prim (Union[Primitive, str]): primitive object or name for operator register.
Note:
This mechanism applied for debugging currently.
Returns: Returns:
function, vm function function, vm function
""" """