!48664 Add _primexpr

Merge pull request !48664 from LiangZhibo/func
This commit is contained in:
i-robot 2023-02-10 09:29:23 +00:00 committed by Gitee
commit 4fb8cbea08
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 170 additions and 48 deletions

View File

@ -26,7 +26,7 @@ import mindspore.communication.management as D
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from mindspore.context import ParallelMode
@ -131,7 +131,7 @@ def _check_moe_config(moe_config=None, parallel_config=None):
f"should be less than device_num: {device_num}.")
@constexpr
@_primexpr
def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim):
res = k * tokens_per_group * capacity_factor / expert_dim
res_int = int(res)

View File

@ -22,6 +22,7 @@ import operator
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore._c_expression import Tensor as Tensor_
@ -33,7 +34,7 @@ from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype
_check_axis_type = constexpr(validator.check_axis_type)
@constexpr
@_primexpr
def _check_shape(shape):
"""check the shape param to match the numpy style"""
# convert tensor to int/list, use followed if statements to do further conversions
@ -66,7 +67,7 @@ def _check_dtype(dtype):
return dtype
@constexpr
@_primexpr
def _is_shape_empty(shp):
"""Check whether shape contains zero"""
if F.is_sequence_shape_unknown(shp):
@ -78,7 +79,7 @@ def _is_shape_empty(shp):
return F.shape_mul(shp) == 0
@constexpr
@_primexpr
def _check_start_normalize(start, ndim):
"""check and normalize start argument for rollaxis."""
if start < 0:
@ -86,7 +87,7 @@ def _check_start_normalize(start, ndim):
return start
@constexpr
@_primexpr
def _check_axes_range(axes, ndim):
"""
Check axes type and normalize the negative axes.
@ -112,6 +113,7 @@ def _get_device():
return context.get_context('device_target')
@_primexpr
def _infer_out_shape(*shapes):
"""
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
@ -126,7 +128,7 @@ def _infer_out_shape(*shapes):
return tuple(shape_out)
@constexpr
@_primexpr
def _can_broadcast(*shapes):
"""
Returns Ture if shapes can broadcast, False if they cannot.
@ -135,13 +137,13 @@ def _can_broadcast(*shapes):
return True
@constexpr
@_primexpr
def _check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
return axis - axis // ndim * ndim
@constexpr
@_primexpr
def _check_axis_valid(axes, ndim):
"""
Checks axes are valid given ndim, and returns axes that can be passed
@ -156,7 +158,7 @@ def _check_axis_valid(axes, ndim):
return (_check_axis_in_range(axes, ndim),)
@constexpr
@_primexpr
def _tile_size(shape, out_shape, ndim):
"""Returns tile_size such that shape*tile_size = out_shape"""
size = [1]*ndim
@ -230,7 +232,7 @@ def _raise_unimplemented_error(info, param=None):
raise NotImplementedError(info + f"{param}")
@constexpr
@_primexpr
def _empty(dtype, shape):
"""Returns an uninitialized array with dtype and shape."""
return Tensor_(dtype, shape)
@ -256,13 +258,13 @@ def _promote_for_trigonometric(dtype):
return res_dtype
@constexpr
@_primexpr
def _max(*args):
"""Returns the maximum value."""
return max(*args)
@constexpr
@_primexpr
def _min(*args):
""""Returns the minimum value."""
return min(*args)
@ -290,7 +292,7 @@ def _check_is_int(dtype):
return isinstance(dtype, typing.Int)
@constexpr
@_primexpr
def _canonicalize_axis(axis, ndim):
"""
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
@ -356,7 +358,7 @@ def _broadcast_tuples(tup1, tup2):
return tup1, tup2
@constexpr
@_primexpr
def _expanded_shape(ndim, axis_size, axis):
"""
Returns a shape with size = 1 for all dimensions
@ -365,7 +367,7 @@ def _expanded_shape(ndim, axis_size, axis):
return tuple([axis_size if i == axis else 1 for i in range(ndim)])
@constexpr
@_primexpr
def _add_unit_axes(shape, ndim, append=False):
"""
Prepends shape with 1s so that it has the number of dimensions ndim.
@ -401,7 +403,7 @@ def _type_convert(force, obj):
return force(obj)
@constexpr
@_primexpr
def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
"""
Generates a new list/tuple by list comprehension.
@ -432,7 +434,7 @@ def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
return res
@constexpr
@_primexpr
def _tuple_setitem(tup, idx, value):
"""
Returns a tuple with specified `idx` set to `value`.
@ -459,7 +461,7 @@ def _ceil(number):
return math.ceil(number)
@constexpr
@_primexpr
def _seq_prod(seq1, seq2):
"""Returns the element-wise product of seq1 and seq2."""
return tuple(map(lambda x, y: x*y, seq1, seq2))
@ -471,7 +473,7 @@ def _make_tensor(val, dtype):
return Tensor(val, dtype)
@constexpr
@_primexpr
def _tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
return tup[start:end]

View File

@ -26,6 +26,7 @@ from mindspore.ops.functional import broadcast_gradient_args
from mindspore.ops import functional as F
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.common import dtype as mstype
from mindspore.common.sparse_tensor import RowTensorInner
from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index
@ -203,7 +204,7 @@ def get_bprop_flatten(self):
return bprop
@constexpr
@_primexpr
def _tile_shape(multiples, shapex):
"""Calculate [1,2], [3, 4] -> [1,3,2,4]."""
len_muli = len(multiples)
@ -316,7 +317,7 @@ def get_bprop_embedding_lookup(self):
return bprop_sparse
@constexpr
@_primexpr
def make_begin(shp):
"""Creates a tuple with zero according to the shape."""
begin = tuple([0 for _ in shp])
@ -347,7 +348,7 @@ def get_bprop_padding(self):
return bprop
@constexpr
@_primexpr
def _concat_grad_uniform(input_shapes, input_nums):
"""Helper function for bprop of Concat"""
is_uniform = True
@ -411,7 +412,7 @@ def get_bprop_slice(self):
return bprop
@constexpr
@_primexpr
def _generate_inverse_index(x_shape, axis):
x_rank = len(x_shape)
index = tuple(range(x_rank))
@ -421,7 +422,7 @@ def _generate_inverse_index(x_shape, axis):
return perm
@constexpr
@_primexpr
def _regenerate_output_shape(x_shp, ind_shp, axis):
rank = len(x_shp)
if axis < 0:

View File

@ -18,7 +18,7 @@
from __future__ import absolute_import
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.ops._grad.grad_base import bprop_getters, sum_grad_reduce_axis
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
@ -44,7 +44,7 @@ def _get_matrix_diag_part_assist(x_shape, x_dtype):
return assist
@constexpr
@_primexpr
def _get_min(x):
return min(x)

View File

@ -30,6 +30,7 @@ from mindspore.ops._grad.grad_base import convert_to_tensor
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo
from mindspore.ops.operations import array_ops as A
@ -856,7 +857,7 @@ def get_bprop_cumsum(self):
return bprop
@constexpr
@_primexpr
def _split_shape_index(input_shape, axis):
"""Calculate reduce_prod grad transpose indices and perm shape."""
rank = len(input_shape)

View File

@ -17,7 +17,7 @@
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops._grad.grad_base import bprop_getters, dyn_size, create_tensor_by_element, dyn_rank
from mindspore.ops import functional as F
@ -29,7 +29,7 @@ from mindspore.ops.operations import _rl_inner_ops as rl_ops
from mindspore.ops._utils.utils import range_op, get_1d_shape
@constexpr
@_primexpr
def bias_add_gradgrad_helper(shape, bias_shape, data_format):
"""Helper function of BiasGradGrad to calculate expanded shape."""
new_shape = list(shape)
@ -640,7 +640,7 @@ def get_bprop_sigmoid_grad(self):
return bprop
@constexpr
@_primexpr
def _get_transpose_axis(x_shp, axis):
rank = len(x_shp)
if axis < 0:

View File

@ -74,6 +74,7 @@ from mindspore.ops.operations.math_ops import Fmax
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
from mindspore.ops._grad.grad_math_ops import binop_grad_common
@ -85,7 +86,7 @@ dyn_shape_op = P.TensorShape()
_conj = P.Conj()
@constexpr
@_primexpr
def _generate_perm(x_dim):
perm = tuple(range(x_dim - 2))
return perm
@ -358,7 +359,7 @@ def get_bprop_index_addcmul(self):
return bprop
@constexpr
@_primexpr
def renew_dim(shape, dim):
""" Re-new dims"""
new_dim = dim if dim >= 0 else len(shape) + dim
@ -687,7 +688,7 @@ def get_bprop_matrix_solve(self):
return bprop
@constexpr
@_primexpr
def _generate_perm_matrix_solve_ls(x_dim):
perm = tuple(range(x_dim - 2))
perm = perm + (x_dim-1, x_dim-2)
@ -1853,7 +1854,7 @@ def _fft_rank_offset(norm_shape, rank):
return norm_shape_product
@constexpr
@_primexpr
def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
"""generate reverse term for fft_with_size"""
if inverse is False:
@ -1873,7 +1874,7 @@ def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
return norm_
@constexpr
@_primexpr
def _rfft_norm(norm_shape, norm, rank):
"""generate norm for rfft"""
norm_ = 1.0
@ -1886,7 +1887,7 @@ def _rfft_norm(norm_shape, norm, rank):
return norm_
@constexpr
@_primexpr
def _get_last_dim_slice_shape(tensor_shape, index):
"""generate shape for slice last tensor"""
from_shape = [0 for x in tensor_shape]
@ -1899,7 +1900,7 @@ def _get_last_dim_slice_shape(tensor_shape, index):
return tuple(from_shape), tuple(to_shape)
@constexpr
@_primexpr
def _rfft_reshape(shape_a, shape_b):
"""generate rfft shape for reshape"""
new_shape = list(shape_b)
@ -1908,7 +1909,7 @@ def _rfft_reshape(shape_a, shape_b):
return tuple(new_shape)
@constexpr
@_primexpr
def _rfft_tile_reshape(shape_a):
"""generate rfft shape for tile"""
reshape_a = list(shape_a)
@ -1917,7 +1918,7 @@ def _rfft_tile_reshape(shape_a):
return tuple(reshape_a)
@constexpr
@_primexpr
def _rfft_last_term_shape(shape_a, shape_b):
"""generate rfft shape for last term"""
new_shape = list(shape_b)
@ -1926,13 +1927,13 @@ def _rfft_last_term_shape(shape_a, shape_b):
return tuple(new_shape)
@constexpr
@_primexpr
def _batch_matmul_shape_increase(shape_before):
"""increase tensor shape for batch_matmul"""
return (1, *shape_before)
@constexpr
@_primexpr
def _batch_matmul_shape_decrease(matrix_shape):
"""decrease tensor shape after batch_matmul"""
shape_tmp = list(matrix_shape)

View File

@ -21,6 +21,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import dtype as mstype
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.common._utils import is_dim_unknown
@ -116,7 +117,7 @@ def range_op(start, limit, delta, dtype):
return output_tensor
@constexpr
@_primexpr
def get_1d_shape(in_shape):
"""helper function to get 1d shape."""
out_shape = 1
@ -125,7 +126,7 @@ def get_1d_shape(in_shape):
return (out_shape,)
@constexpr
@_primexpr
def generate_shape_index(out_shape, indices_shape, axis):
out_rank = len(out_shape)
ind_rank = len(indices_shape)

View File

@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import constexpr, _primexpr
from mindspore import log as logger
from mindspore import context
@ -912,7 +912,7 @@ def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_posit
return shape
@constexpr
@_primexpr
def infer_out_shape(*shapes):
"""
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.

