This commit is contained in:
zong-shuai 2023-01-03 19:32:28 +08:00
parent fefa5651e5
commit 2abba63c7f
1 changed files with 11 additions and 0 deletions

View File

@ -69,6 +69,7 @@ class NonDeterministicInts(Primitive):
self.dtype = dtype
self.add_prim_attr("max_length", 1000000)
self.init_prim_io_names(inputs=["shape"], outputs=["output"])
self.add_prim_attr("side_effect_hidden", True)
valid_values = (mstype.int32, mstype.int64, mstype.uint32, mstype.uint64)
Validator.check_type_name("dtype", dtype, valid_values, self.name)
@ -124,6 +125,7 @@ class TruncatedNormal(Primitive):
"""Initialize TruncatedNormal"""
self.dtype = dtype
self.add_prim_attr("max_length", 1000000)
self.add_prim_attr("side_effect_hidden", True)
self.init_prim_io_names(inputs=["shape"], outputs=["output"])
Validator.check_value_type('seed', seed, [int], self.name)
Validator.check_value_type('seed2', seed2, [int], self.name)
@ -308,6 +310,7 @@ class LogNormalReverse(Primitive):
@prim_attr_register
def __init__(self, mean=1.0, std=2.0):
"""Initialize LogNormalReverse"""
self.add_prim_attr("side_effect_hidden", True)
Validator.check_value_type("mean", mean, [float], self.name)
Validator.check_value_type("std", std, [float], self.name)
@ -459,6 +462,7 @@ class ParameterizedTruncatedNormal(Primitive):
"""Initialize ParameterizedTruncatedNormal"""
self.init_prim_io_names(
inputs=['shape', 'mean', 'stdevs', 'min', 'max'], outputs=['y'])
self.add_prim_attr("side_effect_hidden", True)
Validator.check_value_type('seed', seed, [int], self.name)
Validator.check_value_type('seed2', seed2, [int], self.name)
@ -577,6 +581,7 @@ class RandomPoisson(Primitive):
self.init_prim_io_names(inputs=['shape', 'rate'], outputs=['output'])
Validator.check_value_type('seed', seed, [int], self.name)
Validator.check_value_type('seed2', seed2, [int], self.name)
self.add_prim_attr("side_effect_hidden", True)
valid_values = (mstype.int64, mstype.int32,
mstype.float16, mstype.float32, mstype.float64)
Validator.check_type_name("dtype", dtype, valid_values, self.name)
@ -824,6 +829,7 @@ class Multinomial(Primitive):
Validator.check_value_type("dtype", dtype, [mstype.Type], self.name)
valid_values = (mstype.int64, mstype.int32)
Validator.check_type_name("dtype", dtype, valid_values, self.name)
self.add_prim_attr("side_effect_hidden", True)
class MultinomialWithReplacement(Primitive):
@ -853,6 +859,7 @@ class MultinomialWithReplacement(Primitive):
Validator.check_non_negative_int(numsamples, "numsamples", self.name)
Validator.check_value_type("replacement", replacement, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'seed', 'offset'], outputs=['y'])
self.add_prim_attr("side_effect_hidden", True)
class UniformCandidateSampler(Primitive):
@ -900,6 +907,7 @@ class UniformCandidateSampler(Primitive):
"value of range_max", range_max, Rel.LE, self.name)
Validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
self.num_sampled = num_sampled
self.add_prim_attr("side_effect_hidden", True)
@ -948,6 +956,7 @@ class LogUniformCandidateSampler(Primitive):
self.range_max = range_max
self.unique = unique
self.seed = Validator.check_number("seed", seed, 0, Rel.GE, self.name)
self.add_prim_attr("side_effect_hidden", True)
class RandomShuffle(Primitive):
@ -1031,6 +1040,7 @@ class Uniform(Primitive):
Validator.check('minval', minval, 'maxval', maxval, Rel.LE, self.name)
Validator.check_non_negative_float(minval, "minval", self.name)
Validator.check_non_negative_float(maxval, "maxval", self.name)
self.add_prim_attr("side_effect_hidden", True)
class RandpermV2(Primitive):
@ -1062,3 +1072,4 @@ class RandpermV2(Primitive):
valid_values = (mstype.int32, mstype.int64, mstype.int16, mstype.int8, mstype.uint8, mstype.float64
, mstype.float32, mstype.float16)
Validator.check_type_name("dtype", dtype, valid_values, self.name)
self.add_prim_attr("side_effect_hidden", True)