forked from mindspore-Ecosystem/mindspore
!41786 [assistant][ops]Add CartesianProd
Merge pull request !41786 from zy26/CartesianProd
This commit is contained in:
commit
de6831caba
|
@ -313,6 +313,7 @@ from .math_func import (
|
|||
atleast_1d,
|
||||
dstack,
|
||||
atleast_2d,
|
||||
cartesian_prod,
|
||||
atleast_3d,
|
||||
vstack,
|
||||
combinations,
|
||||
|
|
|
@ -5687,6 +5687,45 @@ def atleast_2d(inputs):
|
|||
return [_expand(arr, 2) for arr in inputs]
|
||||
|
||||
|
||||
def cartesian_prod(*inputs):
|
||||
r"""
|
||||
Performs a Cartesian product for a given tensor sequence.
|
||||
The behavior is similar to Python's `itertools.product`.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): Tensor sequence.
|
||||
|
||||
Returns:
|
||||
Tensor, a Cartesian product for a given tensor sequence.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = Tensor([1, 2])
|
||||
>>> x2 = Tensor([5])
|
||||
>>> out = ops.cartesian_prod(x1, x2)
|
||||
>>> print(out.asnumpy())
|
||||
[[1 5]
|
||||
[2 5]]
|
||||
>>> x1 = Tensor([1, 2, 3, 4])
|
||||
>>> x2 = Tensor([5, 6, 7])
|
||||
>>> x3 = Tensor([8, 9, 0, 1, 2])
|
||||
>>> out = ops.cartesian_prod(x1, x2, x3)
|
||||
>>> print(len(out))
|
||||
60
|
||||
"""
|
||||
meshgrid = P.Meshgrid(indexing="ij")
|
||||
meshgrid_output = meshgrid(inputs)
|
||||
stack = P.Stack(axis=-1)
|
||||
stack_output = stack(meshgrid_output)
|
||||
reshape = P.Reshape()
|
||||
return reshape(stack_output, (-1, len(inputs)))
|
||||
|
||||
|
||||
def atleast_3d(inputs):
|
||||
r"""
|
||||
Reshapes `inputs` as arrays with at least three dimensions.
|
||||
|
@ -9031,6 +9070,7 @@ __all__ = [
|
|||
'atleast_1d',
|
||||
'dstack',
|
||||
'atleast_2d',
|
||||
'cartesian_prod',
|
||||
'atleast_3d',
|
||||
'vstack',
|
||||
'combinations',
|
||||
|
|
|
@ -459,6 +459,15 @@ class AtLeast2DFunc(nn.Cell):
|
|||
return self.atleast_2d([x1, x2, x3])
|
||||
|
||||
|
||||
class CartesianProdFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CartesianProdFunc, self).__init__()
|
||||
self.cartesian_prod = ops.cartesian_prod
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.cartesian_prod(x1, x2)
|
||||
|
||||
|
||||
class AtLeast3DFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AtLeast3DFunc, self).__init__()
|
||||
|
@ -862,6 +871,11 @@ test_case_math_ops = [
|
|||
Tensor(np.array(1), ms.float64),
|
||||
Tensor(np.array([1, 1, 1, 1, 1]), ms.float64)]
|
||||
}),
|
||||
('CartesianProd', {
|
||||
'block': CartesianProdFunc(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2]), ms.int32),
|
||||
Tensor(np.array([5]), ms.int32)]
|
||||
}),
|
||||
('AtLeast3D', {
|
||||
'block': AtLeast3DFunc(),
|
||||
'desc_inputs': [Tensor(np.array([[1, 1, 1], [1, 1, 1]]), ms.float64),
|
||||
|
|
Loading…
Reference in New Issue