diff --git a/docs/api/api_python/mindspore.ops.rst b/docs/api/api_python/mindspore.ops.rst index 306d672648a..6ef090716b4 100644 --- a/docs/api/api_python/mindspore.ops.rst +++ b/docs/api/api_python/mindspore.ops.rst @@ -31,7 +31,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支 mindspore.ops.constexpr mindspore.ops.custom_info_register - mindspore.ops.ms_hybrid + mindspore.ops.ms_kernel mindspore.ops.op_info_register mindspore.ops.prim_attr_register diff --git a/docs/api/api_python/ops/mindspore.ops.ms_hybrid.rst b/docs/api/api_python/ops/mindspore.ops.ms_kernel.rst similarity index 86% rename from docs/api/api_python/ops/mindspore.ops.ms_hybrid.rst rename to docs/api/api_python/ops/mindspore.ops.ms_kernel.rst index cc4fc5fe23e..0f5157a4492 100644 --- a/docs/api/api_python/ops/mindspore.ops.ms_hybrid.rst +++ b/docs/api/api_python/ops/mindspore.ops.ms_kernel.rst @@ -1,7 +1,7 @@ -mindspore.ops.ms_hybrid +mindspore.ops.ms_kernel ======================= -.. py:function:: mindspore.ops.ms_hybrid(fn=None, reg_info=None, compile_attrs=None) +.. py:function:: mindspore.ops.ms_kernel(fn=None, reg_info=None, compile_attrs=None) 用于MindSpore Hybrid DSL函数书写的装饰器。 给用MindSpore Hybrid DSL书写的函数加上此装饰器后,它可以用作一个普通的Python函数。 diff --git a/docs/api/api_python_en/mindspore.ops.rst b/docs/api/api_python_en/mindspore.ops.rst index 92380705db2..1fa45233ac4 100644 --- a/docs/api/api_python_en/mindspore.ops.rst +++ b/docs/api/api_python_en/mindspore.ops.rst @@ -31,7 +31,7 @@ Decorators mindspore.ops.constexpr mindspore.ops.custom_info_register - mindspore.ops.ms_hybrid + mindspore.ops.ms_kernel mindspore.ops.op_info_register mindspore.ops.prim_attr_register diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 1d1c3f61d19..22000118014 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -29,7 +29,7 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col, LoadIm2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle, ProdForceSeA) -from ._ms_hybrid import (ms_hybrid) +from ._ms_kernel import (ms_kernel, ms_hybrid) from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack, Diag, DiagPart, DType, ExpandDims, Eye, Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation, @@ -591,6 +591,7 @@ __sponge__ = [ ] __custom__ = [ + "ms_kernel", "ms_hybrid", ] diff --git a/mindspore/python/mindspore/ops/operations/_ms_hybrid.py b/mindspore/python/mindspore/ops/operations/_ms_kernel.py similarity index 81% rename from mindspore/python/mindspore/ops/operations/_ms_hybrid.py rename to mindspore/python/mindspore/ops/operations/_ms_kernel.py index 2a496e7bc3e..6ab0fcf4f3d 100644 --- a/mindspore/python/mindspore/ops/operations/_ms_hybrid.py +++ b/mindspore/python/mindspore/ops/operations/_ms_kernel.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""ms_hybrid decorator and related util functions""" +"""ms_kernel decorator and related util functions""" import ast import json from functools import wraps from itertools import product import numpy -from mindspore import context +from mindspore import context, log def _allocate(shape, dtype='float32', scope='global'): @@ -86,15 +86,15 @@ def _erf(x): p = 0.3275911 # A&S formula 7.1.26 - t = 1.0/(1.0 + p*x) - y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*numpy.exp(-x*x) - return sign*y # erf(-x) = -erf(x) + t = 1.0 / (1.0 + p * x) + y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * numpy.exp(-x * x) + return sign * y # erf(-x) = -erf(x) def _grid(extents): extents_list = [] for ext in extents: - extents_list.append(range(ext)) + extents_list.append(list(range(ext))) return product(*extents_list) @@ -120,89 +120,92 @@ class WithStub: return self -INTRIN_BUFFER = { - 'allocate': _allocate, - 'output_tensor': _allocate -} - -INTRIN_LOOP = { - 'range': range, - 'grid': _grid, -} - -INTRIN_WITH_SCOPE = { - 'attr': WithStub(), - 'block_realize': WithStub(), -} - -INTRIN_UNARY_OP = { - 'sqrt': numpy.sqrt, - 'sign': numpy.sign, - 'log': numpy.log, - 'tanh': numpy.tanh, - 'exp': numpy.exp, - 'abs': numpy.abs, - 'int32': numpy.int32, - 'float16': numpy.float16, - 'float32': numpy.float32, -} - -INTRIN_BINARY_OP = { - 'power': numpy.power, -} - -INTRIN_GLOBALS = { - **INTRIN_BUFFER, - **INTRIN_LOOP, - **INTRIN_WITH_SCOPE, - **INTRIN_UNARY_OP, - **INTRIN_BINARY_OP, -} - -INTRIN_GENERAL_UNARY_OP = { - 'rsqrt': _rsqrt, - 'erf': _erf, - 'isnan': numpy.isnan, - 'int8': numpy.int8, - 'int16': numpy.int16, - 'int64': numpy.int64, - 'float64': numpy.float64, - 'sin': numpy.sin, - 'cos': numpy.cos, - 'isinf': numpy.isinf, - 'isfinite': numpy.isfinite, - 'atan': numpy.arctan, - 'atan2': numpy.arctan2, - 'expm1': numpy.expm1, - 'floor': numpy.floor, - 'ceil': numpy.ceil, - 'trunc': numpy.trunc, - 'round': numpy.round, -} - -INTRIN_CPU_NOT_SUPPORT = ["atan2", "expm1", "float16"] - -INTRIN_GENERAL_BINARY_OP = { - 'ceil_div': lambda a, b: (a + b - 1) // b, -} - -INTRIN_GENERAL = { - **INTRIN_GENERAL_UNARY_OP, - **INTRIN_GENERAL_BINARY_OP -} - -INTRIN_RUNTIME = { - **INTRIN_GLOBALS, - **INTRIN_GENERAL -} - - class VariableUsage(ast.NodeVisitor): """ The ast visitor to perform static check for the source code, and determine the index of inplace assign outputs """ + intrin_buffer = { + 'allocate': _allocate, + 'output_tensor': _allocate + } + + intrin_loop = { + 'range': range, + 'serial': range, + 'vectorize': range, + 'parallel': range, + 'reduce': range, + 'grid': _grid, + } + + intrin_with_scope = { + 'attr': WithStub(), + 'block_realize': WithStub(), + } + + intrin_unary_op = { + 'sqrt': numpy.sqrt, + 'sign': numpy.sign, + 'log': numpy.log, + 'tanh': numpy.tanh, + 'exp': numpy.exp, + 'abs': numpy.abs, + 'int32': numpy.int32, + 'float16': numpy.float16, + 'float32': numpy.float32, + } + + intrin_bin_op = { + 'power': numpy.power, + } + + intrin_globals = { + **intrin_buffer, + **intrin_loop, + **intrin_with_scope, + **intrin_unary_op, + **intrin_bin_op, + } + + intrin_general_unary_op = { + 'rsqrt': _rsqrt, + 'erf': _erf, + 'isnan': numpy.isnan, + 'int8': numpy.int8, + 'int16': numpy.int16, + 'int64': numpy.int64, + 'float64': numpy.float64, + 'sin': numpy.sin, + 'cos': numpy.cos, + 'isinf': numpy.isinf, + 'isfinite': numpy.isfinite, + 'atan': numpy.arctan, + 'atan2': numpy.arctan2, + 'expm1': numpy.expm1, + 'floor': numpy.floor, + 'ceil': numpy.ceil, + 'trunc': numpy.trunc, + 'round': numpy.round, + } + + intrin_cpu_not_support = ["atan2", "expm1", "float16"] + + intrin_general_bin_op = { + 'ceil_div': lambda a, b: (a + b - 1) // b, + } + + intrin_general = { + **intrin_general_unary_op, + **intrin_general_bin_op + } + + intrin_runtime = { + **intrin_globals, + **intrin_general + } + def __init__(self, func_name): self.func_name = func_name self.scope_level = [] @@ -276,27 +279,27 @@ class VariableUsage(ast.NodeVisitor): """ Ast visitor for Call - Check the func call used in the DSL. Only those in INTRIN_RUNTIME are supported for now. + Check the func call used in the DSL. Only those in intrin_runtime are supported for now. """ func_id = node.func.id - if not (func_id in list(INTRIN_RUNTIME.keys()) + - ['max', 'min', 'len', 'ms_hybrid']): + if not (func_id in list(VariableUsage.intrin_runtime.keys()) + + ['max', 'min', 'len', 'ms_kernel']): raise ValueError( "In the function {} written in the Hybrid DSL, function call id {} " "not in intrinsics' list".format(self.func_name, func_id)) - if (self.device == "Ascend" and func_id in list(INTRIN_GENERAL.keys())) or \ - (self.device == "CPU" and func_id in INTRIN_CPU_NOT_SUPPORT): + if (self.device == "Ascend" and func_id in list(VariableUsage.intrin_general.keys())) or \ + (self.device == "CPU" and func_id in VariableUsage.intrin_cpu_not_support): raise ValueError( "In the function {} written in the Hybrid DSL, function {} is not available on the " "device {}".format(self.func_name, func_id, self.device)) - if func_id in list(INTRIN_UNARY_OP.keys()) + list(INTRIN_GENERAL_UNARY_OP.keys()) + list(INTRIN_LOOP.keys()) \ + if func_id in list(VariableUsage.intrin_unary_op.keys()) + list(VariableUsage.intrin_general_unary_op.keys()) \ and len(node.args) != 1: raise TypeError( "In the function {} written in the Hybrid DSL, function {} " "expects one input, but get {}".format(self.func_name, func_id, len(node.args))) - if func_id in list(INTRIN_BINARY_OP.keys()) + list(INTRIN_GENERAL_BINARY_OP.keys()) + \ - list(INTRIN_BUFFER.keys()) and len(node.args) != 2: + if func_id in list(VariableUsage.intrin_bin_op.keys()) + list(VariableUsage.intrin_general_bin_op.keys()) + \ + list(VariableUsage.intrin_buffer.keys()) and len(node.args) != 2: raise TypeError( "In the function {} written in the Hybrid DSL, function {} " "expects two inputs, but get {}".format(self.func_name, func_id, len(node.args))) @@ -470,10 +473,10 @@ def determine_variable_usage(root, func_name): return visitor.inplace_assign_output -def ms_hybrid(fn=None, reg_info=None, compile_attrs=None): +def ms_kernel(fn=None, reg_info=None, compile_attrs=None): """ The decorator of the Hybrid DSL function for the Custom Op. - When a function written by the Hybrid DSL is decorated by ms_hybrid, + When a function written by the Hybrid DSL is decorated by ms_kernel, it can be run as a usual Python function. Also, this function can be used in the api Custom and to create a Custom op, with func_type "hybrid" or "pyfunc". Creating a custom op with mode "hybrid" by the Hybrid DSL function @@ -495,7 +498,7 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None): Examples: >>> import numpy as np >>> from mindspore import ops, Tensor - >>> from mindspore.ops import ms_hybrid, DataType, CustomRegOp + >>> from mindspore.ops import ms_kernel, DataType, CustomRegOp ... >>> # Create a dict for the compile flags. >>> attrs = { @@ -516,9 +519,9 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None): >>> input_x = np.ones([4, 4]).astype(np.float32) >>> input_y = np.ones([4, 4]).astype(np.float32) ... - >>> # Write a Hybrid DSL function through the decorator @ms_hybrid. + >>> # Write a Hybrid DSL function through the decorator @ms_kernel. >>> # We can also pass the compile attrs and the reg info through the decorator. - >>> @ms_hybrid(reg_info=op_gpu_info, compile_attrs=attrs) + >>> @ms_kernel(reg_info=op_gpu_info, compile_attrs=attrs) ... def outer_product(a, b): ... c = output_tensor(a.shape, a.dtype) ... @@ -544,21 +547,25 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None): compile_attrs = {} if not isinstance(compile_attrs, dict): - raise TypeError("The input 'compile_attrs' of @ms_hybrid must be a dict, " + raise TypeError("The input 'compile_attrs' of @ms_kernel must be a dict, " "but get a {}".format(type(compile_attrs))) for key in compile_attrs.keys(): if not isinstance(key, str): - raise TypeError("The key of 'compile_attrs' of @ms_hybrid must be a str, " + raise TypeError("The key of 'compile_attrs' of @ms_kernel must be a str, " "but get a {}".format(type(key))) if reg_info is not None and not isinstance(reg_info, (str, dict, tuple)): raise TypeError( - "The input 'reg_info' of @ms_hybrid should be one of " + "The input 'reg_info' of @ms_kernel should be one of " "str, dict and tuple, but get a {}".format(type(reg_info))) - def wrap_ms_hybrid(func): - setattr(func, "ms_hybrid_flag", True) + def wrap_ms_kernel(func): + setattr(func, "ms_kernel_flag", True) + + # we enable ml scheduler automatically for ms_kernel function + compile_attrs["enable_mlsched"] = True + setattr(func, "compile_attrs", json.dumps(compile_attrs)) if reg_info is not None: setattr(func, "reg_info", reg_info) @@ -566,12 +573,25 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None): @wraps(func) def _patch_intrins_to_runtime(*args): _globals = func.__globals__ - for elem in list(INTRIN_RUNTIME.keys()): - _globals[elem] = INTRIN_RUNTIME[elem] + for elem in list(VariableUsage.intrin_runtime.keys()): + _globals[elem] = VariableUsage.intrin_runtime[elem] return func(*args) return _patch_intrins_to_runtime if fn is not None: - return wrap_ms_hybrid(fn) - return wrap_ms_hybrid + return wrap_ms_kernel(fn) + return wrap_ms_kernel + + +def ms_hybrid(fn=None, reg_info=None, compile_attrs=None): + """ + Same as docarator ms_kernel. ms_hybrid will be deprecated in the future. + Please use ms_kernel instead. + + Supported Platforms: + Deprecated + """ + log.warning("'ms_hybrid' is deprecated from version 1.8 and " + "will be removed in a future version, use 'ms_kernel' instead.") + return ms_kernel(fn, reg_info, compile_attrs) diff --git a/mindspore/python/mindspore/ops/operations/custom_ops.py b/mindspore/python/mindspore/ops/operations/custom_ops.py index 5ed21fd8994..7537a451d46 100644 --- a/mindspore/python/mindspore/ops/operations/custom_ops.py +++ b/mindspore/python/mindspore/ops/operations/custom_ops.py @@ -29,7 +29,7 @@ from mindspore.common import dtype as mstype from mindspore.ops import DataType from mindspore import log as logger from mindspore import ops -from ._ms_hybrid import determine_variable_usage +from ._ms_kernel import determine_variable_usage from ._custom_grad import autodiff_bprop from ._pyfunc_registry import add_pyfunc @@ -107,7 +107,7 @@ class Custom(ops.PrimitiveWithInfer): 1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar. 2. A TBE operator implementation function. 3. A pure python function - 4. An ms_hybrid decorated function written by the Hybrid DSL. + 4. An ms_kernel decorated function written by the Hybrid DSL. - str: If func is of str type, then str should be a path of file along with a function name. This could be used when func_type is "aot" or "julia". @@ -260,7 +260,7 @@ class Custom(ops.PrimitiveWithInfer): Examples: >>> import mindspore.ops as ops >>> import numpy as np - >>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, ms_hybrid + >>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, ms_kernel >>> from mindspore.common import dtype as mstype >>> from mindspore.nn import Cell >>> input_x = Tensor(np.ones([16, 16]).astype(np.float32)) @@ -270,8 +270,8 @@ class Custom(ops.PrimitiveWithInfer): >>> # This is the default func_type in Custom, >>> # and both out_shape and out_dtype can be None(default value). >>> # In this case, the input func must be a function written in the Hybrid DSL - >>> # and decorated by @ms_hybrid. - >>> @ms_hybrid + >>> # and decorated by @ms_kernel. + >>> @ms_kernel ... def outer_product_script(a, b): ... c = output_tensor(a.shape, a.dtype) ... for i0 in range(a.shape[0]): @@ -402,7 +402,7 @@ class Custom(ops.PrimitiveWithInfer): self.imply_path = "" self.func_source_str = "" self._func_compile_attrs = {} - self._is_ms_hybrid = False + self._is_ms_kernel = False self._check_func() self._update_func_info() @@ -488,20 +488,20 @@ class Custom(ops.PrimitiveWithInfer): elif self.func_type == "julia": self._check_julia_func() elif self.func_type == "hybrid": - if not hasattr(self.func, "ms_hybrid_flag"): - raise TypeError("{}, 'func' must be a function decorated by ms_hybrid".format(self.log_prefix)) - self._is_ms_hybrid = True + if not hasattr(self.func, "ms_kernel_flag"): + raise TypeError("{}, 'func' must be a function decorated by ms_kernel".format(self.log_prefix)) + self._is_ms_kernel = True self._func_compile_attrs = getattr(self.func, "compile_attrs", {}) elif self.func_type == "akg": - if hasattr(self.func, "ms_hybrid_flag"): + if hasattr(self.func, "ms_kernel_flag"): logger.warning("{}. To have a better user experience, the mode hybrid is suggested " - "for the input function with decorator @ms_hybrid. " + "for the input function with decorator @ms_kernel. " "To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix)) elif self.func_type == "pyfunc": - if hasattr(self.func, "ms_hybrid_flag"): - logger.warning("{}. Now you are using the function with decorator @ms_hybrid in the mode pyfunc. " + if hasattr(self.func, "ms_kernel_flag"): + logger.warning("{}. Now you are using the function with decorator @ms_kernel in the mode pyfunc. " "The kernel will be executed as a native python function, which might lead to " - "low efficiency. To accelerate the kernel, set the 'func_type' to be \"ms_hybrid\"" + "low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\"" .format(self.log_prefix)) else: if not callable(self.func): @@ -524,7 +524,7 @@ class Custom(ops.PrimitiveWithInfer): if index != -1: self.func_source_str = self.func_source_str[index:] - if self._is_ms_hybrid: + if self._is_ms_kernel: # static check for the Hybrid DSL in hybrid root = ast.parse(self.func_source_str) inplace_assign_output = determine_variable_usage(root, self.func_name) @@ -848,7 +848,7 @@ class Custom(ops.PrimitiveWithInfer): def _auto_infer(self, *args): """ - the automatic infer function for functions with @ms_hybrid decorator + the automatic infer function for functions with @ms_kernel decorator """ fake_input = [] enable_infer_value = True @@ -893,7 +893,7 @@ class Custom(ops.PrimitiveWithInfer): # deal with the case of ms script # enable auto infer function if any infer information is missing - if self._is_ms_hybrid and (infer_dtype is None or infer_shape is None): + if self._is_ms_kernel and (infer_dtype is None or infer_shape is None): logger.warning("{}, 'out_shape' or 'out_dtype' is None, infer the output shape and output dtype " "automatically. There might be some Python RuntimeWarning but it wouldn't influence the " "result.".format(self.log_prefix)) diff --git a/tests/st/ops/graph_kernel/custom/test_custom_akg.py b/tests/st/ops/graph_kernel/custom/test_custom_akg.py index b89b39c4b59..35dc5c0827e 100644 --- a/tests/st/ops/graph_kernel/custom/test_custom_akg.py +++ b/tests/st/ops/graph_kernel/custom/test_custom_akg.py @@ -334,6 +334,7 @@ def irbuilder_case(): raise ValueError("Precision error, compare result: {}".format(compare_res)) +@pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard @@ -347,6 +348,7 @@ def test_irbuilder_ascend_graph_mode(): irbuilder_case() +@pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard diff --git a/tests/st/ops/graph_kernel/custom/test_ms_hybrid.py b/tests/st/ops/graph_kernel/custom/test_ms_kernel.py similarity index 80% rename from tests/st/ops/graph_kernel/custom/test_ms_hybrid.py rename to tests/st/ops/graph_kernel/custom/test_ms_kernel.py index 5bd7987869b..5d92ff2bd32 100644 --- a/tests/st/ops/graph_kernel/custom/test_ms_hybrid.py +++ b/tests/st/ops/graph_kernel/custom/test_ms_kernel.py @@ -19,10 +19,10 @@ import numpy as np from mindspore import context, Tensor from mindspore.nn import Cell import mindspore.ops as ops -from mindspore.ops import ms_hybrid +from mindspore.ops import ms_kernel -@ms_hybrid +@ms_kernel def dtype_and_cast_example(a, b): """ test function for dtype and cast in Hybrid DSL @@ -38,7 +38,7 @@ def dtype_and_cast_example(a, b): return c -@ms_hybrid +@ms_kernel def allocate_and_math_intrin_example(a, b): """ test function for allocate and math function in Hybrid DSL @@ -53,7 +53,7 @@ def allocate_and_math_intrin_example(a, b): return c -@ms_hybrid +@ms_kernel def grid_example(a, b): """ test function for grid in Hybrid DSL @@ -77,7 +77,7 @@ class TestMsHybridDSL(Cell): return self.program(x, y) -def ms_hybrid_cast_with_infer(): +def ms_kernel_cast_with_infer(): """ test case Custom Op with functions written in Hybrid DSL and infer functions """ @@ -93,7 +93,7 @@ def ms_hybrid_cast_with_infer(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def ms_hybrid_cast_without_infer(): +def ms_kernel_cast_without_infer(): """ test case Custom Op with functions written in Hybrid DSL and without infer functions """ @@ -109,7 +109,7 @@ def ms_hybrid_cast_without_infer(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def ms_hybrid_cast_pyfunc(): +def ms_kernel_cast_pyfunc(): """ test case Custom Op with functions written in Hybrid DSL and func_type pyfunc """ @@ -125,7 +125,7 @@ def ms_hybrid_cast_pyfunc(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def ms_hybrid_allocate(): +def ms_kernel_allocate(): """ test case Custom Op with functions written in Hybrid DSL about math functions and allocate """ @@ -141,7 +141,7 @@ def ms_hybrid_allocate(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def ms_hybrid_allocate_cpu(): +def ms_kernel_allocate_cpu(): """ test case Custom Op with functions written in Hybrid DSL about math functions and allocate for cpu, we test fp32 to avoid env diff in support of data types. @@ -158,7 +158,7 @@ def ms_hybrid_allocate_cpu(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def ms_hybrid_grid(): +def ms_kernel_grid(): """ test case Custom Op with functions written in Hybrid DSL about grid """ @@ -174,7 +174,7 @@ def ms_hybrid_grid(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def ms_hybrid_grid_cpu(): +def ms_kernel_grid_cpu(): """ test case Custom Op with functions written in Hybrid DSL about grid """ @@ -194,79 +194,79 @@ def ms_hybrid_grid_cpu(): @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard -def test_ms_hybrid_ascend_graph_mode(): +def test_ms_kernel_ascend_graph_mode(): """ - Feature: test case for Custom op with func_type="ms_hybrid" - Description: ascend test case, Python DSL with ms_hybrid decorator in GRAPH_MODE. + Feature: test case for Custom op with func_type="ms_kernel" + Description: ascend test case, Python DSL with ms_kernel decorator in GRAPH_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - ms_hybrid_cast_pyfunc() - ms_hybrid_cast_with_infer() - ms_hybrid_cast_without_infer() - ms_hybrid_allocate() - ms_hybrid_grid() + ms_kernel_cast_pyfunc() + ms_kernel_cast_with_infer() + ms_kernel_cast_without_infer() + ms_kernel_allocate() + ms_kernel_grid() @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard -def test_ms_hybrid_ascend_pynative_mode(): +def test_ms_kernel_ascend_pynative_mode(): """ - Feature: test case for Custom op with func_type="ms_hybrid" - Description: ascend test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE. + Feature: test case for Custom op with func_type="ms_kernel" + Description: ascend test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - ms_hybrid_cast_pyfunc() - ms_hybrid_cast_with_infer() - ms_hybrid_cast_without_infer() - ms_hybrid_allocate() - ms_hybrid_grid() + ms_kernel_cast_pyfunc() + ms_kernel_cast_with_infer() + ms_kernel_cast_without_infer() + ms_kernel_allocate() + ms_kernel_grid() @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_ms_hybrid_gpu_graph_mode(): +def test_ms_kernel_gpu_graph_mode(): """ - Feature: test case for Custom op with func_type="ms_hybrid" - Description: gpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE. + Feature: test case for Custom op with func_type="ms_kernel" + Description: gpu test case, Python DSL with ms_kernel decorator in GRAPH_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms_hybrid_cast_pyfunc() - ms_hybrid_cast_with_infer() - ms_hybrid_cast_without_infer() - ms_hybrid_allocate() - ms_hybrid_grid() + ms_kernel_cast_pyfunc() + ms_kernel_cast_with_infer() + ms_kernel_cast_without_infer() + ms_kernel_allocate() + ms_kernel_grid() @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_ms_hybrid_gpu_pynative_mode(): +def test_ms_kernel_gpu_pynative_mode(): """ - Feature: test case for Custom op with func_type="ms_hybrid" - Description: gpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE. + Feature: test case for Custom op with func_type="ms_kernel" + Description: gpu test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - ms_hybrid_cast_pyfunc() - ms_hybrid_cast_with_infer() - ms_hybrid_cast_without_infer() - ms_hybrid_allocate() - ms_hybrid_grid() + ms_kernel_cast_pyfunc() + ms_kernel_cast_with_infer() + ms_kernel_cast_without_infer() + ms_kernel_allocate() + ms_kernel_grid() @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_ms_hybrid_cpu_graph_mode(): +def test_ms_kernel_cpu_graph_mode(): """ - Feature: test case for Custom op with func_type="ms_hybrid" - Description: cpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE. + Feature: test case for Custom op with func_type="ms_kernel" + Description: cpu test case, Python DSL with ms_kernel decorator in GRAPH_MODE. Expectation: the result match with numpy result """ if platform.system().lower() in {"windows", "darwin"}: @@ -274,22 +274,22 @@ def test_ms_hybrid_cpu_graph_mode(): pass else: context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - ms_hybrid_allocate_cpu() - ms_hybrid_grid_cpu() + ms_kernel_allocate_cpu() + ms_kernel_grid_cpu() @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_ms_hybrid_cpu_pynative_mode(): +def test_ms_kernel_cpu_pynative_mode(): """ - Feature: test case for Custom op with func_type="ms_hybrid" - Description: cpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE. + Feature: test case for Custom op with func_type="ms_kernel" + Description: cpu test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE. Expectation: the result match with numpy result """ if platform.system().lower() in {"windows", "darwin"}: pass else: context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") - ms_hybrid_allocate_cpu() - ms_hybrid_grid_cpu() + ms_kernel_allocate_cpu() + ms_kernel_grid_cpu()