update ms_kernel related files

fix import problem

update code check

update test case

update test case

rename file
This commit is contained in:
Zichun Ye 2022-06-08 16:32:34 +08:00
parent b041966ea2
commit 8de24b4d2f
8 changed files with 203 additions and 180 deletions

View File

@ -31,7 +31,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
mindspore.ops.constexpr
mindspore.ops.custom_info_register
mindspore.ops.ms_hybrid
mindspore.ops.ms_kernel
mindspore.ops.op_info_register
mindspore.ops.prim_attr_register

View File

@ -1,7 +1,7 @@
mindspore.ops.ms_hybrid
mindspore.ops.ms_kernel
=======================
.. py:function:: mindspore.ops.ms_hybrid(fn=None, reg_info=None, compile_attrs=None)
.. py:function:: mindspore.ops.ms_kernel(fn=None, reg_info=None, compile_attrs=None)
用于MindSpore Hybrid DSL函数书写的装饰器。
给用MindSpore Hybrid DSL书写的函数加上此装饰器后它可以用作一个普通的Python函数。

View File

@ -31,7 +31,7 @@ Decorators
mindspore.ops.constexpr
mindspore.ops.custom_info_register
mindspore.ops.ms_hybrid
mindspore.ops.ms_kernel
mindspore.ops.op_info_register
mindspore.ops.prim_attr_register

View File

