forked from mindspore-Ecosystem/mindspore
!2436 fix nn.PReLU example
Merge pull request !2436 from jiangjinsheng/issue_fix4
This commit is contained in:
commit
c8f26f799b
|
@ -380,7 +380,7 @@ class PReLU(Cell):
|
|||
Tensor, with the same type and shape as the `input_data`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
|
||||
>>> input_x = Tensor(np.random.rand(1, 10, 4, 4), mindspore.float32)
|
||||
>>> prelu = nn.PReLU()
|
||||
>>> prelu(input_x)
|
||||
|
||||
|
|
|
@ -1093,6 +1093,8 @@ class StridedSliceGrad(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
||||
|
||||
def __infer__(self, dy, shapex, begin, end, strides):
|
||||
args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return {'shape': shapex['value'],
|
||||
'dtype': dy['dtype'],
|
||||
'value': None}
|
||||
|
|
|
@ -2619,6 +2619,8 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
|
||||
for elem in block_shape:
|
||||
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||
validator.check_value_type('block_shape element', elem, [int], self.name)
|
||||
|
||||
self.block_shape = block_shape
|
||||
|
||||
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||
|
@ -2661,7 +2663,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|||
The length of block_shape is M correspoding to the number of spatial dimensions.
|
||||
crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
|
||||
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
|
||||
input dimension i+2. It is required that input_shape[i+2]*block_size[i] > crops[i][0]+crops[i][1].
|
||||
input dimension i+2. It is required that input_shape[i+2]*block_shape[i] > crops[i][0]+crops[i][1].
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
@ -2697,6 +2699,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|||
|
||||
for elem in block_shape:
|
||||
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||
validator.check_value_type('block_shape element', elem, [int], self.name)
|
||||
|
||||
self.block_shape = block_shape
|
||||
|
||||
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||
|
|
|
@ -2157,10 +2157,10 @@ class ResizeBilinear(PrimitiveWithInfer):
|
|||
Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`.
|
||||
|
||||
Examples:
|
||||
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32)
|
||||
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
|
||||
>>> resize_bilinear = P.ResizeBilinear((5, 5))
|
||||
>>> result = resize_bilinear(tensor)
|
||||
>>> assert result.shape == (5, 5)
|
||||
>>> assert result.shape == (1, 1, 5, 5)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -2176,6 +2176,7 @@ class ResizeBilinear(PrimitiveWithInfer):
|
|||
return out_shape
|
||||
|
||||
def infer_dtype(self, input_dtype):
|
||||
validator.check_tensor_type_same({'input_dtype': input_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
return mstype.tensor_type(mstype.float32)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue