This commit is contained in:
lihongkang 2020-10-29 19:27:29 +08:00
parent 88aaec279c
commit 3704e05370
4 changed files with 30 additions and 8 deletions

View File

@ -75,11 +75,12 @@ class Dropout(Cell):
Examples:
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> net = nn.Dropout(keep_prob=0.8)
>>> net.set_train()
>>> net(x)
[[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]]]
[[[0., 1.25, 0.],
[1.25, 1.25, 1.25]],
[[1.25, 1.25, 1.25],
[1.25, 1.25, 1.25]]]
"""
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
@ -287,7 +288,8 @@ class ClipByNorm(Cell):
>>> net = nn.ClipByNorm()
>>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
>>> net(input, clip_norm)
>>> net(input, clip_norm).shape
(4, 16)
"""

View File

@ -447,6 +447,8 @@ class CentralCrop(Cell):
>>> net = nn.CentralCrop(central_fraction=0.5)
>>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
>>> output = net(image)
>>> output.shape
(4, 3, 2, 2)
"""
def __init__(self, central_fraction):

View File

@ -1941,7 +1941,7 @@ class Slice(PrimitiveWithInfer):
"""
Slices a tensor in the specified shape.
Args:
Inputs:
x (Tensor): The target tensor.
begin (tuple): The beginning of the slice. Only constant value is allowed.
size (tuple): The size of the slice. Only constant value is allowed.
@ -2262,8 +2262,8 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_value_type("strides", strides_v, [tuple], self.name)
if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
raise TypeError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
if tuple(filter(lambda x: x == 0, strides_v)):
raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.")

View File

@ -724,11 +724,29 @@ class BatchMatMul(MatMul):
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul()
>>> output = batmatmul(input_x, input_y)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]
[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
>>>
>>> input_x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul(transpose_a=True)
>>> output = batmatmul(input_x, input_y)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]
[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
"""
@prim_attr_register