View File

@ -21,6 +21,7 @@ import numpy as np
import mindspore.ops as ops
from mindspore.common import dtype as mstype
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations._inner_ops import Cummin
@ -7770,7 +7771,7 @@ def _tile_size(shape, out_shape, ndim):
return tuple(size)
@constexpr
@_primexpr
def _expand(x, ndim):
"""Expand x to ndim from axis, which can be 0 or -1."""
rank_op = _get_cache_prim(P.Rank)()

View File

@ -732,6 +732,24 @@ def prim_attr_register(fn):
return deco
def _check_contains_variable(item_dtype, item_value):
"""
Check whether the item is or contains variable.
"""
if isinstance(item_value, (list, tuple)):
for i in range(len(item_value)):
if _check_contains_variable(item_dtype[i], item_value[i]):
return True
elif isinstance(item_value, dict):
for i in range(len(item_value)):
if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]):
return True
for i in range(len(item_value)):
if _check_contains_variable(item_dtype[i], list(item_value.values())[i]):
return True
return item_dtype is not None and item_value is None
def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True):
"""
Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
@ -786,9 +804,60 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
def __infer__(self, *args):
value_args = []
for item in args:
if (item["dtype"] is not None and item["value"] is None and check):
if _check_contains_variable(item["dtype"], item["value"]) and check:
logger.warning("The \"" + self.name + "\" is a constexpr function." \
" The input arguments must be all constant value.")
value_args.append(item["value"])
return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
def __call__(self, *args, **kwargs):
return fn(*args, **kwargs)
if get_instance:
return CompileOp()
return CompileOp
if fn is not None:
return deco(fn)
return deco
def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True):
"""
_primexpr is similar as constexpr except that when the input to the function decorated by _primexpr contains
variable, the function will be compiled as graph.
_primexpr is only for internal use.
Args:
fn (function): A `fn` use as the infer_value of the output operator. Default: None.
get_instance (bool): If true, return the instance of operator,
otherwise return the operator class. Default: True.
name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None.
reuse_result (bool): If true, the operator will be executed once and reuse the result next time,
otherwise the operator will always be executed. Default: True.
"""
def deco(fn):
"""Decorator for CompileOp."""
class CompileOp(PrimitiveWithInfer):
"""
CompileOp is a temporary operator used to execute the constexpr function.
"""
def __init__(self):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)
self.set_const_prim(True)
self.fn = fn
self.add_prim_attr('constexpr_prim', True)
if not reuse_result:
self.add_prim_attr('forbid_reuse_result', True)
def __infer__(self, *args):
value_args = []
for item in args:
if _check_contains_variable(item["dtype"], item["value"]):
return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
value_args.append(item["value"])
return {'dtype': None, 'shape': None, 'value': fn(*value_args)}

View File

@ -15,6 +15,7 @@
""" test_dict_get """
from mindspore import Tensor, jit, context, mutable
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import _primexpr
context.set_context(mode=context.GRAPH_MODE)
@ -105,3 +106,48 @@ def test_constexpr_input_with_mutable_list():
out = foo(Tensor([1]))
assert out == 3
@_primexpr
def count_none_2(arg):
if arg is None:
raise ValueError("The arg is None")
a = 0
if not isinstance(a, (list, tuple)):
if a is None:
return 1
return 0
for e in arg:
if e is None:
a += 1
return a
def test__primexpr_with_mutable_list():
"""
Feature: constexpr with mutable list.
Description: If mutable list is used as constexpr input, all elements will be converted to None.
Expectation: No exception.
"""
@jit
def foo(x):
arg = mutable([Tensor([1]), Tensor([2]), x])
return count_none_2(arg)
out = foo(Tensor([1]))
assert out == 0
def test__primexpr_with_mutable_scalar():
"""
Feature: constexpr with mutable list.
Description: If mutable list is used as constexpr input, all elements will be converted to None.
Expectation: No exception.
"""
@jit
def foo():
arg = mutable(2)
return count_none_2(arg)
out = foo()
assert out == 0