!2436 fix nn.PReLU example

Merge pull request !2436 from jiangjinsheng/issue_fix4
This commit is contained in:
mindspore-ci-bot 2020-06-23 11:11:48 +08:00 committed by Gitee
commit c8f26f799b
4 changed files with 11 additions and 4 deletions

View File

@ -380,7 +380,7 @@ class PReLU(Cell):
Tensor, with the same type and shape as the `input_data`.
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
>>> input_x = Tensor(np.random.rand(1, 10, 4, 4), mindspore.float32)
>>> prelu = nn.PReLU()
>>> prelu(input_x)

View File

@ -1093,6 +1093,8 @@ class StridedSliceGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides):
args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return {'shape': shapex['value'],
'dtype': dy['dtype'],
'value': None}

View File

@ -2619,6 +2619,8 @@ class SpaceToBatchND(PrimitiveWithInfer):
for elem in block_shape:
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
validator.check_value_type('block_shape element', elem, [int], self.name)
self.block_shape = block_shape
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
@ -2661,7 +2663,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
The length of block_shape is M correspoding to the number of spatial dimensions.
crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
input dimension i+2. It is required that input_shape[i+2]*block_size[i] > crops[i][0]+crops[i][1].
input dimension i+2. It is required that input_shape[i+2]*block_shape[i] > crops[i][0]+crops[i][1].
Inputs:
- **input_x** (Tensor) - The input tensor.
@ -2697,6 +2699,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
for elem in block_shape:
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
validator.check_value_type('block_shape element', elem, [int], self.name)
self.block_shape = block_shape
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)

View File

@ -2157,10 +2157,10 @@ class ResizeBilinear(PrimitiveWithInfer):
Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`.
Examples:
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32)
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
>>> resize_bilinear = P.ResizeBilinear((5, 5))
>>> result = resize_bilinear(tensor)
>>> assert result.shape == (5, 5)
>>> assert result.shape == (1, 1, 5, 5)
"""
@prim_attr_register
@ -2176,6 +2176,7 @@ class ResizeBilinear(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, input_dtype):
validator.check_tensor_type_same({'input_dtype': input_dtype}, [mstype.float16, mstype.float32], self.name)
return mstype.tensor_type(mstype.float32)