!12530 add meshgrid support for start <= stop & complex step with real component

From: @jachua
Reviewed-by: @guoqi1024,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-02-23 17:41:57 +08:00 committed by Gitee
commit 842ca43df3
1 changed files with 19 additions and 1 deletions

View File

@ -1291,6 +1291,12 @@ def meshgrid(*xi, sparse=False, indexing='xy'):
if indexing not in ('xy', 'ij'): if indexing not in ('xy', 'ij'):
_raise_type_error("Valid values for `indexing` are 'xy' and 'ij'.") _raise_type_error("Valid values for `indexing` are 'xy' and 'ij'.")
shape_out = ()
for x in xi:
shape_out += (x.size,)
if _is_shape_empty(shape_out):
return ones(shape_out)
grids = [] grids = []
for x in xi: for x in xi:
if F.rank(x) == 1: if F.rank(x) == 1:
@ -1351,7 +1357,7 @@ class nd_grid:
else: else:
step = 1 step = 1
if isinstance(step, complex): if isinstance(step, complex):
v = linspace(k.start, k.stop, int(abs(step.imag))) v = linspace(k.start, k.stop, int(abs(step)))
else: else:
v = arange(k.start, k.stop, step) v = arange(k.start, k.stop, step)
xi.append(v) xi.append(v)
@ -1362,6 +1368,8 @@ class nd_grid:
if self.sparse: if self.sparse:
return grids return grids
if isinstance(grids, Tensor_):
return grids
expanded = [] expanded = []
for grid in grids: for grid in grids:
expanded.append(F.expand_dims(grid, 0)) expanded.append(F.expand_dims(grid, 0))
@ -1380,6 +1388,11 @@ class mGridClass(nd_grid):
as specifying the number of points to create between the start and as specifying the number of points to create between the start and
stop values, where the stop value is inclusive. stop values, where the stop value is inclusive.
Note:
Unlike Numpy, if the step length is a complex number with a real
component, the step length is handled as equivalent to
``int(abs(step))``.
Returns: Returns:
Tensor or tuple of tensor, a meshgrid. Tensor or tuple of tensor, a meshgrid.
@ -1422,6 +1435,11 @@ class oGridClass(nd_grid):
as specifying the number of points to create between the start and as specifying the number of points to create between the start and
stop values, where the stop value is inclusive. stop values, where the stop value is inclusive.
Note:
Unlike Numpy, if the step length is a complex number with a real
component, the step length is handled as equivalent to
``int(abs(step))``.
Raises: Raises:
TypeError: if slicing indices are not integers. TypeError: if slicing indices are not integers.