forked from mindspore-Ecosystem/mindspore
update ms_kernel related files
fix import problem update code check update test case update test case rename file
This commit is contained in:
parent
b041966ea2
commit
8de24b4d2f
|
@ -31,7 +31,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
|
|||
|
||||
mindspore.ops.constexpr
|
||||
mindspore.ops.custom_info_register
|
||||
mindspore.ops.ms_hybrid
|
||||
mindspore.ops.ms_kernel
|
||||
mindspore.ops.op_info_register
|
||||
mindspore.ops.prim_attr_register
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
mindspore.ops.ms_hybrid
|
||||
mindspore.ops.ms_kernel
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.ms_hybrid(fn=None, reg_info=None, compile_attrs=None)
|
||||
.. py:function:: mindspore.ops.ms_kernel(fn=None, reg_info=None, compile_attrs=None)
|
||||
|
||||
用于MindSpore Hybrid DSL函数书写的装饰器。
|
||||
给用MindSpore Hybrid DSL书写的函数加上此装饰器后,它可以用作一个普通的Python函数。
|
|
@ -31,7 +31,7 @@ Decorators
|
|||
|
||||
mindspore.ops.constexpr
|
||||
mindspore.ops.custom_info_register
|
||||
mindspore.ops.ms_hybrid
|
||||
mindspore.ops.ms_kernel
|
||||
mindspore.ops.op_info_register
|
||||
mindspore.ops.prim_attr_register
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg
|
|||
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,
|
||||
LoadIm2Col, UpdateThorGradient, Cholesky, CholeskyTrsm,
|
||||
DetTriangle, ProdForceSeA)
|
||||
from ._ms_hybrid import (ms_hybrid)
|
||||
from ._ms_kernel import (ms_kernel, ms_hybrid)
|
||||
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation,
|
||||
|
@ -591,6 +591,7 @@ __sponge__ = [
|
|||
]
|
||||
|
||||
__custom__ = [
|
||||
"ms_kernel",
|
||||
"ms_hybrid",
|
||||
]
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ms_hybrid decorator and related util functions"""
|
||||
"""ms_kernel decorator and related util functions"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
from functools import wraps
|
||||
from itertools import product
|
||||
import numpy
|
||||
from mindspore import context
|
||||
from mindspore import context, log
|
||||
|
||||
|
||||
def _allocate(shape, dtype='float32', scope='global'):
|
||||
|
@ -86,15 +86,15 @@ def _erf(x):
|
|||
p = 0.3275911
|
||||
|
||||
# A&S formula 7.1.26
|
||||
t = 1.0/(1.0 + p*x)
|
||||
y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*numpy.exp(-x*x)
|
||||
return sign*y # erf(-x) = -erf(x)
|
||||
t = 1.0 / (1.0 + p * x)
|
||||
y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * numpy.exp(-x * x)
|
||||
return sign * y # erf(-x) = -erf(x)
|
||||
|
||||
|
||||
def _grid(extents):
|
||||
extents_list = []
|
||||
for ext in extents:
|
||||
extents_list.append(range(ext))
|
||||
extents_list.append(list(range(ext)))
|
||||
return product(*extents_list)
|
||||
|
||||
|
||||
|
@ -120,89 +120,92 @@ class WithStub:
|
|||
return self
|
||||
|
||||
|
||||
INTRIN_BUFFER = {
|
||||
'allocate': _allocate,
|
||||
'output_tensor': _allocate
|
||||
}
|
||||
|
||||
INTRIN_LOOP = {
|
||||
'range': range,
|
||||
'grid': _grid,
|
||||
}
|
||||
|
||||
INTRIN_WITH_SCOPE = {
|
||||
'attr': WithStub(),
|
||||
'block_realize': WithStub(),
|
||||
}
|
||||
|
||||
INTRIN_UNARY_OP = {
|
||||
'sqrt': numpy.sqrt,
|
||||
'sign': numpy.sign,
|
||||
'log': numpy.log,
|
||||
'tanh': numpy.tanh,
|
||||
'exp': numpy.exp,
|
||||
'abs': numpy.abs,
|
||||
'int32': numpy.int32,
|
||||
'float16': numpy.float16,
|
||||
'float32': numpy.float32,
|
||||
}
|
||||
|
||||
INTRIN_BINARY_OP = {
|
||||
'power': numpy.power,
|
||||
}
|
||||
|
||||
INTRIN_GLOBALS = {
|
||||
**INTRIN_BUFFER,
|
||||
**INTRIN_LOOP,
|
||||
**INTRIN_WITH_SCOPE,
|
||||
**INTRIN_UNARY_OP,
|
||||
**INTRIN_BINARY_OP,
|
||||
}
|
||||
|
||||
INTRIN_GENERAL_UNARY_OP = {
|
||||
'rsqrt': _rsqrt,
|
||||
'erf': _erf,
|
||||
'isnan': numpy.isnan,
|
||||
'int8': numpy.int8,
|
||||
'int16': numpy.int16,
|
||||
'int64': numpy.int64,
|
||||
'float64': numpy.float64,
|
||||
'sin': numpy.sin,
|
||||
'cos': numpy.cos,
|
||||
'isinf': numpy.isinf,
|
||||
'isfinite': numpy.isfinite,
|
||||
'atan': numpy.arctan,
|
||||
'atan2': numpy.arctan2,
|
||||
'expm1': numpy.expm1,
|
||||
'floor': numpy.floor,
|
||||
'ceil': numpy.ceil,
|
||||
'trunc': numpy.trunc,
|
||||
'round': numpy.round,
|
||||
}
|
||||
|
||||
INTRIN_CPU_NOT_SUPPORT = ["atan2", "expm1", "float16"]
|
||||
|
||||
INTRIN_GENERAL_BINARY_OP = {
|
||||
'ceil_div': lambda a, b: (a + b - 1) // b,
|
||||
}
|
||||
|
||||
INTRIN_GENERAL = {
|
||||
**INTRIN_GENERAL_UNARY_OP,
|
||||
**INTRIN_GENERAL_BINARY_OP
|
||||
}
|
||||
|
||||
INTRIN_RUNTIME = {
|
||||
**INTRIN_GLOBALS,
|
||||
**INTRIN_GENERAL
|
||||
}
|
||||
|
||||
|
||||
class VariableUsage(ast.NodeVisitor):
|
||||
"""
|
||||
The ast visitor to perform static check for the source code,
|
||||
and determine the index of inplace assign outputs
|
||||
"""
|
||||
|
||||
intrin_buffer = {
|
||||
'allocate': _allocate,
|
||||
'output_tensor': _allocate
|
||||
}
|
||||
|
||||
intrin_loop = {
|
||||
'range': range,
|
||||
'serial': range,
|
||||
'vectorize': range,
|
||||
'parallel': range,
|
||||
'reduce': range,
|
||||
'grid': _grid,
|
||||
}
|
||||
|
||||
intrin_with_scope = {
|
||||
'attr': WithStub(),
|
||||
'block_realize': WithStub(),
|
||||
}
|
||||
|
||||
intrin_unary_op = {
|
||||
'sqrt': numpy.sqrt,
|
||||
'sign': numpy.sign,
|
||||
'log': numpy.log,
|
||||
'tanh': numpy.tanh,
|
||||
'exp': numpy.exp,
|
||||
'abs': numpy.abs,
|
||||
'int32': numpy.int32,
|
||||
'float16': numpy.float16,
|
||||
'float32': numpy.float32,
|
||||
}
|
||||
|
||||
intrin_bin_op = {
|
||||
'power': numpy.power,
|
||||
}
|
||||
|
||||
intrin_globals = {
|
||||
**intrin_buffer,
|
||||
**intrin_loop,
|
||||
**intrin_with_scope,
|
||||
**intrin_unary_op,
|
||||
**intrin_bin_op,
|
||||
}
|
||||
|
||||
intrin_general_unary_op = {
|
||||
'rsqrt': _rsqrt,
|
||||
'erf': _erf,
|
||||
'isnan': numpy.isnan,
|
||||
'int8': numpy.int8,
|
||||
'int16': numpy.int16,
|
||||
'int64': numpy.int64,
|
||||
'float64': numpy.float64,
|
||||
'sin': numpy.sin,
|
||||
'cos': numpy.cos,
|
||||
'isinf': numpy.isinf,
|
||||
'isfinite': numpy.isfinite,
|
||||
'atan': numpy.arctan,
|
||||
'atan2': numpy.arctan2,
|
||||
'expm1': numpy.expm1,
|
||||
'floor': numpy.floor,
|
||||
'ceil': numpy.ceil,
|
||||
'trunc': numpy.trunc,
|
||||
'round': numpy.round,
|
||||
}
|
||||
|
||||
intrin_cpu_not_support = ["atan2", "expm1", "float16"]
|
||||
|
||||
intrin_general_bin_op = {
|
||||
'ceil_div': lambda a, b: (a + b - 1) // b,
|
||||
}
|
||||
|
||||
intrin_general = {
|
||||
**intrin_general_unary_op,
|
||||
**intrin_general_bin_op
|
||||
}
|
||||
|
||||
intrin_runtime = {
|
||||
**intrin_globals,
|
||||
**intrin_general
|
||||
}
|
||||
|
||||
def __init__(self, func_name):
|
||||
self.func_name = func_name
|
||||
self.scope_level = []
|
||||
|
@ -276,27 +279,27 @@ class VariableUsage(ast.NodeVisitor):
|
|||
"""
|
||||
Ast visitor for Call
|
||||
|
||||
Check the func call used in the DSL. Only those in INTRIN_RUNTIME are supported for now.
|
||||
Check the func call used in the DSL. Only those in intrin_runtime are supported for now.
|
||||
"""
|
||||
|
||||
func_id = node.func.id
|
||||
if not (func_id in list(INTRIN_RUNTIME.keys()) +
|
||||
['max', 'min', 'len', 'ms_hybrid']):
|
||||
if not (func_id in list(VariableUsage.intrin_runtime.keys()) +
|
||||
['max', 'min', 'len', 'ms_kernel']):
|
||||
raise ValueError(
|
||||
"In the function {} written in the Hybrid DSL, function call id {} "
|
||||
"not in intrinsics' list".format(self.func_name, func_id))
|
||||
if (self.device == "Ascend" and func_id in list(INTRIN_GENERAL.keys())) or \
|
||||
(self.device == "CPU" and func_id in INTRIN_CPU_NOT_SUPPORT):
|
||||
if (self.device == "Ascend" and func_id in list(VariableUsage.intrin_general.keys())) or \
|
||||
(self.device == "CPU" and func_id in VariableUsage.intrin_cpu_not_support):
|
||||
raise ValueError(
|
||||
"In the function {} written in the Hybrid DSL, function {} is not available on the "
|
||||
"device {}".format(self.func_name, func_id, self.device))
|
||||
if func_id in list(INTRIN_UNARY_OP.keys()) + list(INTRIN_GENERAL_UNARY_OP.keys()) + list(INTRIN_LOOP.keys()) \
|
||||
if func_id in list(VariableUsage.intrin_unary_op.keys()) + list(VariableUsage.intrin_general_unary_op.keys()) \
|
||||
and len(node.args) != 1:
|
||||
raise TypeError(
|
||||
"In the function {} written in the Hybrid DSL, function {} "
|
||||
"expects one input, but get {}".format(self.func_name, func_id, len(node.args)))
|
||||
if func_id in list(INTRIN_BINARY_OP.keys()) + list(INTRIN_GENERAL_BINARY_OP.keys()) + \
|
||||
list(INTRIN_BUFFER.keys()) and len(node.args) != 2:
|
||||
if func_id in list(VariableUsage.intrin_bin_op.keys()) + list(VariableUsage.intrin_general_bin_op.keys()) + \
|
||||
list(VariableUsage.intrin_buffer.keys()) and len(node.args) != 2:
|
||||
raise TypeError(
|
||||
"In the function {} written in the Hybrid DSL, function {} "
|
||||
"expects two inputs, but get {}".format(self.func_name, func_id, len(node.args)))
|
||||
|
@ -470,10 +473,10 @@ def determine_variable_usage(root, func_name):
|
|||
return visitor.inplace_assign_output
|
||||
|
||||
|
||||
def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
|
||||
def ms_kernel(fn=None, reg_info=None, compile_attrs=None):
|
||||
"""
|
||||
The decorator of the Hybrid DSL function for the Custom Op.
|
||||
When a function written by the Hybrid DSL is decorated by ms_hybrid,
|
||||
When a function written by the Hybrid DSL is decorated by ms_kernel,
|
||||
it can be run as a usual Python function.
|
||||
Also, this function can be used in the api Custom and to create a Custom op, with func_type
|
||||
"hybrid" or "pyfunc". Creating a custom op with mode "hybrid" by the Hybrid DSL function
|
||||
|
@ -495,7 +498,7 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
|
|||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import ops, Tensor
|
||||
>>> from mindspore.ops import ms_hybrid, DataType, CustomRegOp
|
||||
>>> from mindspore.ops import ms_kernel, DataType, CustomRegOp
|
||||
...
|
||||
>>> # Create a dict for the compile flags.
|
||||
>>> attrs = {
|
||||
|
@ -516,9 +519,9 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
|
|||
>>> input_x = np.ones([4, 4]).astype(np.float32)
|
||||
>>> input_y = np.ones([4, 4]).astype(np.float32)
|
||||
...
|
||||
>>> # Write a Hybrid DSL function through the decorator @ms_hybrid.
|
||||
>>> # Write a Hybrid DSL function through the decorator @ms_kernel.
|
||||
>>> # We can also pass the compile attrs and the reg info through the decorator.
|
||||
>>> @ms_hybrid(reg_info=op_gpu_info, compile_attrs=attrs)
|
||||
>>> @ms_kernel(reg_info=op_gpu_info, compile_attrs=attrs)
|
||||
... def outer_product(a, b):
|
||||
... c = output_tensor(a.shape, a.dtype)
|
||||
...
|
||||
|
@ -544,21 +547,25 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
|
|||
compile_attrs = {}
|
||||
|
||||
if not isinstance(compile_attrs, dict):
|
||||
raise TypeError("The input 'compile_attrs' of @ms_hybrid must be a dict, "
|
||||
raise TypeError("The input 'compile_attrs' of @ms_kernel must be a dict, "
|
||||
"but get a {}".format(type(compile_attrs)))
|
||||
|
||||
for key in compile_attrs.keys():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("The key of 'compile_attrs' of @ms_hybrid must be a str, "
|
||||
raise TypeError("The key of 'compile_attrs' of @ms_kernel must be a str, "
|
||||
"but get a {}".format(type(key)))
|
||||
|
||||
if reg_info is not None and not isinstance(reg_info, (str, dict, tuple)):
|
||||
raise TypeError(
|
||||
"The input 'reg_info' of @ms_hybrid should be one of "
|
||||
"The input 'reg_info' of @ms_kernel should be one of "
|
||||
"str, dict and tuple, but get a {}".format(type(reg_info)))
|
||||
|
||||
def wrap_ms_hybrid(func):
|
||||
setattr(func, "ms_hybrid_flag", True)
|
||||
def wrap_ms_kernel(func):
|
||||
setattr(func, "ms_kernel_flag", True)
|
||||
|
||||
# we enable ml scheduler automatically for ms_kernel function
|
||||
compile_attrs["enable_mlsched"] = True
|
||||
|
||||
setattr(func, "compile_attrs", json.dumps(compile_attrs))
|
||||
if reg_info is not None:
|
||||
setattr(func, "reg_info", reg_info)
|
||||
|
@ -566,12 +573,25 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
|
|||
@wraps(func)
|
||||
def _patch_intrins_to_runtime(*args):
|
||||
_globals = func.__globals__
|
||||
for elem in list(INTRIN_RUNTIME.keys()):
|
||||
_globals[elem] = INTRIN_RUNTIME[elem]
|
||||
for elem in list(VariableUsage.intrin_runtime.keys()):
|
||||
_globals[elem] = VariableUsage.intrin_runtime[elem]
|
||||
return func(*args)
|
||||
|
||||
return _patch_intrins_to_runtime
|
||||
|
||||
if fn is not None:
|
||||
return wrap_ms_hybrid(fn)
|
||||
return wrap_ms_hybrid
|
||||
return wrap_ms_kernel(fn)
|
||||
return wrap_ms_kernel
|
||||
|
||||
|
||||
def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
|
||||
"""
|
||||
Same as docarator ms_kernel. ms_hybrid will be deprecated in the future.
|
||||
Please use ms_kernel instead.
|
||||
|
||||
Supported Platforms:
|
||||
Deprecated
|
||||
"""
|
||||
log.warning("'ms_hybrid' is deprecated from version 1.8 and "
|
||||
"will be removed in a future version, use 'ms_kernel' instead.")
|
||||
return ms_kernel(fn, reg_info, compile_attrs)
|
|
@ -29,7 +29,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops import DataType
|
||||
from mindspore import log as logger
|
||||
from mindspore import ops
|
||||
from ._ms_hybrid import determine_variable_usage
|
||||
from ._ms_kernel import determine_variable_usage
|
||||
from ._custom_grad import autodiff_bprop
|
||||
from ._pyfunc_registry import add_pyfunc
|
||||
|
||||
|
@ -107,7 +107,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
|
||||
2. A TBE operator implementation function.
|
||||
3. A pure python function
|
||||
4. An ms_hybrid decorated function written by the Hybrid DSL.
|
||||
4. An ms_kernel decorated function written by the Hybrid DSL.
|
||||
|
||||
- str: If func is of str type, then str should be a path of file along with a function name.
|
||||
This could be used when func_type is "aot" or "julia".
|
||||
|
@ -260,7 +260,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
Examples:
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, ms_hybrid
|
||||
>>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, ms_kernel
|
||||
>>> from mindspore.common import dtype as mstype
|
||||
>>> from mindspore.nn import Cell
|
||||
>>> input_x = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
|
@ -270,8 +270,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
>>> # This is the default func_type in Custom,
|
||||
>>> # and both out_shape and out_dtype can be None(default value).
|
||||
>>> # In this case, the input func must be a function written in the Hybrid DSL
|
||||
>>> # and decorated by @ms_hybrid.
|
||||
>>> @ms_hybrid
|
||||
>>> # and decorated by @ms_kernel.
|
||||
>>> @ms_kernel
|
||||
... def outer_product_script(a, b):
|
||||
... c = output_tensor(a.shape, a.dtype)
|
||||
... for i0 in range(a.shape[0]):
|
||||
|
@ -402,7 +402,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.imply_path = ""
|
||||
self.func_source_str = ""
|
||||
self._func_compile_attrs = {}
|
||||
self._is_ms_hybrid = False
|
||||
self._is_ms_kernel = False
|
||||
|
||||
self._check_func()
|
||||
self._update_func_info()
|
||||
|
@ -488,20 +488,20 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
elif self.func_type == "julia":
|
||||
self._check_julia_func()
|
||||
elif self.func_type == "hybrid":
|
||||
if not hasattr(self.func, "ms_hybrid_flag"):
|
||||
raise TypeError("{}, 'func' must be a function decorated by ms_hybrid".format(self.log_prefix))
|
||||
self._is_ms_hybrid = True
|
||||
if not hasattr(self.func, "ms_kernel_flag"):
|
||||
raise TypeError("{}, 'func' must be a function decorated by ms_kernel".format(self.log_prefix))
|
||||
self._is_ms_kernel = True
|
||||
self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
|
||||
elif self.func_type == "akg":
|
||||
if hasattr(self.func, "ms_hybrid_flag"):
|
||||
if hasattr(self.func, "ms_kernel_flag"):
|
||||
logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
|
||||
"for the input function with decorator @ms_hybrid. "
|
||||
"for the input function with decorator @ms_kernel. "
|
||||
"To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
|
||||
elif self.func_type == "pyfunc":
|
||||
if hasattr(self.func, "ms_hybrid_flag"):
|
||||
logger.warning("{}. Now you are using the function with decorator @ms_hybrid in the mode pyfunc. "
|
||||
if hasattr(self.func, "ms_kernel_flag"):
|
||||
logger.warning("{}. Now you are using the function with decorator @ms_kernel in the mode pyfunc. "
|
||||
"The kernel will be executed as a native python function, which might lead to "
|
||||
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"ms_hybrid\""
|
||||
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
|
||||
.format(self.log_prefix))
|
||||
else:
|
||||
if not callable(self.func):
|
||||
|
@ -524,7 +524,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
if index != -1:
|
||||
self.func_source_str = self.func_source_str[index:]
|
||||
|
||||
if self._is_ms_hybrid:
|
||||
if self._is_ms_kernel:
|
||||
# static check for the Hybrid DSL in hybrid
|
||||
root = ast.parse(self.func_source_str)
|
||||
inplace_assign_output = determine_variable_usage(root, self.func_name)
|
||||
|
@ -848,7 +848,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
|
||||
def _auto_infer(self, *args):
|
||||
"""
|
||||
the automatic infer function for functions with @ms_hybrid decorator
|
||||
the automatic infer function for functions with @ms_kernel decorator
|
||||
"""
|
||||
fake_input = []
|
||||
enable_infer_value = True
|
||||
|
@ -893,7 +893,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
|
||||
# deal with the case of ms script
|
||||
# enable auto infer function if any infer information is missing
|
||||
if self._is_ms_hybrid and (infer_dtype is None or infer_shape is None):
|
||||
if self._is_ms_kernel and (infer_dtype is None or infer_shape is None):
|
||||
logger.warning("{}, 'out_shape' or 'out_dtype' is None, infer the output shape and output dtype "
|
||||
"automatically. There might be some Python RuntimeWarning but it wouldn't influence the "
|
||||
"result.".format(self.log_prefix))
|
||||
|
|
|
@ -334,6 +334,7 @@ def irbuilder_case():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -347,6 +348,7 @@ def test_irbuilder_ascend_graph_mode():
|
|||
irbuilder_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -19,10 +19,10 @@ import numpy as np
|
|||
from mindspore import context, Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import ms_hybrid
|
||||
from mindspore.ops import ms_kernel
|
||||
|
||||
|
||||
@ms_hybrid
|
||||
@ms_kernel
|
||||
def dtype_and_cast_example(a, b):
|
||||
"""
|
||||
test function for dtype and cast in Hybrid DSL
|
||||
|
@ -38,7 +38,7 @@ def dtype_and_cast_example(a, b):
|
|||
return c
|
||||
|
||||
|
||||
@ms_hybrid
|
||||
@ms_kernel
|
||||
def allocate_and_math_intrin_example(a, b):
|
||||
"""
|
||||
test function for allocate and math function in Hybrid DSL
|
||||
|
@ -53,7 +53,7 @@ def allocate_and_math_intrin_example(a, b):
|
|||
return c
|
||||
|
||||
|
||||
@ms_hybrid
|
||||
@ms_kernel
|
||||
def grid_example(a, b):
|
||||
"""
|
||||
test function for grid in Hybrid DSL
|
||||
|
@ -77,7 +77,7 @@ class TestMsHybridDSL(Cell):
|
|||
return self.program(x, y)
|
||||
|
||||
|
||||
def ms_hybrid_cast_with_infer():
|
||||
def ms_kernel_cast_with_infer():
|
||||
"""
|
||||
test case Custom Op with functions written in Hybrid DSL and infer functions
|
||||
"""
|
||||
|
@ -93,7 +93,7 @@ def ms_hybrid_cast_with_infer():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def ms_hybrid_cast_without_infer():
|
||||
def ms_kernel_cast_without_infer():
|
||||
"""
|
||||
test case Custom Op with functions written in Hybrid DSL and without infer functions
|
||||
"""
|
||||
|
@ -109,7 +109,7 @@ def ms_hybrid_cast_without_infer():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def ms_hybrid_cast_pyfunc():
|
||||
def ms_kernel_cast_pyfunc():
|
||||
"""
|
||||
test case Custom Op with functions written in Hybrid DSL and func_type pyfunc
|
||||
"""
|
||||
|
@ -125,7 +125,7 @@ def ms_hybrid_cast_pyfunc():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def ms_hybrid_allocate():
|
||||
def ms_kernel_allocate():
|
||||
"""
|
||||
test case Custom Op with functions written in Hybrid DSL about math functions and allocate
|
||||
"""
|
||||
|
@ -141,7 +141,7 @@ def ms_hybrid_allocate():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def ms_hybrid_allocate_cpu():
|
||||
def ms_kernel_allocate_cpu():
|
||||
"""
|
||||
test case Custom Op with functions written in Hybrid DSL about math functions and allocate
|
||||
for cpu, we test fp32 to avoid env diff in support of data types.
|
||||
|
@ -158,7 +158,7 @@ def ms_hybrid_allocate_cpu():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def ms_hybrid_grid():
|
||||
def ms_kernel_grid():
|
||||
"""
|
||||
test case Custom Op with functions written in Hybrid DSL about grid
|
||||
"""
|
||||
|
@ -174,7 +174,7 @@ def ms_hybrid_grid():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def ms_hybrid_grid_cpu():
|
||||
def ms_kernel_grid_cpu():
|
||||
"""
|
||||
test case Custom Op with functions written in Hybrid DSL about grid
|
||||
"""
|
||||
|
@ -194,79 +194,79 @@ def ms_hybrid_grid_cpu():
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_hybrid_ascend_graph_mode():
|
||||
def test_ms_kernel_ascend_graph_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="ms_hybrid"
|
||||
Description: ascend test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
|
||||
Feature: test case for Custom op with func_type="ms_kernel"
|
||||
Description: ascend test case, Python DSL with ms_kernel decorator in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
ms_hybrid_cast_pyfunc()
|
||||
ms_hybrid_cast_with_infer()
|
||||
ms_hybrid_cast_without_infer()
|
||||
ms_hybrid_allocate()
|
||||
ms_hybrid_grid()
|
||||
ms_kernel_cast_pyfunc()
|
||||
ms_kernel_cast_with_infer()
|
||||
ms_kernel_cast_without_infer()
|
||||
ms_kernel_allocate()
|
||||
ms_kernel_grid()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_hybrid_ascend_pynative_mode():
|
||||
def test_ms_kernel_ascend_pynative_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="ms_hybrid"
|
||||
Description: ascend test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
|
||||
Feature: test case for Custom op with func_type="ms_kernel"
|
||||
Description: ascend test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
ms_hybrid_cast_pyfunc()
|
||||
ms_hybrid_cast_with_infer()
|
||||
ms_hybrid_cast_without_infer()
|
||||
ms_hybrid_allocate()
|
||||
ms_hybrid_grid()
|
||||
ms_kernel_cast_pyfunc()
|
||||
ms_kernel_cast_with_infer()
|
||||
ms_kernel_cast_without_infer()
|
||||
ms_kernel_allocate()
|
||||
ms_kernel_grid()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_hybrid_gpu_graph_mode():
|
||||
def test_ms_kernel_gpu_graph_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="ms_hybrid"
|
||||
Description: gpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
|
||||
Feature: test case for Custom op with func_type="ms_kernel"
|
||||
Description: gpu test case, Python DSL with ms_kernel decorator in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms_hybrid_cast_pyfunc()
|
||||
ms_hybrid_cast_with_infer()
|
||||
ms_hybrid_cast_without_infer()
|
||||
ms_hybrid_allocate()
|
||||
ms_hybrid_grid()
|
||||
ms_kernel_cast_pyfunc()
|
||||
ms_kernel_cast_with_infer()
|
||||
ms_kernel_cast_without_infer()
|
||||
ms_kernel_allocate()
|
||||
ms_kernel_grid()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_hybrid_gpu_pynative_mode():
|
||||
def test_ms_kernel_gpu_pynative_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="ms_hybrid"
|
||||
Description: gpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
|
||||
Feature: test case for Custom op with func_type="ms_kernel"
|
||||
Description: gpu test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
ms_hybrid_cast_pyfunc()
|
||||
ms_hybrid_cast_with_infer()
|
||||
ms_hybrid_cast_without_infer()
|
||||
ms_hybrid_allocate()
|
||||
ms_hybrid_grid()
|
||||
ms_kernel_cast_pyfunc()
|
||||
ms_kernel_cast_with_infer()
|
||||
ms_kernel_cast_without_infer()
|
||||
ms_kernel_allocate()
|
||||
ms_kernel_grid()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_hybrid_cpu_graph_mode():
|
||||
def test_ms_kernel_cpu_graph_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="ms_hybrid"
|
||||
Description: cpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
|
||||
Feature: test case for Custom op with func_type="ms_kernel"
|
||||
Description: cpu test case, Python DSL with ms_kernel decorator in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
if platform.system().lower() in {"windows", "darwin"}:
|
||||
|
@ -274,22 +274,22 @@ def test_ms_hybrid_cpu_graph_mode():
|
|||
pass
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
ms_hybrid_allocate_cpu()
|
||||
ms_hybrid_grid_cpu()
|
||||
ms_kernel_allocate_cpu()
|
||||
ms_kernel_grid_cpu()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_hybrid_cpu_pynative_mode():
|
||||
def test_ms_kernel_cpu_pynative_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="ms_hybrid"
|
||||
Description: cpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
|
||||
Feature: test case for Custom op with func_type="ms_kernel"
|
||||
Description: cpu test case, Python DSL with ms_kernel decorator in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
if platform.system().lower() in {"windows", "darwin"}:
|
||||
pass
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
ms_hybrid_allocate_cpu()
|
||||
ms_hybrid_grid_cpu()
|
||||
ms_kernel_allocate_cpu()
|
||||
ms_kernel_grid_cpu()
|
Loading…
Reference in New Issue