forked from mindspore-Ecosystem/mindspore
!5946 fix bugs of op Exp, IOU, GroupNorm and Dropout
Merge pull request !5946 from lihongkang/lhk_master
This commit is contained in:
commit
ffeff2fa5b
|
@ -65,18 +65,22 @@ class Dropout(Cell):
|
|||
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
||||
|
||||
Raises:
|
||||
ValueError: If `keep_prob` is not in range (0, 1).
|
||||
ValueError: If `keep_prob` is not in range (0, 1].
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - An N-D Tensor.
|
||||
- **input** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, output tensor with the same shape as the input.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32)
|
||||
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
||||
>>> net = nn.Dropout(keep_prob=0.8)
|
||||
>>> net(x)
|
||||
[[[1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0]],
|
||||
[[1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0]]]
|
||||
"""
|
||||
|
||||
def __init__(self, keep_prob=0.5, seed0=0, seed1=0, dtype=mstype.float32):
|
||||
|
@ -84,6 +88,7 @@ class Dropout(Cell):
|
|||
if keep_prob <= 0 or keep_prob > 1:
|
||||
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
||||
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||
validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
||||
self.keep_prob = keep_prob
|
||||
self.seed0 = seed0
|
||||
self.seed1 = seed1
|
||||
|
@ -107,8 +112,7 @@ class Dropout(Cell):
|
|||
return x
|
||||
|
||||
shape = self.get_shape(x)
|
||||
dtype = P.DType()(x)
|
||||
keep_prob = self.cast(self.keep_prob, dtype)
|
||||
keep_prob = self.cast(self.keep_prob, mstype.float32)
|
||||
output = self.dropout_gen_mask(shape, keep_prob)
|
||||
return self.dropout_do_mask(x, output, keep_prob)
|
||||
|
||||
|
|
|
@ -585,9 +585,18 @@ class GroupNorm(Cell):
|
|||
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> goup_norm_op = nn.GroupNorm(16, 64)
|
||||
>>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
|
||||
>>> goup_norm_op = nn.GroupNorm(2, 2)
|
||||
>>> x = Tensor(np.ones([1, 2, 4, 4], np.float32))
|
||||
>>> goup_norm_op(x)
|
||||
[[[[0. 0. 0. 0.]
|
||||
[0. 0. 0. 0.]
|
||||
[0. 0. 0. 0.]
|
||||
[0. 0. 0. 0.]]
|
||||
|
||||
[[0. 0. 0. 0.]
|
||||
[0. 0. 0. 0.]
|
||||
[0. 0. 0. 0.]
|
||||
[0. 0. 0. 0.]]]]
|
||||
"""
|
||||
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
|
||||
|
|
|
@ -1360,7 +1360,7 @@ class Tile(PrimitiveWithInfer):
|
|||
|
||||
- **multiples** (tuple[int]) - The input tuple is constructed by multiple
|
||||
integers, i.e., :math:`(y_1, y_2, ..., y_S)`. The length of `multiples`
|
||||
can't be smaller than the length of shape in `input_x`.
|
||||
can't be smaller than the length of shape in `input_x`. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same type as the `input_x`.
|
||||
|
@ -1400,7 +1400,7 @@ class Tile(PrimitiveWithInfer):
|
|||
def __infer__(self, x, multiples):
|
||||
multiples_v = multiples['value']
|
||||
x_shp = x['shape']
|
||||
validator.check_value_type("shape", multiples_v, [tuple], self.name)
|
||||
validator.check_value_type("multiples", multiples_v, [tuple], self.name)
|
||||
for i, multiple in enumerate(multiples_v):
|
||||
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
|
||||
validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name)
|
||||
|
|
|
@ -1382,10 +1382,10 @@ class Exp(PrimitiveWithInfer):
|
|||
Returns exponential of a tensor element-wise.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
- **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as the `input_x`.
|
||||
Tensor, has the same shape and dtype as the `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
|
||||
|
@ -1452,7 +1452,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
|
|||
width and determined by the arguments range and nbins.
|
||||
|
||||
Args:
|
||||
dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32".
|
||||
dtype (str): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32".
|
||||
nbins (int): The number of histogram bins, the type is a positive integer.
|
||||
|
||||
Inputs:
|
||||
|
|
|
@ -264,6 +264,9 @@ class IOU(PrimitiveWithInfer):
|
|||
>>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
|
||||
>>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
|
||||
>>> iou(anchor_boxes, gt_boxes)
|
||||
[[0.0, 65504, 65504],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.22253, 0.0, 0.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue