!24453 udpate api document

Merge pull request !24453 from huangbingjian/code_docs_api
This commit is contained in:
i-robot 2021-09-30 03:24:56 +00:00 committed by Gitee
commit c863bd16c7
2 changed files with 19 additions and 4 deletions

View File

@ -343,8 +343,8 @@ class PrimitiveWithCheck(Primitive):
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator
input arguments but used the infer method registered in c++ source codes.
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called.
If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
@ -424,8 +424,8 @@ class PrimitiveWithInfer(Primitive):
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference
in python.
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
There are four method can be overridden to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has the highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
logic of the shape and type. The infer_value() is used for constant propagation.

View File

@ -47,6 +47,21 @@ def get_vm_impl_fn(prim):
Returns:
function, vm function
Examples:
>>> from mindspore.ops import vm_impl_registry
>>> from mindspore.ops.vm_impl_registry import get_vm_impl_fn
...
>>> @vm_impl_registry.register("Type")
>>> def vm_impl_dtype(self):
... def vm_impl(x):
... return type(x)
... return vm_impl
...
>>> fn = get_vm_impl_fn("Type")
>>> out = fn(1.0)
>>> print(out)
<class 'float'>
"""
out = vm_impl_registry.get(prim, None)
if out: