forked from mindspore-Ecosystem/mindspore
fix bugs
This commit is contained in:
parent
88aaec279c
commit
3704e05370
|
@ -75,11 +75,12 @@ class Dropout(Cell):
|
|||
Examples:
|
||||
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
||||
>>> net = nn.Dropout(keep_prob=0.8)
|
||||
>>> net.set_train()
|
||||
>>> net(x)
|
||||
[[[1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0]],
|
||||
[[1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0]]]
|
||||
[[[0., 1.25, 0.],
|
||||
[1.25, 1.25, 1.25]],
|
||||
[[1.25, 1.25, 1.25],
|
||||
[1.25, 1.25, 1.25]]]
|
||||
"""
|
||||
|
||||
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
||||
|
@ -287,7 +288,8 @@ class ClipByNorm(Cell):
|
|||
>>> net = nn.ClipByNorm()
|
||||
>>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
||||
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
||||
>>> net(input, clip_norm)
|
||||
>>> net(input, clip_norm).shape
|
||||
(4, 16)
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
@ -447,6 +447,8 @@ class CentralCrop(Cell):
|
|||
>>> net = nn.CentralCrop(central_fraction=0.5)
|
||||
>>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
|
||||
>>> output = net(image)
|
||||
>>> output.shape
|
||||
(4, 3, 2, 2)
|
||||
"""
|
||||
|
||||
def __init__(self, central_fraction):
|
||||
|
|
|
@ -1941,7 +1941,7 @@ class Slice(PrimitiveWithInfer):
|
|||
"""
|
||||
Slices a tensor in the specified shape.
|
||||
|
||||
Args:
|
||||
Inputs:
|
||||
x (Tensor): The target tensor.
|
||||
begin (tuple): The beginning of the slice. Only constant value is allowed.
|
||||
size (tuple): The size of the slice. Only constant value is allowed.
|
||||
|
@ -2262,8 +2262,8 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
validator.check_value_type("strides", strides_v, [tuple], self.name)
|
||||
|
||||
if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
|
||||
raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
|
||||
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
|
||||
raise TypeError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
|
||||
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
|
||||
|
||||
if tuple(filter(lambda x: x == 0, strides_v)):
|
||||
raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.")
|
||||
|
|
|
@ -724,11 +724,29 @@ class BatchMatMul(MatMul):
|
|||
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
|
||||
>>> batmatmul = P.BatchMatMul()
|
||||
>>> output = batmatmul(input_x, input_y)
|
||||
[[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]
|
||||
|
||||
[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]]
|
||||
>>>
|
||||
>>> input_x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32)
|
||||
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
|
||||
>>> batmatmul = P.BatchMatMul(transpose_a=True)
|
||||
>>> output = batmatmul(input_x, input_y)
|
||||
[[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]
|
||||
|
||||
[[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]
|
||||
[[3. 3. 3. 3.]]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue