multinomial remove UniformReal cache, rrelu uniformsupport dynamic shape.

This commit is contained in:
gaoshuanglong 2023-08-04 11:31:41 +08:00
parent 504c1735fd
commit 890683caa7
3 changed files with 7 additions and 1 deletions

View File

@ -44,6 +44,7 @@ mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm
mindspore/mindspore/python/mindspore/ops/function/math_func.py:matrix_norm
mindspore/mindspore/python/mindspore/ops/function/math_func.py:einsum
mindspore/mindspore/python/mindspore/ops/function/random_func.py:multinomial
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
mindspore/mindspore/python/mindspore/common/parameter.py:set_data

View File

@ -3591,6 +3591,9 @@ def rrelu(input, lower=1.0 / 8, upper=1.0 / 3):
_lower = Tensor(lower, mstype.float32)
_upper = Tensor(upper, mstype.float32)
_size = input.shape
if ops.is_sequence_value_unknown(_size):
dyn_shape = _get_cache_prim(P.TensorShape)()
_size = dyn_shape(input)
sign_matrix = _get_cache_prim(P.Sign)()(input)
negative_filter = sign_matrix.clip(None, 0)
positive_filter = sign_matrix.clip(0, None)

View File

@ -1351,7 +1351,9 @@ def multinomial(input, num_samples, replacement=True, seed=None):
n_dist = 1
if len(shape(input)) > 1:
n_dist = shape(input)[-2]
random_uniform = _get_cache_prim(P.UniformReal)(seed1, seed2)((n_dist * shape(input)[-1],))
random_uniform_real = P.UniformReal(seed1, seed2)
random_cache_op = _set_prim_op_user_data(random_uniform_real, "random_cache", False)
random_uniform = random_cache_op((n_dist * shape(input)[-1],))
if n_dist != 1:
random_uniform = reshape(random_uniform, (n_dist, shape(input)[-1]))
real_div = _get_cache_prim(P.RealDiv)()