multinomial remove UniformReal cache, rrelu uniformsupport dynamic shape.
This commit is contained in:
parent
504c1735fd
commit
890683caa7
|
@ -44,6 +44,7 @@ mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
|
|||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm
|
||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:matrix_norm
|
||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:einsum
|
||||
mindspore/mindspore/python/mindspore/ops/function/random_func.py:multinomial
|
||||
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
|
||||
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
|
||||
mindspore/mindspore/python/mindspore/common/parameter.py:set_data
|
||||
|
|
|
@ -3591,6 +3591,9 @@ def rrelu(input, lower=1.0 / 8, upper=1.0 / 3):
|
|||
_lower = Tensor(lower, mstype.float32)
|
||||
_upper = Tensor(upper, mstype.float32)
|
||||
_size = input.shape
|
||||
if ops.is_sequence_value_unknown(_size):
|
||||
dyn_shape = _get_cache_prim(P.TensorShape)()
|
||||
_size = dyn_shape(input)
|
||||
sign_matrix = _get_cache_prim(P.Sign)()(input)
|
||||
negative_filter = sign_matrix.clip(None, 0)
|
||||
positive_filter = sign_matrix.clip(0, None)
|
||||
|
|
|
@ -1351,7 +1351,9 @@ def multinomial(input, num_samples, replacement=True, seed=None):
|
|||
n_dist = 1
|
||||
if len(shape(input)) > 1:
|
||||
n_dist = shape(input)[-2]
|
||||
random_uniform = _get_cache_prim(P.UniformReal)(seed1, seed2)((n_dist * shape(input)[-1],))
|
||||
random_uniform_real = P.UniformReal(seed1, seed2)
|
||||
random_cache_op = _set_prim_op_user_data(random_uniform_real, "random_cache", False)
|
||||
random_uniform = random_cache_op((n_dist * shape(input)[-1],))
|
||||
if n_dist != 1:
|
||||
random_uniform = reshape(random_uniform, (n_dist, shape(input)[-1]))
|
||||
real_div = _get_cache_prim(P.RealDiv)()
|
||||
|
|
Loading…
Reference in New Issue