!47912 [OPS] modify meshgrid funtional input form tuple to tensor sequence

Merge pull request !47912 from yangruoqi713/ops
This commit is contained in:
i-robot 2023-01-17 03:19:05 +00:00 committed by Gitee
commit 288a901c1e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 11 additions and 7 deletions

View File

@ -1,14 +1,16 @@
mindspore.ops.meshgrid
======================
.. py:function:: mindspore.ops.meshgrid(inputs, indexing='xy')
.. py:function:: mindspore.ops.meshgrid(*inputs, indexing='xy')
从给定的Tensor生成网格矩阵。
给定N个一维Tensor对每个Tensor做扩张操作返回N个N维的Tensor。
参数:
- **inputs** (Union[tuple]) - N个一维Tensor。输入的长度应大于1。数据类型为Number。
- **inputs** (tuple[Tensor]) - N个一维Tensor。输入的长度应大于1。数据类型为Number。
关键字参数:
- **indexing** ('xy', 'ij', 可选) - 'xy'或'ij'。影响输出的网格矩阵的size。对于长度为 `M``N` 的二维输入,取值为'xy'时输出的shape为 :math:`(N, M)` ,取值为'ij'时输出的shape为 :math:`(M, N)` 。以长度为 `M` `N``P` 的三维输入,取值为'xy'时输出的shape为 :math:`(N, M, P)` ,取值为'ij'时输出的shape为 :math:`(M, N, P)` 。默认值:'xy'。
返回:

View File

@ -4003,7 +4003,7 @@ def matrix_set_diag(x, diagonal, k=0, align="RIGHT_LEFT"): # pylint: disable=red
return matrix_set_diag_v3_op(x, diagonal, k)
def meshgrid(inputs, indexing='xy'):
def meshgrid(*inputs, indexing='xy'):
"""
Generates coordinate matrices from given coordinate tensors.
@ -4011,14 +4011,16 @@ def meshgrid(inputs, indexing='xy'):
coordinate tensors for evaluating expressions on an N-D grid.
Args:
inputs (Union[tuple]): A Tuple of N 1-D Tensor objects.
inputs (tuple[Tensor]): A list of N 1-D Tensor objects.
The length of input should be greater than 1. The data type is Number.
Keyword Args:
indexing ('xy', 'ij', optional): Cartesian ('xy', default) or
matrix ('ij') indexing of output. In the 2-D case with
inputs of length `M` and `N`, the outputs are of shape `(N, M)`
for 'xy' indexing and `(M, N)` for 'ij' indexing. In the 3-D
case with inputs of length `M`, `N` and `P`, outputs are of shape
`(N, M, P)` for 'xy' indexing and `(M, N, P)` for 'ij' indexing.
`(N, M, P)` for 'xy' indexing and `(M, N, P)` for 'ij' indexing. Default: 'xy'.
Returns:
Tensors, a Tuple of N N-D Tensor objects. The data type is the same with the Inputs.

View File

@ -81,7 +81,7 @@ def test_meshgrid(dtype, indexing):
assert np.array_equal(output[1].asnumpy(), np_output[1])
# test functional interface
output = F.meshgrid((Tensor(x), Tensor(y)), indexing)
output = F.meshgrid(Tensor(x), Tensor(y), indexing=indexing)
assert np.array_equal(output[0].asnumpy(), np_output[0])
assert np.array_equal(output[1].asnumpy(), np_output[1])
@ -93,7 +93,7 @@ def test_meshgrid(dtype, indexing):
assert np.array_equal(output[2].asnumpy(), np_output[2])
# test functional interface
output = F.meshgrid((Tensor(x), Tensor(y), Tensor(z)), indexing)
output = F.meshgrid(Tensor(x), Tensor(y), Tensor(z), indexing=indexing)
assert np.array_equal(output[0].asnumpy(), np_output[0])
assert np.array_equal(output[1].asnumpy(), np_output[1])
assert np.array_equal(output[2].asnumpy(), np_output[2])