forked from mindspore-Ecosystem/mindspore
!17789 update api doc
From: @lianliguang Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
cf295a00bb
|
@ -26,7 +26,7 @@ from . import signature as sig
|
||||||
|
|
||||||
class Primitive(Primitive_):
|
class Primitive(Primitive_):
|
||||||
"""
|
"""
|
||||||
Primitive is the base class of primitives in python.
|
Primitive is the base class of operator primitives in python.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Name for the current Primitive.
|
name (str): Name for the current Primitive.
|
||||||
|
@ -175,15 +175,27 @@ class Primitive(Primitive_):
|
||||||
|
|
||||||
def check_elim(self, *args):
|
def check_elim(self, *args):
|
||||||
"""
|
"""
|
||||||
Check if certain inputs should go to the backend. Subclass in need should override this method.
|
Check if the primitive can be eliminated. Subclass in need should override this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args(Primitive args): Same as arguments of current Primitive.
|
args(Primitive args): Same as arguments of current Primitive.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple consisting of two elements. The first element indicates whether we should filter out current
|
A tuple consisting of two elements.
|
||||||
arguments; the second element is the output if we need to filter out the arguments.
|
The first element means if the primitive can be calculated in compiling stage
|
||||||
"""
|
the second element is calculated result .
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class AddN(Primitive):
|
||||||
|
>>> @prim_attr_register
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||||
|
>>> def check_elim(self, inputs):
|
||||||
|
>>> if len(inputs) != 1:
|
||||||
|
>>> return (False, None)
|
||||||
|
>>> if isinstance(inputs[0], Tensor):
|
||||||
|
>>> return (True, inputs[0])
|
||||||
|
"""
|
||||||
return (False, None)
|
return (False, None)
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
|
@ -223,7 +235,7 @@ class Primitive(Primitive_):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def update_parameter(self):
|
def update_parameter(self):
|
||||||
""" Whether the primitive will update the value of parameter."""
|
"""Return whether the primitive will update the value of parameter."""
|
||||||
return self._update_parameter
|
return self._update_parameter
|
||||||
|
|
||||||
def recompute(self, mode=True):
|
def recompute(self, mode=True):
|
||||||
|
@ -321,7 +333,7 @@ class PrimitiveWithCheck(Primitive):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __check__(self, *args):
|
def __check__(self, *args):
|
||||||
"""Check shape, type, and value at the same time by using dictionary as arguments."""
|
"""Checking the input shape and the input type of ops is valid """
|
||||||
tracks = ['dtype', 'shape']
|
tracks = ['dtype', 'shape']
|
||||||
for track in tracks:
|
for track in tracks:
|
||||||
fn = getattr(self, 'check_' + track)
|
fn = getattr(self, 'check_' + track)
|
||||||
|
@ -478,13 +490,23 @@ def prim_attr_register(fn):
|
||||||
Primitive attributes register.
|
Primitive attributes register.
|
||||||
|
|
||||||
Register the decorator of the built-in operator primitive '__init__'.
|
Register the decorator of the built-in operator primitive '__init__'.
|
||||||
The function will add all the parameters of '__init__' as operator attributes.
|
The function will add all the parameters of '__init__' as operator attributes ,
|
||||||
|
and init primtive name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn (function): __init__ function of primitive.
|
fn (function): __init__ function of primitive.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function, original function.
|
function, original function.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class MatMul(PrimitiveWithCheck):
|
||||||
|
>>> @prim_attr_register
|
||||||
|
>>> def __init__(self, transpose_a=False, transpose_b=False):
|
||||||
|
>>> self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||||
|
>>> cls_name = self.name
|
||||||
|
>>> validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||||
|
>>> validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def deco(self, *args, **kwargs):
|
def deco(self, *args, **kwargs):
|
||||||
|
@ -539,10 +561,12 @@ def constexpr(fn=None, get_instance=True, name=None):
|
||||||
|
|
||||||
def deco(fn):
|
def deco(fn):
|
||||||
"""Decorator for CompileOp."""
|
"""Decorator for CompileOp."""
|
||||||
|
|
||||||
class CompileOp(PrimitiveWithInfer):
|
class CompileOp(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
CompileOp is a temporary operator used to execute the constexpr function.
|
CompileOp is a temporary operator used to execute the constexpr function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
op_name = name if name else fn.__name__
|
op_name = name if name else fn.__name__
|
||||||
PrimitiveWithInfer.__init__(self, op_name)
|
PrimitiveWithInfer.__init__(self, op_name)
|
||||||
|
|
|
@ -42,6 +42,9 @@ def get_vm_impl_fn(prim):
|
||||||
Args:
|
Args:
|
||||||
prim (Union[Primitive, str]): primitive object or name for operator register.
|
prim (Union[Primitive, str]): primitive object or name for operator register.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This mechanism applied for debugging currently.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function, vm function
|
function, vm function
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue