From de5e96618b9ef2d97d1d8b86771e0eb746101417 Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Thu, 30 Sep 2021 10:32:24 +0800 Subject: [PATCH] update api documents --- mindspore/ops/primitive.py | 8 ++++---- mindspore/ops/vm_impl_registry.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index a96308506fe..e52aa28155a 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -343,8 +343,8 @@ class PrimitiveWithCheck(Primitive): PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments but used the infer method registered in c++ source codes. - There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(), - check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. + There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(), + check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called. If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation. @@ -424,8 +424,8 @@ class PrimitiveWithInfer(Primitive): PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python. - There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(), - infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority + There are four method can be overridden to define the infer logic of the primitive: __infer__(), infer_shape(), + infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has the highest priority to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer logic of the shape and type. The infer_value() is used for constant propagation. diff --git a/mindspore/ops/vm_impl_registry.py b/mindspore/ops/vm_impl_registry.py index c7480b039c9..3273b616612 100644 --- a/mindspore/ops/vm_impl_registry.py +++ b/mindspore/ops/vm_impl_registry.py @@ -47,6 +47,21 @@ def get_vm_impl_fn(prim): Returns: function, vm function + + Examples: + >>> from mindspore.ops import vm_impl_registry + >>> from mindspore.ops.vm_impl_registry import get_vm_impl_fn + ... + >>> @vm_impl_registry.register("Type") + >>> def vm_impl_dtype(self): + ... def vm_impl(x): + ... return type(x) + ... return vm_impl + ... + >>> fn = get_vm_impl_fn("Type") + >>> out = fn(1.0) + >>> print(out) + """ out = vm_impl_registry.get(prim, None) if out: