forked from mindspore-Ecosystem/mindspore
commit
4fb8cbea08
|
@ -26,7 +26,7 @@ import mindspore.communication.management as D
|
|||
from mindspore._checkparam import Validator
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer import Dense
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -131,7 +131,7 @@ def _check_moe_config(moe_config=None, parallel_config=None):
|
|||
f"should be less than device_num: {device_num}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim):
|
||||
res = k * tokens_per_group * capacity_factor / expert_dim
|
||||
res_int = int(res)
|
||||
|
|
|
@ -22,6 +22,7 @@ import operator
|
|||
import mindspore.context as context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
|
@ -33,7 +34,7 @@ from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype
|
|||
_check_axis_type = constexpr(validator.check_axis_type)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _check_shape(shape):
|
||||
"""check the shape param to match the numpy style"""
|
||||
# convert tensor to int/list, use followed if statements to do further conversions
|
||||
|
@ -66,7 +67,7 @@ def _check_dtype(dtype):
|
|||
return dtype
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _is_shape_empty(shp):
|
||||
"""Check whether shape contains zero"""
|
||||
if F.is_sequence_shape_unknown(shp):
|
||||
|
@ -78,7 +79,7 @@ def _is_shape_empty(shp):
|
|||
return F.shape_mul(shp) == 0
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _check_start_normalize(start, ndim):
|
||||
"""check and normalize start argument for rollaxis."""
|
||||
if start < 0:
|
||||
|
@ -86,7 +87,7 @@ def _check_start_normalize(start, ndim):
|
|||
return start
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _check_axes_range(axes, ndim):
|
||||
"""
|
||||
Check axes type and normalize the negative axes.
|
||||
|
@ -112,6 +113,7 @@ def _get_device():
|
|||
return context.get_context('device_target')
|
||||
|
||||
|
||||
@_primexpr
|
||||
def _infer_out_shape(*shapes):
|
||||
"""
|
||||
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
||||
|
@ -126,7 +128,7 @@ def _infer_out_shape(*shapes):
|
|||
return tuple(shape_out)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _can_broadcast(*shapes):
|
||||
"""
|
||||
Returns Ture if shapes can broadcast, False if they cannot.
|
||||
|
@ -135,13 +137,13 @@ def _can_broadcast(*shapes):
|
|||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
return axis - axis // ndim * ndim
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _check_axis_valid(axes, ndim):
|
||||
"""
|
||||
Checks axes are valid given ndim, and returns axes that can be passed
|
||||
|
@ -156,7 +158,7 @@ def _check_axis_valid(axes, ndim):
|
|||
return (_check_axis_in_range(axes, ndim),)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _tile_size(shape, out_shape, ndim):
|
||||
"""Returns tile_size such that shape*tile_size = out_shape"""
|
||||
size = [1]*ndim
|
||||
|
@ -230,7 +232,7 @@ def _raise_unimplemented_error(info, param=None):
|
|||
raise NotImplementedError(info + f"{param}")
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _empty(dtype, shape):
|
||||
"""Returns an uninitialized array with dtype and shape."""
|
||||
return Tensor_(dtype, shape)
|
||||
|
@ -256,13 +258,13 @@ def _promote_for_trigonometric(dtype):
|
|||
return res_dtype
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _max(*args):
|
||||
"""Returns the maximum value."""
|
||||
return max(*args)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _min(*args):
|
||||
""""Returns the minimum value."""
|
||||
return min(*args)
|
||||
|
@ -290,7 +292,7 @@ def _check_is_int(dtype):
|
|||
return isinstance(dtype, typing.Int)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _canonicalize_axis(axis, ndim):
|
||||
"""
|
||||
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
|
||||
|
@ -356,7 +358,7 @@ def _broadcast_tuples(tup1, tup2):
|
|||
return tup1, tup2
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _expanded_shape(ndim, axis_size, axis):
|
||||
"""
|
||||
Returns a shape with size = 1 for all dimensions
|
||||
|
@ -365,7 +367,7 @@ def _expanded_shape(ndim, axis_size, axis):
|
|||
return tuple([axis_size if i == axis else 1 for i in range(ndim)])
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _add_unit_axes(shape, ndim, append=False):
|
||||
"""
|
||||
Prepends shape with 1s so that it has the number of dimensions ndim.
|
||||
|
@ -401,7 +403,7 @@ def _type_convert(force, obj):
|
|||
return force(obj)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
|
||||
"""
|
||||
Generates a new list/tuple by list comprehension.
|
||||
|
@ -432,7 +434,7 @@ def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
|
|||
return res
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _tuple_setitem(tup, idx, value):
|
||||
"""
|
||||
Returns a tuple with specified `idx` set to `value`.
|
||||
|
@ -459,7 +461,7 @@ def _ceil(number):
|
|||
return math.ceil(number)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _seq_prod(seq1, seq2):
|
||||
"""Returns the element-wise product of seq1 and seq2."""
|
||||
return tuple(map(lambda x, y: x*y, seq1, seq2))
|
||||
|
@ -471,7 +473,7 @@ def _make_tensor(val, dtype):
|
|||
return Tensor(val, dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _tuple_slice(tup, start, end):
|
||||
"""get sliced tuple from start and end."""
|
||||
return tup[start:end]
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.ops.functional import broadcast_gradient_args
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.sparse_tensor import RowTensorInner
|
||||
from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index
|
||||
|
@ -203,7 +204,7 @@ def get_bprop_flatten(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _tile_shape(multiples, shapex):
|
||||
"""Calculate [1,2], [3, 4] -> [1,3,2,4]."""
|
||||
len_muli = len(multiples)
|
||||
|
@ -316,7 +317,7 @@ def get_bprop_embedding_lookup(self):
|
|||
return bprop_sparse
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def make_begin(shp):
|
||||
"""Creates a tuple with zero according to the shape."""
|
||||
begin = tuple([0 for _ in shp])
|
||||
|
@ -347,7 +348,7 @@ def get_bprop_padding(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _concat_grad_uniform(input_shapes, input_nums):
|
||||
"""Helper function for bprop of Concat"""
|
||||
is_uniform = True
|
||||
|
@ -411,7 +412,7 @@ def get_bprop_slice(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _generate_inverse_index(x_shape, axis):
|
||||
x_rank = len(x_shape)
|
||||
index = tuple(range(x_rank))
|
||||
|
@ -421,7 +422,7 @@ def _generate_inverse_index(x_shape, axis):
|
|||
return perm
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _regenerate_output_shape(x_shp, ind_shp, axis):
|
||||
rank = len(x_shp)
|
||||
if axis < 0:
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
from __future__ import absolute_import
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.ops._grad.grad_base import bprop_getters, sum_grad_reduce_axis
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
@ -44,7 +44,7 @@ def _get_matrix_diag_part_assist(x_shape, x_dtype):
|
|||
return assist
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _get_min(x):
|
||||
return min(x)
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from mindspore.ops._grad.grad_base import convert_to_tensor
|
|||
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
|
||||
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo
|
||||
from mindspore.ops.operations import array_ops as A
|
||||
|
@ -856,7 +857,7 @@ def get_bprop_cumsum(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _split_shape_index(input_shape, axis):
|
||||
"""Calculate reduce_prod grad transpose indices and perm shape."""
|
||||
rank = len(input_shape)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops._grad.grad_base import bprop_getters, dyn_size, create_tensor_by_element, dyn_rank
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -29,7 +29,7 @@ from mindspore.ops.operations import _rl_inner_ops as rl_ops
|
|||
from mindspore.ops._utils.utils import range_op, get_1d_shape
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def bias_add_gradgrad_helper(shape, bias_shape, data_format):
|
||||
"""Helper function of BiasGradGrad to calculate expanded shape."""
|
||||
new_shape = list(shape)
|
||||
|
@ -640,7 +640,7 @@ def get_bprop_sigmoid_grad(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _get_transpose_axis(x_shp, axis):
|
||||
rank = len(x_shp)
|
||||
if axis < 0:
|
||||
|
|
|
@ -74,6 +74,7 @@ from mindspore.ops.operations.math_ops import Fmax
|
|||
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
|
||||
from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
|
||||
from mindspore.ops._grad.grad_math_ops import binop_grad_common
|
||||
|
@ -85,7 +86,7 @@ dyn_shape_op = P.TensorShape()
|
|||
_conj = P.Conj()
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _generate_perm(x_dim):
|
||||
perm = tuple(range(x_dim - 2))
|
||||
return perm
|
||||
|
@ -358,7 +359,7 @@ def get_bprop_index_addcmul(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def renew_dim(shape, dim):
|
||||
""" Re-new dims"""
|
||||
new_dim = dim if dim >= 0 else len(shape) + dim
|
||||
|
@ -687,7 +688,7 @@ def get_bprop_matrix_solve(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _generate_perm_matrix_solve_ls(x_dim):
|
||||
perm = tuple(range(x_dim - 2))
|
||||
perm = perm + (x_dim-1, x_dim-2)
|
||||
|
@ -1853,7 +1854,7 @@ def _fft_rank_offset(norm_shape, rank):
|
|||
return norm_shape_product
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
|
||||
"""generate reverse term for fft_with_size"""
|
||||
if inverse is False:
|
||||
|
@ -1873,7 +1874,7 @@ def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
|
|||
return norm_
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _rfft_norm(norm_shape, norm, rank):
|
||||
"""generate norm for rfft"""
|
||||
norm_ = 1.0
|
||||
|
@ -1886,7 +1887,7 @@ def _rfft_norm(norm_shape, norm, rank):
|
|||
return norm_
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _get_last_dim_slice_shape(tensor_shape, index):
|
||||
"""generate shape for slice last tensor"""
|
||||
from_shape = [0 for x in tensor_shape]
|
||||
|
@ -1899,7 +1900,7 @@ def _get_last_dim_slice_shape(tensor_shape, index):
|
|||
return tuple(from_shape), tuple(to_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _rfft_reshape(shape_a, shape_b):
|
||||
"""generate rfft shape for reshape"""
|
||||
new_shape = list(shape_b)
|
||||
|
@ -1908,7 +1909,7 @@ def _rfft_reshape(shape_a, shape_b):
|
|||
return tuple(new_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _rfft_tile_reshape(shape_a):
|
||||
"""generate rfft shape for tile"""
|
||||
reshape_a = list(shape_a)
|
||||
|
@ -1917,7 +1918,7 @@ def _rfft_tile_reshape(shape_a):
|
|||
return tuple(reshape_a)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _rfft_last_term_shape(shape_a, shape_b):
|
||||
"""generate rfft shape for last term"""
|
||||
new_shape = list(shape_b)
|
||||
|
@ -1926,13 +1927,13 @@ def _rfft_last_term_shape(shape_a, shape_b):
|
|||
return tuple(new_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _batch_matmul_shape_increase(shape_before):
|
||||
"""increase tensor shape for batch_matmul"""
|
||||
return (1, *shape_before)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _batch_matmul_shape_decrease(matrix_shape):
|
||||
"""decrease tensor shape after batch_matmul"""
|
||||
shape_tmp = list(matrix_shape)
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore._checkparam import Rel
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.common._utils import is_dim_unknown
|
||||
|
||||
|
||||
|
@ -116,7 +117,7 @@ def range_op(start, limit, delta, dtype):
|
|||
return output_tensor
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def get_1d_shape(in_shape):
|
||||
"""helper function to get 1d shape."""
|
||||
out_shape = 1
|
||||
|
@ -125,7 +126,7 @@ def get_1d_shape(in_shape):
|
|||
return (out_shape,)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def generate_shape_index(out_shape, indices_shape, axis):
|
||||
out_rank = len(out_shape)
|
||||
ind_rank = len(indices_shape)
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import constexpr, _primexpr
|
||||
from mindspore import log as logger
|
||||
from mindspore import context
|
||||
|
||||
|
@ -912,7 +912,7 @@ def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_posit
|
|||
return shape
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def infer_out_shape(*shapes):
|
||||
"""
|
||||
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
||||
|
|
|
@ -21,6 +21,7 @@ import numpy as np
|
|||
import mindspore.ops as ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations._inner_ops import Cummin
|
||||
|
@ -7770,7 +7771,7 @@ def _tile_size(shape, out_shape, ndim):
|
|||
return tuple(size)
|
||||
|
||||
|
||||
@constexpr
|
||||
@_primexpr
|
||||
def _expand(x, ndim):
|
||||
"""Expand x to ndim from axis, which can be 0 or -1."""
|
||||
rank_op = _get_cache_prim(P.Rank)()
|
||||
|
|
|
@ -732,6 +732,24 @@ def prim_attr_register(fn):
|
|||
return deco
|
||||
|
||||
|
||||
def _check_contains_variable(item_dtype, item_value):
|
||||
"""
|
||||
Check whether the item is or contains variable.
|
||||
"""
|
||||
if isinstance(item_value, (list, tuple)):
|
||||
for i in range(len(item_value)):
|
||||
if _check_contains_variable(item_dtype[i], item_value[i]):
|
||||
return True
|
||||
elif isinstance(item_value, dict):
|
||||
for i in range(len(item_value)):
|
||||
if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]):
|
||||
return True
|
||||
for i in range(len(item_value)):
|
||||
if _check_contains_variable(item_dtype[i], list(item_value.values())[i]):
|
||||
return True
|
||||
return item_dtype is not None and item_value is None
|
||||
|
||||
|
||||
def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True):
|
||||
"""
|
||||
Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
|
||||
|
@ -786,9 +804,60 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
|
|||
def __infer__(self, *args):
|
||||
value_args = []
|
||||
for item in args:
|
||||
if (item["dtype"] is not None and item["value"] is None and check):
|
||||
if _check_contains_variable(item["dtype"], item["value"]) and check:
|
||||
logger.warning("The \"" + self.name + "\" is a constexpr function." \
|
||||
" The input arguments must be all constant value.")
|
||||
value_args.append(item["value"])
|
||||
return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if get_instance:
|
||||
return CompileOp()
|
||||
return CompileOp
|
||||
|
||||
if fn is not None:
|
||||
return deco(fn)
|
||||
return deco
|
||||
|
||||
|
||||
def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True):
|
||||
"""
|
||||
_primexpr is similar as constexpr except that when the input to the function decorated by _primexpr contains
|
||||
variable, the function will be compiled as graph.
|
||||
|
||||
_primexpr is only for internal use.
|
||||
|
||||
Args:
|
||||
fn (function): A `fn` use as the infer_value of the output operator. Default: None.
|
||||
get_instance (bool): If true, return the instance of operator,
|
||||
otherwise return the operator class. Default: True.
|
||||
name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None.
|
||||
reuse_result (bool): If true, the operator will be executed once and reuse the result next time,
|
||||
otherwise the operator will always be executed. Default: True.
|
||||
"""
|
||||
def deco(fn):
|
||||
"""Decorator for CompileOp."""
|
||||
|
||||
class CompileOp(PrimitiveWithInfer):
|
||||
"""
|
||||
CompileOp is a temporary operator used to execute the constexpr function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
op_name = name if name else fn.__name__
|
||||
PrimitiveWithInfer.__init__(self, op_name)
|
||||
self.set_const_prim(True)
|
||||
self.fn = fn
|
||||
self.add_prim_attr('constexpr_prim', True)
|
||||
if not reuse_result:
|
||||
self.add_prim_attr('forbid_reuse_result', True)
|
||||
|
||||
def __infer__(self, *args):
|
||||
value_args = []
|
||||
for item in args:
|
||||
if _check_contains_variable(item["dtype"], item["value"]):
|
||||
return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
|
||||
value_args.append(item["value"])
|
||||
return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
""" test_dict_get """
|
||||
from mindspore import Tensor, jit, context, mutable
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -105,3 +106,48 @@ def test_constexpr_input_with_mutable_list():
|
|||
|
||||
out = foo(Tensor([1]))
|
||||
assert out == 3
|
||||
|
||||
|
||||
@_primexpr
|
||||
def count_none_2(arg):
|
||||
if arg is None:
|
||||
raise ValueError("The arg is None")
|
||||
a = 0
|
||||
if not isinstance(a, (list, tuple)):
|
||||
if a is None:
|
||||
return 1
|
||||
return 0
|
||||
for e in arg:
|
||||
if e is None:
|
||||
a += 1
|
||||
return a
|
||||
|
||||
|
||||
def test__primexpr_with_mutable_list():
|
||||
"""
|
||||
Feature: constexpr with mutable list.
|
||||
Description: If mutable list is used as constexpr input, all elements will be converted to None.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
arg = mutable([Tensor([1]), Tensor([2]), x])
|
||||
return count_none_2(arg)
|
||||
|
||||
out = foo(Tensor([1]))
|
||||
assert out == 0
|
||||
|
||||
|
||||
def test__primexpr_with_mutable_scalar():
|
||||
"""
|
||||
Feature: constexpr with mutable list.
|
||||
Description: If mutable list is used as constexpr input, all elements will be converted to None.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
arg = mutable(2)
|
||||
return count_none_2(arg)
|
||||
|
||||
out = foo()
|
||||
assert out == 0
|
||||
|
|
Loading…
Reference in New Issue