@ -29,7 +29,7 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,
LoadIm2Col, UpdateThorGradient, Cholesky, CholeskyTrsm,
DetTriangle, ProdForceSeA)
from ._ms_hybrid import (ms_hybrid)
from ._ms_kernel import (ms_kernel, ms_hybrid)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation,
@ -591,6 +591,7 @@ __sponge__ = [
]
__custom__ = [
"ms_kernel",
"ms_hybrid",
]

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ms_hybrid decorator and related util functions"""
"""ms_kernel decorator and related util functions"""
import ast
import json
from functools import wraps
from itertools import product
import numpy
from mindspore import context
from mindspore import context, log
def _allocate(shape, dtype='float32', scope='global'):
@ -86,15 +86,15 @@ def _erf(x):
p = 0.3275911
# A&S formula 7.1.26
t = 1.0/(1.0 + p*x)
y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*numpy.exp(-x*x)
return sign*y # erf(-x) = -erf(x)
t = 1.0 / (1.0 + p * x)
y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * numpy.exp(-x * x)
return sign * y # erf(-x) = -erf(x)
def _grid(extents):
extents_list = []
for ext in extents:
extents_list.append(range(ext))
extents_list.append(list(range(ext)))
return product(*extents_list)
@ -120,89 +120,92 @@ class WithStub:
return self
INTRIN_BUFFER = {
'allocate': _allocate,
'output_tensor': _allocate
}
INTRIN_LOOP = {
'range': range,
'grid': _grid,
}
INTRIN_WITH_SCOPE = {
'attr': WithStub(),
'block_realize': WithStub(),
}
INTRIN_UNARY_OP = {
'sqrt': numpy.sqrt,
'sign': numpy.sign,
'log': numpy.log,
'tanh': numpy.tanh,
'exp': numpy.exp,
'abs': numpy.abs,
'int32': numpy.int32,
'float16': numpy.float16,
'float32': numpy.float32,
}
INTRIN_BINARY_OP = {
'power': numpy.power,
}
INTRIN_GLOBALS = {
**INTRIN_BUFFER,
**INTRIN_LOOP,
**INTRIN_WITH_SCOPE,
**INTRIN_UNARY_OP,
**INTRIN_BINARY_OP,
}
INTRIN_GENERAL_UNARY_OP = {
'rsqrt': _rsqrt,
'erf': _erf,
'isnan': numpy.isnan,
'int8': numpy.int8,
'int16': numpy.int16,
'int64': numpy.int64,
'float64': numpy.float64,
'sin': numpy.sin,
'cos': numpy.cos,
'isinf': numpy.isinf,
'isfinite': numpy.isfinite,
'atan': numpy.arctan,
'atan2': numpy.arctan2,
'expm1': numpy.expm1,
'floor': numpy.floor,
'ceil': numpy.ceil,
'trunc': numpy.trunc,
'round': numpy.round,
}
INTRIN_CPU_NOT_SUPPORT = ["atan2", "expm1", "float16"]
INTRIN_GENERAL_BINARY_OP = {
'ceil_div': lambda a, b: (a + b - 1) // b,
}
INTRIN_GENERAL = {
**INTRIN_GENERAL_UNARY_OP,
**INTRIN_GENERAL_BINARY_OP
}
INTRIN_RUNTIME = {
**INTRIN_GLOBALS,
**INTRIN_GENERAL
}
class VariableUsage(ast.NodeVisitor):
"""
The ast visitor to perform static check for the source code,
and determine the index of inplace assign outputs
"""
intrin_buffer = {
'allocate': _allocate,
'output_tensor': _allocate
}
intrin_loop = {
'range': range,
'serial': range,
'vectorize': range,
'parallel': range,
'reduce': range,
'grid': _grid,
}
intrin_with_scope = {
'attr': WithStub(),
'block_realize': WithStub(),
}
intrin_unary_op = {
'sqrt': numpy.sqrt,
'sign': numpy.sign,
'log': numpy.log,
'tanh': numpy.tanh,
'exp': numpy.exp,
'abs': numpy.abs,
'int32': numpy.int32,
'float16': numpy.float16,
'float32': numpy.float32,
}
intrin_bin_op = {
'power': numpy.power,
}
intrin_globals = {
**intrin_buffer,
**intrin_loop,
**intrin_with_scope,
**intrin_unary_op,
**intrin_bin_op,
}
intrin_general_unary_op = {
'rsqrt': _rsqrt,
'erf': _erf,
'isnan': numpy.isnan,
'int8': numpy.int8,
'int16': numpy.int16,
'int64': numpy.int64,
'float64': numpy.float64,
'sin': numpy.sin,
'cos': numpy.cos,
'isinf': numpy.isinf,
'isfinite': numpy.isfinite,
'atan': numpy.arctan,
'atan2': numpy.arctan2,
'expm1': numpy.expm1,
'floor': numpy.floor,
'ceil': numpy.ceil,
'trunc': numpy.trunc,
'round': numpy.round,
}
intrin_cpu_not_support = ["atan2", "expm1", "float16"]
intrin_general_bin_op = {
'ceil_div': lambda a, b: (a + b - 1) // b,
}
intrin_general = {
**intrin_general_unary_op,
**intrin_general_bin_op
}
intrin_runtime = {
**intrin_globals,
**intrin_general
}
def __init__(self, func_name):
self.func_name = func_name
self.scope_level = []
@ -276,27 +279,27 @@ class VariableUsage(ast.NodeVisitor):
"""
Ast visitor for Call
Check the func call used in the DSL. Only those in INTRIN_RUNTIME are supported for now.
Check the func call used in the DSL. Only those in intrin_runtime are supported for now.
"""
func_id = node.func.id
if not (func_id in list(INTRIN_RUNTIME.keys()) +
['max', 'min', 'len', 'ms_hybrid']):
if not (func_id in list(VariableUsage.intrin_runtime.keys()) +
['max', 'min', 'len', 'ms_kernel']):
raise ValueError(
"In the function {} written in the Hybrid DSL, function call id {} "
"not in intrinsics' list".format(self.func_name, func_id))
if (self.device == "Ascend" and func_id in list(INTRIN_GENERAL.keys())) or \
(self.device == "CPU" and func_id in INTRIN_CPU_NOT_SUPPORT):
if (self.device == "Ascend" and func_id in list(VariableUsage.intrin_general.keys())) or \
(self.device == "CPU" and func_id in VariableUsage.intrin_cpu_not_support):
raise ValueError(
"In the function {} written in the Hybrid DSL, function {} is not available on the "
"device {}".format(self.func_name, func_id, self.device))
if func_id in list(INTRIN_UNARY_OP.keys()) + list(INTRIN_GENERAL_UNARY_OP.keys()) + list(INTRIN_LOOP.keys()) \
if func_id in list(VariableUsage.intrin_unary_op.keys()) + list(VariableUsage.intrin_general_unary_op.keys()) \
and len(node.args) != 1:
raise TypeError(
"In the function {} written in the Hybrid DSL, function {} "
"expects one input, but get {}".format(self.func_name, func_id, len(node.args)))
if func_id in list(INTRIN_BINARY_OP.keys()) + list(INTRIN_GENERAL_BINARY_OP.keys()) + \
list(INTRIN_BUFFER.keys()) and len(node.args) != 2:
if func_id in list(VariableUsage.intrin_bin_op.keys()) + list(VariableUsage.intrin_general_bin_op.keys()) + \
list(VariableUsage.intrin_buffer.keys()) and len(node.args) != 2:
raise TypeError(
"In the function {} written in the Hybrid DSL, function {} "
"expects two inputs, but get {}".format(self.func_name, func_id, len(node.args)))
@ -470,10 +473,10 @@ def determine_variable_usage(root, func_name):
return visitor.inplace_assign_output
def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
def ms_kernel(fn=None, reg_info=None, compile_attrs=None):
"""
The decorator of the Hybrid DSL function for the Custom Op.
When a function written by the Hybrid DSL is decorated by ms_hybrid,
When a function written by the Hybrid DSL is decorated by ms_kernel,
it can be run as a usual Python function.
Also, this function can be used in the api Custom and to create a Custom op, with func_type
"hybrid" or "pyfunc". Creating a custom op with mode "hybrid" by the Hybrid DSL function
@ -495,7 +498,7 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
Examples:
>>> import numpy as np
>>> from mindspore import ops, Tensor
>>> from mindspore.ops import ms_hybrid, DataType, CustomRegOp
>>> from mindspore.ops import ms_kernel, DataType, CustomRegOp
...
>>> # Create a dict for the compile flags.
>>> attrs = {
@ -516,9 +519,9 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
>>> input_x = np.ones([4, 4]).astype(np.float32)
>>> input_y = np.ones([4, 4]).astype(np.float32)
...
>>> # Write a Hybrid DSL function through the decorator @ms_hybrid.
>>> # Write a Hybrid DSL function through the decorator @ms_kernel.
>>> # We can also pass the compile attrs and the reg info through the decorator.
>>> @ms_hybrid(reg_info=op_gpu_info, compile_attrs=attrs)
>>> @ms_kernel(reg_info=op_gpu_info, compile_attrs=attrs)
... def outer_product(a, b):
... c = output_tensor(a.shape, a.dtype)
...
@ -544,21 +547,25 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
compile_attrs = {}
if not isinstance(compile_attrs, dict):
raise TypeError("The input 'compile_attrs' of @ms_hybrid must be a dict, "
raise TypeError("The input 'compile_attrs' of @ms_kernel must be a dict, "
"but get a {}".format(type(compile_attrs)))
for key in compile_attrs.keys():
if not isinstance(key, str):
raise TypeError("The key of 'compile_attrs' of @ms_hybrid must be a str, "
raise TypeError("The key of 'compile_attrs' of @ms_kernel must be a str, "
"but get a {}".format(type(key)))
if reg_info is not None and not isinstance(reg_info, (str, dict, tuple)):
raise TypeError(
"The input 'reg_info' of @ms_hybrid should be one of "
"The input 'reg_info' of @ms_kernel should be one of "
"str, dict and tuple, but get a {}".format(type(reg_info)))
def wrap_ms_hybrid(func):
setattr(func, "ms_hybrid_flag", True)
def wrap_ms_kernel(func):
setattr(func, "ms_kernel_flag", True)
# we enable ml scheduler automatically for ms_kernel function
compile_attrs["enable_mlsched"] = True
setattr(func, "compile_attrs", json.dumps(compile_attrs))
if reg_info is not None:
setattr(func, "reg_info", reg_info)
@ -566,12 +573,25 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
@wraps(func)
def _patch_intrins_to_runtime(*args):
_globals = func.__globals__
for elem in list(INTRIN_RUNTIME.keys()):
_globals[elem] = INTRIN_RUNTIME[elem]
for elem in list(VariableUsage.intrin_runtime.keys()):
_globals[elem] = VariableUsage.intrin_runtime[elem]
return func(*args)
return _patch_intrins_to_runtime
if fn is not None:
return wrap_ms_hybrid(fn)
return wrap_ms_hybrid
return wrap_ms_kernel(fn)
return wrap_ms_kernel
def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
"""
Same as docarator ms_kernel. ms_hybrid will be deprecated in the future.
Please use ms_kernel instead.
Supported Platforms:
Deprecated
"""
log.warning("'ms_hybrid' is deprecated from version 1.8 and "
"will be removed in a future version, use 'ms_kernel' instead.")
return ms_kernel(fn, reg_info, compile_attrs)

View File

@ -29,7 +29,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import DataType
from mindspore import log as logger
from mindspore import ops
from ._ms_hybrid import determine_variable_usage
from ._ms_kernel import determine_variable_usage
from ._custom_grad import autodiff_bprop
from ._pyfunc_registry import add_pyfunc
@ -107,7 +107,7 @@ class Custom(ops.PrimitiveWithInfer):
1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
2. A TBE operator implementation function.
3. A pure python function
4. An ms_hybrid decorated function written by the Hybrid DSL.
4. An ms_kernel decorated function written by the Hybrid DSL.
- str: If func is of str type, then str should be a path of file along with a function name.
This could be used when func_type is "aot" or "julia".
@ -260,7 +260,7 @@ class Custom(ops.PrimitiveWithInfer):
Examples:
>>> import mindspore.ops as ops
>>> import numpy as np
>>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, ms_hybrid
>>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, ms_kernel
>>> from mindspore.common import dtype as mstype
>>> from mindspore.nn import Cell
>>> input_x = Tensor(np.ones([16, 16]).astype(np.float32))
@ -270,8 +270,8 @@ class Custom(ops.PrimitiveWithInfer):
>>> # This is the default func_type in Custom,
>>> # and both out_shape and out_dtype can be None(default value).
>>> # In this case, the input func must be a function written in the Hybrid DSL
>>> # and decorated by @ms_hybrid.
>>> @ms_hybrid
>>> # and decorated by @ms_kernel.
>>> @ms_kernel
... def outer_product_script(a, b):
... c = output_tensor(a.shape, a.dtype)
... for i0 in range(a.shape[0]):
@ -402,7 +402,7 @@ class Custom(ops.PrimitiveWithInfer):
self.imply_path = ""
self.func_source_str = ""
self._func_compile_attrs = {}
self._is_ms_hybrid = False
self._is_ms_kernel = False
self._check_func()
self._update_func_info()
@ -488,20 +488,20 @@ class Custom(ops.PrimitiveWithInfer):
elif self.func_type == "julia":
self._check_julia_func()
elif self.func_type == "hybrid":
if not hasattr(self.func, "ms_hybrid_flag"):
raise TypeError("{}, 'func' must be a function decorated by ms_hybrid".format(self.log_prefix))
self._is_ms_hybrid = True
if not hasattr(self.func, "ms_kernel_flag"):
raise TypeError("{}, 'func' must be a function decorated by ms_kernel".format(self.log_prefix))
self._is_ms_kernel = True
self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
elif self.func_type == "akg":
if hasattr(self.func, "ms_hybrid_flag"):
if hasattr(self.func, "ms_kernel_flag"):
logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
"for the input function with decorator @ms_hybrid. "
"for the input function with decorator @ms_kernel. "
"To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
elif self.func_type == "pyfunc":
if hasattr(self.func, "ms_hybrid_flag"):
logger.warning("{}. Now you are using the function with decorator @ms_hybrid in the mode pyfunc. "
if hasattr(self.func, "ms_kernel_flag"):
logger.warning("{}. Now you are using the function with decorator @ms_kernel in the mode pyfunc. "
"The kernel will be executed as a native python function, which might lead to "
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"ms_hybrid\""
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
.format(self.log_prefix))
else:
if not callable(self.func):
@ -524,7 +524,7 @@ class Custom(ops.PrimitiveWithInfer):
if index != -1:
self.func_source_str = self.func_source_str[index:]
if self._is_ms_hybrid:
if self._is_ms_kernel:
# static check for the Hybrid DSL in hybrid
root = ast.parse(self.func_source_str)
inplace_assign_output = determine_variable_usage(root, self.func_name)
@ -848,7 +848,7 @@ class Custom(ops.PrimitiveWithInfer):
def _auto_infer(self, *args):
"""
the automatic infer function for functions with @ms_hybrid decorator
the automatic infer function for functions with @ms_kernel decorator
"""
fake_input = []
enable_infer_value = True
@ -893,7 +893,7 @@ class Custom(ops.PrimitiveWithInfer):
# deal with the case of ms script
# enable auto infer function if any infer information is missing
if self._is_ms_hybrid and (infer_dtype is None or infer_shape is None):
if self._is_ms_kernel and (infer_dtype is None or infer_shape is None):
logger.warning("{}, 'out_shape' or 'out_dtype' is None, infer the output shape and output dtype "
"automatically. There might be some Python RuntimeWarning but it wouldn't influence the "
"result.".format(self.log_prefix))

View File

@ -334,6 +334,7 @@ def irbuilder_case():
raise ValueError("Precision error, compare result: {}".format(compare_res))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -347,6 +348,7 @@ def test_irbuilder_ascend_graph_mode():
irbuilder_case()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard

View File

@ -19,10 +19,10 @@ import numpy as np
from mindspore import context, Tensor
from mindspore.nn import Cell
import mindspore.ops as ops
from mindspore.ops import ms_hybrid
from mindspore.ops import ms_kernel
@ms_hybrid
@ms_kernel
def dtype_and_cast_example(a, b):
"""
test function for dtype and cast in Hybrid DSL
@ -38,7 +38,7 @@ def dtype_and_cast_example(a, b):
return c
@ms_hybrid
@ms_kernel
def allocate_and_math_intrin_example(a, b):
"""
test function for allocate and math function in Hybrid DSL
@ -53,7 +53,7 @@ def allocate_and_math_intrin_example(a, b):
return c
@ms_hybrid
@ms_kernel
def grid_example(a, b):
"""
test function for grid in Hybrid DSL
@ -77,7 +77,7 @@ class TestMsHybridDSL(Cell):
return self.program(x, y)
def ms_hybrid_cast_with_infer():
def ms_kernel_cast_with_infer():
"""
test case Custom Op with functions written in Hybrid DSL and infer functions
"""
@ -93,7 +93,7 @@ def ms_hybrid_cast_with_infer():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def ms_hybrid_cast_without_infer():
def ms_kernel_cast_without_infer():
"""
test case Custom Op with functions written in Hybrid DSL and without infer functions
"""
@ -109,7 +109,7 @@ def ms_hybrid_cast_without_infer():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def ms_hybrid_cast_pyfunc():
def ms_kernel_cast_pyfunc():
"""
test case Custom Op with functions written in Hybrid DSL and func_type pyfunc
"""
@ -125,7 +125,7 @@ def ms_hybrid_cast_pyfunc():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def ms_hybrid_allocate():
def ms_kernel_allocate():
"""
test case Custom Op with functions written in Hybrid DSL about math functions and allocate
"""
@ -141,7 +141,7 @@ def ms_hybrid_allocate():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def ms_hybrid_allocate_cpu():
def ms_kernel_allocate_cpu():
"""
test case Custom Op with functions written in Hybrid DSL about math functions and allocate
for cpu, we test fp32 to avoid env diff in support of data types.
@ -158,7 +158,7 @@ def ms_hybrid_allocate_cpu():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def ms_hybrid_grid():
def ms_kernel_grid():
"""
test case Custom Op with functions written in Hybrid DSL about grid
"""
@ -174,7 +174,7 @@ def ms_hybrid_grid():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def ms_hybrid_grid_cpu():
def ms_kernel_grid_cpu():
"""
test case Custom Op with functions written in Hybrid DSL about grid
"""
@ -194,79 +194,79 @@ def ms_hybrid_grid_cpu():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ms_hybrid_ascend_graph_mode():
def test_ms_kernel_ascend_graph_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: ascend test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
Feature: test case for Custom op with func_type="ms_kernel"
Description: ascend test case, Python DSL with ms_kernel decorator in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_allocate()
ms_hybrid_grid()
ms_kernel_cast_pyfunc()
ms_kernel_cast_with_infer()
ms_kernel_cast_without_infer()
ms_kernel_allocate()
ms_kernel_grid()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ms_hybrid_ascend_pynative_mode():
def test_ms_kernel_ascend_pynative_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: ascend test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
Feature: test case for Custom op with func_type="ms_kernel"
Description: ascend test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_allocate()
ms_hybrid_grid()
ms_kernel_cast_pyfunc()
ms_kernel_cast_with_infer()
ms_kernel_cast_without_infer()
ms_kernel_allocate()
ms_kernel_grid()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ms_hybrid_gpu_graph_mode():
def test_ms_kernel_gpu_graph_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: gpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
Feature: test case for Custom op with func_type="ms_kernel"
Description: gpu test case, Python DSL with ms_kernel decorator in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_allocate()
ms_hybrid_grid()
ms_kernel_cast_pyfunc()
ms_kernel_cast_with_infer()
ms_kernel_cast_without_infer()
ms_kernel_allocate()
ms_kernel_grid()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ms_hybrid_gpu_pynative_mode():
def test_ms_kernel_gpu_pynative_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: gpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
Feature: test case for Custom op with func_type="ms_kernel"
Description: gpu test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_allocate()
ms_hybrid_grid()
ms_kernel_cast_pyfunc()
ms_kernel_cast_with_infer()
ms_kernel_cast_without_infer()
ms_kernel_allocate()
ms_kernel_grid()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ms_hybrid_cpu_graph_mode():
def test_ms_kernel_cpu_graph_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: cpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
Feature: test case for Custom op with func_type="ms_kernel"
Description: cpu test case, Python DSL with ms_kernel decorator in GRAPH_MODE.
Expectation: the result match with numpy result
"""
if platform.system().lower() in {"windows", "darwin"}:
@ -274,22 +274,22 @@ def test_ms_hybrid_cpu_graph_mode():
pass
else:
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
ms_hybrid_allocate_cpu()
ms_hybrid_grid_cpu()
ms_kernel_allocate_cpu()
ms_kernel_grid_cpu()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ms_hybrid_cpu_pynative_mode():
def test_ms_kernel_cpu_pynative_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: cpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
Feature: test case for Custom op with func_type="ms_kernel"
Description: cpu test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
if platform.system().lower() in {"windows", "darwin"}:
pass
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
ms_hybrid_allocate_cpu()
ms_hybrid_grid_cpu()
ms_kernel_allocate_cpu()
ms_kernel_grid_cpu()