fix examples issues

This commit is contained in:
luojianing 2022-11-22 20:40:25 +08:00
parent 1d04c8f8f6
commit a32b19be56
7 changed files with 21 additions and 27 deletions

View File

@ -358,8 +358,8 @@ def dot(x1, x2):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor, ops
>>> import mindspore
>>> from mindspore import Tensor, ops
>>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
>>> output = ops.dot(input_x1, input_x2)

View File

@ -2822,11 +2822,10 @@ def approximate_equal(x, y, tolerance=1e-5):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.function.math_func import approximate_equal
>>> tol = 1.5
>>> x = Tensor(np.array([1, 2, 3]), mstype.float32)
>>> y = Tensor(np.array([2, 4, 6]), mstype.float32)
>>> output = approximate_equal(Tensor(x), Tensor(y), tol)
>>> output = ops.approximate_equal(Tensor(x), Tensor(y), tol)
>>> print(output)
[ True False False]
"""

View File

@ -2343,7 +2343,7 @@ def conv3d_transpose(inputs, weight, pad_mode='valid', padding=0, stride=1, dila
Examples:
>>> dout = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float16)
>>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float16)
>>> output = conv3d_transpose(dout, weight)
>>> output = ops.conv3d_transpose(dout, weight)
>>> print(output.shape)
(32, 3, 13, 37, 33)
"""
@ -2487,7 +2487,7 @@ def hardsigmoid(input_x):
Examples:
>>> x = Tensor(np.array([ -3.5, 0, 4.3]), mindspore.float32)
>>> output = F.hardsigmoid(x)
>>> output = ops.hardsigmoid(x)
>>> print(output)
[0. 0.5 1. ]
"""

View File

@ -1266,10 +1266,9 @@ class Padding(Primitive):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.operations.array_ops import Padding
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
>>> pad_dim_size = 4
>>> output = Padding(pad_dim_size)(x)
>>> output = ops.Padding(pad_dim_size)(x)
>>> print(output)
[[ 8. 0. 0. 0.]
[10. 0. 0. 0.]]
@ -1444,7 +1443,6 @@ class MatrixDiagV3(Primitive):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.operations.array_ops import MatrixDiagV3
>>> x = Tensor(np.array([[8, 9, 0],
... [1, 2, 3],
... [0, 4, 5]]), mindspore.float32)
@ -1452,7 +1450,7 @@ class MatrixDiagV3(Primitive):
>>> num_rows = Tensor(np.array(3), mindspore.int32)
>>> num_cols = Tensor(np.array(3), mindspore.int32)
>>> padding_value = Tensor(np.array(11), mindspore.float32)
>>> matrix_diag_v3 = MatrixDiagV3(align='LEFT_RIGHT')
>>> matrix_diag_v3 = ops.MatrixDiagV3(align='LEFT_RIGHT')
>>> output = matrix_diag_v3(x, k, num_rows, num_cols, padding_value)
>>> print(output)
[[ 1. 8. 11.]
@ -1485,7 +1483,7 @@ class MatrixDiagPartV3(Primitive):
... [9, 8, 7, 6]]), mindspore.float32)
>>> k =Tensor(np.array([1, 3]), mindspore.int32)
>>> padding_value = Tensor(np.array(9), mindspore.float32)
>>> matrix_diag_part_v3 = ops.operations.array_ops.MatrixDiagPartV3(align='RIGHT_LEFT')
>>> matrix_diag_part_v3 = ops.MatrixDiagPartV3(align='RIGHT_LEFT')
>>> output = matrix_diag_part_v3(x, k, padding_value)
>>> print(output)
[[9. 9. 4.]
@ -3723,13 +3721,12 @@ class DiagPart(PrimitiveWithInfer):
Supported Platforms:
``Ascend`` ``GPU``
Examples
Examples:
>>> input_x = Tensor([[1, 0, 0, 0],
... [0, 2, 0, 0],
... [0, 0, 3, 0],
... [0, 0, 0, 4]])
>>> import mindspore.ops as P
>>> diag_part = P.DiagPart()
>>> diag_part = ops.DiagPart()
>>> output = diag_part(input_x)
>>> print(output)
[1 2 3 4]
@ -4961,11 +4958,10 @@ class ScatterNdMul(_ScatterNdOp):
``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.operations.array_ops import ScatterNdMul
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_nd_mul = ScatterNdMul()
>>> scatter_nd_mul = ops.ScatterNdMul()
>>> output = scatter_nd_mul(input_x, indices, updates)
>>> print(output)
[ 1. 16. 18. 4. 35. 6. 7. 72.]
@ -4973,7 +4969,7 @@ class ScatterNdMul(_ScatterNdOp):
>>> indices = Tensor(np.array([[0], [2]]), mindspore.int32)
>>> updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
... [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]]), mindspore.int32)
>>> scatter_nd_mul = ScatterNdMul()
>>> scatter_nd_mul = ops.ScatterNdMul()
>>> output = scatter_nd_mul(input_x, indices, updates)
>>> print(output)
[[[1 1 1 1]
@ -5057,11 +5053,10 @@ class ScatterNdMax(_ScatterNdOp):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.operations.array_ops import ScatterNdMax
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
>>> scatter_nd_max = ScatterNdMax()
>>> scatter_nd_max = ops.ScatterNdMax()
>>> output = scatter_nd_max(input_x, indices, updates)
>>> print(output)
[ 1. 8. 6. 4. 7. 6. 7. 9.]
@ -5069,7 +5064,7 @@ class ScatterNdMax(_ScatterNdOp):
>>> indices = Tensor(np.array([[0], [2]]), mindspore.int32)
>>> updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
... [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]]), mindspore.int32)
>>> scatter_nd_max = ScatterNdMax()
>>> scatter_nd_max = ops.ScatterNdMax()
>>> output = scatter_nd_max(input_x, indices, updates)
>>> print(output)
[[[1 1 1 1]

View File

@ -96,9 +96,9 @@ class Svd(Primitive):
Examples:
>>> import numpy as np
>>> from mindspore import Tensor, set_context
>>> from mindspore.ops.operations import linalg_ops as linalg
>>> from mindspore import ops
>>> set_context(device_target="CPU")
>>> svd = linalg.Svd(full_matrices=True, compute_uv=True)
>>> svd = ops.Svd(full_matrices=True, compute_uv=True)
>>> a = Tensor(np.array([[1, 2], [-4, -5], [2, 1]]).astype(np.float32))
>>> s, u, v = svd(a)
>>> print(s)

View File

@ -5860,9 +5860,10 @@ class IsClose(Primitive):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops.operations.math_ops import IsClose
>>> from mindspore.ops import IsClose
>>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16)
>>> other = Tensor(np.array([1.3, 3.3, 2.3, 3.1, 5.1]), mindspore.float16)
>>> isclose = IsClose()

View File

@ -1309,7 +1309,7 @@ class DataFormatVecPermute(Primitive):
>>> class Net(nn.Cell):
... def __init__(self, src_format="NHWC", dst_format="NCHW"):
... super().__init__()
... self.op = P.nn_ops.DataFormatVecPermute(src_format, dst_format)
... self.op = ops.DataFormatVecPermute(src_format, dst_format)
... def construct(self, x):
... return self.op(x)
...
@ -3027,9 +3027,9 @@ class L2Loss(Primitive):
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples
Examples:
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16)
>>> l2_loss = L2Loss()
>>> l2_loss = ops.L2Loss()
>>> output = l2_loss(input_x)
>>> print(output)
7.0
@ -6938,8 +6938,7 @@ class Dropout2D(PrimitiveWithInfer):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.operations.nn_ops import Dropout2D
>>> dropout = Dropout2D(keep_prob=0.5)
>>> dropout = ops.Dropout2D(keep_prob=0.5)
>>> x = Tensor(np.ones([2, 1, 2, 3]), mindspore.float32)
>>> output, mask = dropout(x)
>>> print(output.shape)