forked from mindspore-Ecosystem/mindspore
!4773 Fix empty shape issue in distribution sample functions
Merge pull request !4773 from peixu_ren/custom_bijector
This commit is contained in:
commit
86616ac515
|
@ -45,10 +45,6 @@ def cast_to_tensor(t, hint_type=mstype.float32):
|
|||
return t
|
||||
t_type = hint_type
|
||||
if isinstance(t, Tensor):
|
||||
#check if the Tensor in shape of Tensor(4)
|
||||
if t.dim() == 0:
|
||||
value = t.asnumpy()
|
||||
return Tensor([value], dtype=t_type)
|
||||
#convert the type of tensor to dtype
|
||||
return Tensor(t.asnumpy(), dtype=t_type)
|
||||
if isinstance(t, (list, np.ndarray)):
|
||||
|
@ -56,7 +52,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
|
|||
if isinstance(t, bool):
|
||||
raise TypeError(f'Input cannot be Type Bool')
|
||||
if isinstance(t, (int, float)):
|
||||
return Tensor([t], dtype=t_type)
|
||||
return Tensor(t, dtype=t_type)
|
||||
raise TypeError("Input type is not supported.")
|
||||
|
||||
def convert_to_batch(t, batch_shape, required_type):
|
||||
|
|
|
@ -107,6 +107,7 @@ class Bernoulli(Distribution):
|
|||
self._probs = probs
|
||||
|
||||
# ops needed for the class
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
|
@ -284,8 +285,16 @@ class Bernoulli(Distribution):
|
|||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
||||
if probs1 is None:
|
||||
raise_none_error("probs")
|
||||
origin_shape = shape + self.shape(probs1)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
else:
|
||||
sample_shape = origin_shape
|
||||
l_zero = self.const(0.0)
|
||||
h_one = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed)
|
||||
sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
|
||||
sample = self.less(sample_uniform, probs1)
|
||||
return self.cast(sample, self.dtype)
|
||||
value = self.cast(sample, self.dtype)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
return value
|
||||
|
|
|
@ -111,6 +111,7 @@ class Exponential(Distribution):
|
|||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
# ops needed for the class
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
|
@ -276,8 +277,16 @@ class Exponential(Distribution):
|
|||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
||||
if rate is None:
|
||||
raise_none_error("rate")
|
||||
origin_shape = shape + self.shape(rate)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
else:
|
||||
sample_shape = origin_shape
|
||||
minval = self.const(self.minval)
|
||||
maxval = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed)
|
||||
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
|
||||
sample = -self.log(sample_uniform) / rate
|
||||
return self.cast(sample, self.dtype)
|
||||
value = self.cast(sample, self.dtype)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
return value
|
||||
|
|
|
@ -112,6 +112,7 @@ class Geometric(Distribution):
|
|||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
# ops needed for the class
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
|
@ -283,8 +284,16 @@ class Geometric(Distribution):
|
|||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
||||
if probs1 is None:
|
||||
raise_none_error("probs")
|
||||
origin_shape = shape + self.shape(probs1)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
else:
|
||||
sample_shape = origin_shape
|
||||
minval = self.const(self.minval)
|
||||
maxval = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + self.shape(probs1), minval, maxval, self.seed)
|
||||
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
|
||||
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1))
|
||||
return self.cast(sample, self.dtype)
|
||||
value = self.cast(sample, self.dtype)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
return value
|
||||
|
|
|
@ -114,6 +114,7 @@ class Normal(Distribution):
|
|||
|
||||
|
||||
#ops needed for the class
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.erf = P.Erf()
|
||||
|
@ -305,7 +306,14 @@ class Normal(Distribution):
|
|||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
||||
if sd is None:
|
||||
raise_none_error("sd")
|
||||
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
|
||||
sample_shape = shape + batch_shape
|
||||
batch_shape = self.shape(mean + sd)
|
||||
origin_shape = shape + batch_shape
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
else:
|
||||
sample_shape = origin_shape
|
||||
sample_norm = C.normal(sample_shape, mean, sd, self.seed)
|
||||
return sample_norm
|
||||
value = self.cast(sample_norm, self.dtype)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
return value
|
||||
|
|
|
@ -112,6 +112,7 @@ class Uniform(Distribution):
|
|||
self._high = high
|
||||
|
||||
# ops needed for the class
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
|
@ -327,8 +328,16 @@ class Uniform(Distribution):
|
|||
if high is None:
|
||||
raise_none_error("high")
|
||||
broadcast_shape = self.shape(low + high)
|
||||
origin_shape = shape + broadcast_shape
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
else:
|
||||
sample_shape = origin_shape
|
||||
l_zero = self.const(0.0)
|
||||
h_one = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed)
|
||||
sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
|
||||
sample = (high - low) * sample_uniform + low
|
||||
return self.cast(sample, self.dtype)
|
||||
value = self.cast(sample, self.dtype)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
return value
|
||||
|
|
Loading…
Reference in New Issue