forked from mindspore-Ecosystem/mindspore
!24453 udpate api document
Merge pull request !24453 from huangbingjian/code_docs_api
This commit is contained in:
commit
c863bd16c7
|
@ -343,8 +343,8 @@ class PrimitiveWithCheck(Primitive):
|
|||
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator
|
||||
input arguments but used the infer method registered in c++ source codes.
|
||||
|
||||
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
|
||||
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
|
||||
There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(),
|
||||
check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called.
|
||||
If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
|
||||
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
|
||||
|
||||
|
@ -424,8 +424,8 @@ class PrimitiveWithInfer(Primitive):
|
|||
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference
|
||||
in python.
|
||||
|
||||
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
|
||||
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
|
||||
There are four method can be overridden to define the infer logic of the primitive: __infer__(), infer_shape(),
|
||||
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has the highest priority
|
||||
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
|
||||
logic of the shape and type. The infer_value() is used for constant propagation.
|
||||
|
||||
|
|
|
@ -47,6 +47,21 @@ def get_vm_impl_fn(prim):
|
|||
|
||||
Returns:
|
||||
function, vm function
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import vm_impl_registry
|
||||
>>> from mindspore.ops.vm_impl_registry import get_vm_impl_fn
|
||||
...
|
||||
>>> @vm_impl_registry.register("Type")
|
||||
>>> def vm_impl_dtype(self):
|
||||
... def vm_impl(x):
|
||||
... return type(x)
|
||||
... return vm_impl
|
||||
...
|
||||
>>> fn = get_vm_impl_fn("Type")
|
||||
>>> out = fn(1.0)
|
||||
>>> print(out)
|
||||
<class 'float'>
|
||||
"""
|
||||
out = vm_impl_registry.get(prim, None)
|
||||
if out:
|
||||
|
|
Loading…
Reference in New Issue