From e17dffbd767f5901358b951363a41f14a273027d Mon Sep 17 00:00:00 2001 From: huangmengxi Date: Tue, 23 Feb 2021 11:27:23 +0800 Subject: [PATCH] fix meshgrid with complex step & stop <= start --- mindspore/numpy/array_creations.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py index 5d62f191e32..2d2b655b0fa 100644 --- a/mindspore/numpy/array_creations.py +++ b/mindspore/numpy/array_creations.py @@ -1291,6 +1291,12 @@ def meshgrid(*xi, sparse=False, indexing='xy'): if indexing not in ('xy', 'ij'): _raise_type_error("Valid values for `indexing` are 'xy' and 'ij'.") + shape_out = () + for x in xi: + shape_out += (x.size,) + if _is_shape_empty(shape_out): + return ones(shape_out) + grids = [] for x in xi: if F.rank(x) == 1: @@ -1351,7 +1357,7 @@ class nd_grid: else: step = 1 if isinstance(step, complex): - v = linspace(k.start, k.stop, int(abs(step.imag))) + v = linspace(k.start, k.stop, int(abs(step))) else: v = arange(k.start, k.stop, step) xi.append(v) @@ -1362,6 +1368,8 @@ class nd_grid: if self.sparse: return grids + if isinstance(grids, Tensor_): + return grids expanded = [] for grid in grids: expanded.append(F.expand_dims(grid, 0)) @@ -1380,6 +1388,11 @@ class mGridClass(nd_grid): as specifying the number of points to create between the start and stop values, where the stop value is inclusive. + Note: + Unlike Numpy, if the step length is a complex number with a real + component, the step length is handled as equivalent to + ``int(abs(step))``. + Returns: Tensor or tuple of tensor, a meshgrid. @@ -1422,6 +1435,11 @@ class oGridClass(nd_grid): as specifying the number of points to create between the start and stop values, where the stop value is inclusive. + Note: + Unlike Numpy, if the step length is a complex number with a real + component, the step length is handled as equivalent to + ``int(abs(step))``. + Raises: TypeError: if slicing indices are not integers.