forked from mindspore-Ecosystem/mindspore
!47912 [OPS] modify meshgrid funtional input form tuple to tensor sequence
Merge pull request !47912 from yangruoqi713/ops
This commit is contained in:
commit
288a901c1e
|
@ -1,14 +1,16 @@
|
|||
mindspore.ops.meshgrid
|
||||
======================
|
||||
|
||||
.. py:function:: mindspore.ops.meshgrid(inputs, indexing='xy')
|
||||
.. py:function:: mindspore.ops.meshgrid(*inputs, indexing='xy')
|
||||
|
||||
从给定的Tensor生成网格矩阵。
|
||||
|
||||
给定N个一维Tensor,对每个Tensor做扩张操作,返回N个N维的Tensor。
|
||||
|
||||
参数:
|
||||
- **inputs** (Union[tuple]) - N个一维Tensor。输入的长度应大于1。数据类型为Number。
|
||||
- **inputs** (tuple[Tensor]) - N个一维Tensor。输入的长度应大于1。数据类型为Number。
|
||||
|
||||
关键字参数:
|
||||
- **indexing** ('xy', 'ij', 可选) - 'xy'或'ij'。影响输出的网格矩阵的size。对于长度为 `M` 和 `N` 的二维输入,取值为'xy'时,输出的shape为 :math:`(N, M)` ,取值为'ij'时,输出的shape为 :math:`(M, N)` 。以长度为 `M` , `N` 和 `P` 的三维输入,取值为'xy'时,输出的shape为 :math:`(N, M, P)` ,取值为'ij'时,输出的shape为 :math:`(M, N, P)` 。默认值:'xy'。
|
||||
|
||||
返回:
|
||||
|
|
|
@ -4003,7 +4003,7 @@ def matrix_set_diag(x, diagonal, k=0, align="RIGHT_LEFT"): # pylint: disable=red
|
|||
return matrix_set_diag_v3_op(x, diagonal, k)
|
||||
|
||||
|
||||
def meshgrid(inputs, indexing='xy'):
|
||||
def meshgrid(*inputs, indexing='xy'):
|
||||
"""
|
||||
Generates coordinate matrices from given coordinate tensors.
|
||||
|
||||
|
@ -4011,14 +4011,16 @@ def meshgrid(inputs, indexing='xy'):
|
|||
coordinate tensors for evaluating expressions on an N-D grid.
|
||||
|
||||
Args:
|
||||
inputs (Union[tuple]): A Tuple of N 1-D Tensor objects.
|
||||
inputs (tuple[Tensor]): A list of N 1-D Tensor objects.
|
||||
The length of input should be greater than 1. The data type is Number.
|
||||
|
||||
Keyword Args:
|
||||
indexing ('xy', 'ij', optional): Cartesian ('xy', default) or
|
||||
matrix ('ij') indexing of output. In the 2-D case with
|
||||
inputs of length `M` and `N`, the outputs are of shape `(N, M)`
|
||||
for 'xy' indexing and `(M, N)` for 'ij' indexing. In the 3-D
|
||||
case with inputs of length `M`, `N` and `P`, outputs are of shape
|
||||
`(N, M, P)` for 'xy' indexing and `(M, N, P)` for 'ij' indexing.
|
||||
`(N, M, P)` for 'xy' indexing and `(M, N, P)` for 'ij' indexing. Default: 'xy'.
|
||||
|
||||
Returns:
|
||||
Tensors, a Tuple of N N-D Tensor objects. The data type is the same with the Inputs.
|
||||
|
|
|
@ -81,7 +81,7 @@ def test_meshgrid(dtype, indexing):
|
|||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
|
||||
# test functional interface
|
||||
output = F.meshgrid((Tensor(x), Tensor(y)), indexing)
|
||||
output = F.meshgrid(Tensor(x), Tensor(y), indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
|
||||
|
@ -93,7 +93,7 @@ def test_meshgrid(dtype, indexing):
|
|||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
# test functional interface
|
||||
output = F.meshgrid((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = F.meshgrid(Tensor(x), Tensor(y), Tensor(z), indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
|
Loading…
Reference in New Issue