diff --git a/mindspore/python/mindspore/nn/transformer/moe.py b/mindspore/python/mindspore/nn/transformer/moe.py index 2245db22fd3..9cd89799584 100644 --- a/mindspore/python/mindspore/nn/transformer/moe.py +++ b/mindspore/python/mindspore/nn/transformer/moe.py @@ -26,7 +26,7 @@ import mindspore.communication.management as D from mindspore._checkparam import Validator from mindspore.ops import operations as P from mindspore.ops import functional as F -from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.nn.cell import Cell from mindspore.nn.layer import Dense from mindspore.context import ParallelMode @@ -131,7 +131,7 @@ def _check_moe_config(moe_config=None, parallel_config=None): f"should be less than device_num: {device_num}.") -@constexpr +@_primexpr def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim): res = k * tokens_per_group * capacity_factor / expert_dim res_int = int(res) diff --git a/mindspore/python/mindspore/numpy/utils_const.py b/mindspore/python/mindspore/numpy/utils_const.py index dc0da4f70a6..1d5c30e61e9 100644 --- a/mindspore/python/mindspore/numpy/utils_const.py +++ b/mindspore/python/mindspore/numpy/utils_const.py @@ -22,6 +22,7 @@ import operator import mindspore.context as context from mindspore.ops import functional as F from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.common import dtype as mstype from mindspore.common import Tensor from mindspore._c_expression import Tensor as Tensor_ @@ -33,7 +34,7 @@ from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype _check_axis_type = constexpr(validator.check_axis_type) -@constexpr +@_primexpr def _check_shape(shape): """check the shape param to match the numpy style""" # convert tensor to int/list, use followed if statements to do further conversions @@ -66,7 +67,7 @@ def _check_dtype(dtype): return dtype -@constexpr +@_primexpr def _is_shape_empty(shp): """Check whether shape contains zero""" if F.is_sequence_shape_unknown(shp): @@ -78,7 +79,7 @@ def _is_shape_empty(shp): return F.shape_mul(shp) == 0 -@constexpr +@_primexpr def _check_start_normalize(start, ndim): """check and normalize start argument for rollaxis.""" if start < 0: @@ -86,7 +87,7 @@ def _check_start_normalize(start, ndim): return start -@constexpr +@_primexpr def _check_axes_range(axes, ndim): """ Check axes type and normalize the negative axes. @@ -112,6 +113,7 @@ def _get_device(): return context.get_context('device_target') +@_primexpr def _infer_out_shape(*shapes): """ Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. @@ -126,7 +128,7 @@ def _infer_out_shape(*shapes): return tuple(shape_out) -@constexpr +@_primexpr def _can_broadcast(*shapes): """ Returns Ture if shapes can broadcast, False if they cannot. @@ -135,13 +137,13 @@ def _can_broadcast(*shapes): return True -@constexpr +@_primexpr def _check_axis_in_range(axis, ndim): """Checks axes are with the bounds of ndim""" return axis - axis // ndim * ndim -@constexpr +@_primexpr def _check_axis_valid(axes, ndim): """ Checks axes are valid given ndim, and returns axes that can be passed @@ -156,7 +158,7 @@ def _check_axis_valid(axes, ndim): return (_check_axis_in_range(axes, ndim),) -@constexpr +@_primexpr def _tile_size(shape, out_shape, ndim): """Returns tile_size such that shape*tile_size = out_shape""" size = [1]*ndim @@ -230,7 +232,7 @@ def _raise_unimplemented_error(info, param=None): raise NotImplementedError(info + f"{param}") -@constexpr +@_primexpr def _empty(dtype, shape): """Returns an uninitialized array with dtype and shape.""" return Tensor_(dtype, shape) @@ -256,13 +258,13 @@ def _promote_for_trigonometric(dtype): return res_dtype -@constexpr +@_primexpr def _max(*args): """Returns the maximum value.""" return max(*args) -@constexpr +@_primexpr def _min(*args): """"Returns the minimum value.""" return min(*args) @@ -290,7 +292,7 @@ def _check_is_int(dtype): return isinstance(dtype, typing.Int) -@constexpr +@_primexpr def _canonicalize_axis(axis, ndim): """ Check axes are within the number of dimensions of tensor x and normalize the negative axes. @@ -356,7 +358,7 @@ def _broadcast_tuples(tup1, tup2): return tup1, tup2 -@constexpr +@_primexpr def _expanded_shape(ndim, axis_size, axis): """ Returns a shape with size = 1 for all dimensions @@ -365,7 +367,7 @@ def _expanded_shape(ndim, axis_size, axis): return tuple([axis_size if i == axis else 1 for i in range(ndim)]) -@constexpr +@_primexpr def _add_unit_axes(shape, ndim, append=False): """ Prepends shape with 1s so that it has the number of dimensions ndim. @@ -401,7 +403,7 @@ def _type_convert(force, obj): return force(obj) -@constexpr +@_primexpr def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False): """ Generates a new list/tuple by list comprehension. @@ -432,7 +434,7 @@ def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False): return res -@constexpr +@_primexpr def _tuple_setitem(tup, idx, value): """ Returns a tuple with specified `idx` set to `value`. @@ -459,7 +461,7 @@ def _ceil(number): return math.ceil(number) -@constexpr +@_primexpr def _seq_prod(seq1, seq2): """Returns the element-wise product of seq1 and seq2.""" return tuple(map(lambda x, y: x*y, seq1, seq2)) @@ -471,7 +473,7 @@ def _make_tensor(val, dtype): return Tensor(val, dtype) -@constexpr +@_primexpr def _tuple_slice(tup, start, end): """get sliced tuple from start and end.""" return tup[start:end] diff --git a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py index 704d2a4fbbc..5b62fa2cf62 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py @@ -26,6 +26,7 @@ from mindspore.ops.functional import broadcast_gradient_args from mindspore.ops import functional as F from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.common import dtype as mstype from mindspore.common.sparse_tensor import RowTensorInner from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index @@ -203,7 +204,7 @@ def get_bprop_flatten(self): return bprop -@constexpr +@_primexpr def _tile_shape(multiples, shapex): """Calculate [1,2], [3, 4] -> [1,3,2,4].""" len_muli = len(multiples) @@ -316,7 +317,7 @@ def get_bprop_embedding_lookup(self): return bprop_sparse -@constexpr +@_primexpr def make_begin(shp): """Creates a tuple with zero according to the shape.""" begin = tuple([0 for _ in shp]) @@ -347,7 +348,7 @@ def get_bprop_padding(self): return bprop -@constexpr +@_primexpr def _concat_grad_uniform(input_shapes, input_nums): """Helper function for bprop of Concat""" is_uniform = True @@ -411,7 +412,7 @@ def get_bprop_slice(self): return bprop -@constexpr +@_primexpr def _generate_inverse_index(x_shape, axis): x_rank = len(x_shape) index = tuple(range(x_rank)) @@ -421,7 +422,7 @@ def _generate_inverse_index(x_shape, axis): return perm -@constexpr +@_primexpr def _regenerate_output_shape(x_shp, ind_shp, axis): rank = len(x_shp) if axis < 0: diff --git a/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py b/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py index ef641ec247d..a2040843b32 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from mindspore.ops import operations as P from mindspore.ops import functional as F -from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.ops._grad.grad_base import bprop_getters, sum_grad_reduce_axis from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner @@ -44,7 +44,7 @@ def _get_matrix_diag_part_assist(x_shape, x_dtype): return assist -@constexpr +@_primexpr def _get_min(x): return min(x) diff --git a/mindspore/python/mindspore/ops/_grad/grad_math_ops.py b/mindspore/python/mindspore/ops/_grad/grad_math_ops.py index 4a61b477149..8ec8ad1b0bc 100755 --- a/mindspore/python/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_math_ops.py @@ -30,6 +30,7 @@ from mindspore.ops._grad.grad_base import convert_to_tensor from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo from mindspore.ops.operations import array_ops as A @@ -856,7 +857,7 @@ def get_bprop_cumsum(self): return bprop -@constexpr +@_primexpr def _split_shape_index(input_shape, axis): """Calculate reduce_prod grad transpose indices and perm shape.""" rank = len(input_shape) diff --git a/mindspore/python/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/python/mindspore/ops/_grad/grad_nn_ops.py index 8db71ef4b11..1423791bbd3 100755 --- a/mindspore/python/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_nn_ops.py @@ -17,7 +17,7 @@ from mindspore import context from mindspore.common import dtype as mstype from mindspore.common.tensor import Tensor -from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.ops.operations import nn_ops as nps from mindspore.ops._grad.grad_base import bprop_getters, dyn_size, create_tensor_by_element, dyn_rank from mindspore.ops import functional as F @@ -29,7 +29,7 @@ from mindspore.ops.operations import _rl_inner_ops as rl_ops from mindspore.ops._utils.utils import range_op, get_1d_shape -@constexpr +@_primexpr def bias_add_gradgrad_helper(shape, bias_shape, data_format): """Helper function of BiasGradGrad to calculate expanded shape.""" new_shape = list(shape) @@ -640,7 +640,7 @@ def get_bprop_sigmoid_grad(self): return bprop -@constexpr +@_primexpr def _get_transpose_axis(x_shp, axis): rank = len(x_shp) if axis < 0: diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py index e1bce261709..4713fbbd0c7 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py @@ -74,6 +74,7 @@ from mindspore.ops.operations.math_ops import Fmax from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis from mindspore.ops._grad.grad_math_ops import binop_grad_common @@ -85,7 +86,7 @@ dyn_shape_op = P.TensorShape() _conj = P.Conj() -@constexpr +@_primexpr def _generate_perm(x_dim): perm = tuple(range(x_dim - 2)) return perm @@ -358,7 +359,7 @@ def get_bprop_index_addcmul(self): return bprop -@constexpr +@_primexpr def renew_dim(shape, dim): """ Re-new dims""" new_dim = dim if dim >= 0 else len(shape) + dim @@ -687,7 +688,7 @@ def get_bprop_matrix_solve(self): return bprop -@constexpr +@_primexpr def _generate_perm_matrix_solve_ls(x_dim): perm = tuple(range(x_dim - 2)) perm = perm + (x_dim-1, x_dim-2) @@ -1853,7 +1854,7 @@ def _fft_rank_offset(norm_shape, rank): return norm_shape_product -@constexpr +@_primexpr def _fft_with_size_back_norm(norm_shape, norm, inverse, rank): """generate reverse term for fft_with_size""" if inverse is False: @@ -1873,7 +1874,7 @@ def _fft_with_size_back_norm(norm_shape, norm, inverse, rank): return norm_ -@constexpr +@_primexpr def _rfft_norm(norm_shape, norm, rank): """generate norm for rfft""" norm_ = 1.0 @@ -1886,7 +1887,7 @@ def _rfft_norm(norm_shape, norm, rank): return norm_ -@constexpr +@_primexpr def _get_last_dim_slice_shape(tensor_shape, index): """generate shape for slice last tensor""" from_shape = [0 for x in tensor_shape] @@ -1899,7 +1900,7 @@ def _get_last_dim_slice_shape(tensor_shape, index): return tuple(from_shape), tuple(to_shape) -@constexpr +@_primexpr def _rfft_reshape(shape_a, shape_b): """generate rfft shape for reshape""" new_shape = list(shape_b) @@ -1908,7 +1909,7 @@ def _rfft_reshape(shape_a, shape_b): return tuple(new_shape) -@constexpr +@_primexpr def _rfft_tile_reshape(shape_a): """generate rfft shape for tile""" reshape_a = list(shape_a) @@ -1917,7 +1918,7 @@ def _rfft_tile_reshape(shape_a): return tuple(reshape_a) -@constexpr +@_primexpr def _rfft_last_term_shape(shape_a, shape_b): """generate rfft shape for last term""" new_shape = list(shape_b) @@ -1926,13 +1927,13 @@ def _rfft_last_term_shape(shape_a, shape_b): return tuple(new_shape) -@constexpr +@_primexpr def _batch_matmul_shape_increase(shape_before): """increase tensor shape for batch_matmul""" return (1, *shape_before) -@constexpr +@_primexpr def _batch_matmul_shape_decrease(matrix_shape): """decrease tensor shape after batch_matmul""" shape_tmp = list(matrix_shape) diff --git a/mindspore/python/mindspore/ops/_utils/utils.py b/mindspore/python/mindspore/ops/_utils/utils.py index 179143c1c24..fd4658d1211 100644 --- a/mindspore/python/mindspore/ops/_utils/utils.py +++ b/mindspore/python/mindspore/ops/_utils/utils.py @@ -21,6 +21,7 @@ from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common import dtype as mstype from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.common._utils import is_dim_unknown @@ -116,7 +117,7 @@ def range_op(start, limit, delta, dtype): return output_tensor -@constexpr +@_primexpr def get_1d_shape(in_shape): """helper function to get 1d shape.""" out_shape = 1 @@ -125,7 +126,7 @@ def get_1d_shape(in_shape): return (out_shape,) -@constexpr +@_primexpr def generate_shape_index(out_shape, indices_shape, axis): out_rank = len(out_shape) ind_rank = len(indices_shape) diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 15cc9c13046..9739256f2f4 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common.tensor import Tensor from mindspore.ops import operations as P -from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import constexpr, _primexpr from mindspore import log as logger from mindspore import context @@ -912,7 +912,7 @@ def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_posit return shape -@constexpr +@_primexpr def infer_out_shape(*shapes): """ Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 14cfafeac42..6835ca2b87f 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -21,6 +21,7 @@ import numpy as np import mindspore.ops as ops from mindspore.common import dtype as mstype from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops.operations._inner_ops import Cummin @@ -7770,7 +7771,7 @@ def _tile_size(shape, out_shape, ndim): return tuple(size) -@constexpr +@_primexpr def _expand(x, ndim): """Expand x to ndim from axis, which can be 0 or -1.""" rank_op = _get_cache_prim(P.Rank)() diff --git a/mindspore/python/mindspore/ops/primitive.py b/mindspore/python/mindspore/ops/primitive.py index 9ff84cc1110..4b0b4cac04a 100644 --- a/mindspore/python/mindspore/ops/primitive.py +++ b/mindspore/python/mindspore/ops/primitive.py @@ -732,6 +732,24 @@ def prim_attr_register(fn): return deco +def _check_contains_variable(item_dtype, item_value): + """ + Check whether the item is or contains variable. + """ + if isinstance(item_value, (list, tuple)): + for i in range(len(item_value)): + if _check_contains_variable(item_dtype[i], item_value[i]): + return True + elif isinstance(item_value, dict): + for i in range(len(item_value)): + if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]): + return True + for i in range(len(item_value)): + if _check_contains_variable(item_dtype[i], list(item_value.values())[i]): + return True + return item_dtype is not None and item_value is None + + def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True): """ Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function @@ -786,9 +804,60 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr def __infer__(self, *args): value_args = [] for item in args: - if (item["dtype"] is not None and item["value"] is None and check): + if _check_contains_variable(item["dtype"], item["value"]) and check: logger.warning("The \"" + self.name + "\" is a constexpr function." \ " The input arguments must be all constant value.") + value_args.append(item["value"]) + return {'dtype': None, 'shape': None, 'value': fn(*value_args)} + + def __call__(self, *args, **kwargs): + return fn(*args, **kwargs) + + if get_instance: + return CompileOp() + return CompileOp + + if fn is not None: + return deco(fn) + return deco + + +def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True): + """ + _primexpr is similar as constexpr except that when the input to the function decorated by _primexpr contains + variable, the function will be compiled as graph. + + _primexpr is only for internal use. + + Args: + fn (function): A `fn` use as the infer_value of the output operator. Default: None. + get_instance (bool): If true, return the instance of operator, + otherwise return the operator class. Default: True. + name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None. + reuse_result (bool): If true, the operator will be executed once and reuse the result next time, + otherwise the operator will always be executed. Default: True. + """ + def deco(fn): + """Decorator for CompileOp.""" + + class CompileOp(PrimitiveWithInfer): + """ + CompileOp is a temporary operator used to execute the constexpr function. + """ + + def __init__(self): + op_name = name if name else fn.__name__ + PrimitiveWithInfer.__init__(self, op_name) + self.set_const_prim(True) + self.fn = fn + self.add_prim_attr('constexpr_prim', True) + if not reuse_result: + self.add_prim_attr('forbid_reuse_result', True) + + def __infer__(self, *args): + value_args = [] + for item in args: + if _check_contains_variable(item["dtype"], item["value"]): return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)} value_args.append(item["value"]) return {'dtype': None, 'shape': None, 'value': fn(*value_args)} diff --git a/tests/ut/python/graph_syntax/test_constexpr.py b/tests/ut/python/graph_syntax/test_constexpr.py index 99c125a5538..9e976284191 100644 --- a/tests/ut/python/graph_syntax/test_constexpr.py +++ b/tests/ut/python/graph_syntax/test_constexpr.py @@ -15,6 +15,7 @@ """ test_dict_get """ from mindspore import Tensor, jit, context, mutable from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import _primexpr context.set_context(mode=context.GRAPH_MODE) @@ -105,3 +106,48 @@ def test_constexpr_input_with_mutable_list(): out = foo(Tensor([1])) assert out == 3 + + +@_primexpr +def count_none_2(arg): + if arg is None: + raise ValueError("The arg is None") + a = 0 + if not isinstance(a, (list, tuple)): + if a is None: + return 1 + return 0 + for e in arg: + if e is None: + a += 1 + return a + + +def test__primexpr_with_mutable_list(): + """ + Feature: constexpr with mutable list. + Description: If mutable list is used as constexpr input, all elements will be converted to None. + Expectation: No exception. + """ + @jit + def foo(x): + arg = mutable([Tensor([1]), Tensor([2]), x]) + return count_none_2(arg) + + out = foo(Tensor([1])) + assert out == 0 + + +def test__primexpr_with_mutable_scalar(): + """ + Feature: constexpr with mutable list. + Description: If mutable list is used as constexpr input, all elements will be converted to None. + Expectation: No exception. + """ + @jit + def foo(): + arg = mutable(2) + return count_none_2(arg) + + out = foo() + assert out == 0