!40030 sparse_concat 资料修改

Merge pull request !40030 from baochong/master
This commit is contained in:
i-robot 2022-08-09 07:48:52 +00:00 committed by Gitee
commit 03b53cc3b7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 22 additions and 19 deletions

View File

@ -452,22 +452,22 @@ def sparse_concat(sp_input, concat_dim=0):
``CPU``
Examples:
>>> indics0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
>>> indices0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
>>> values0 = Tensor([1, 2], dtype=mstype.int32)
>>> shape0 = (3, 4)
>>> input0 = COOTensor(indics0, values0, shape0)
>>> indics1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
>>> input0 = COOTensor(indices0, values0, shape0)
>>> indices1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
>>> values1 = Tensor([3, 4], dtype=mstype.int32)
>>> shape1 = (3, 4)
>>> input1 = COOTensor(indics1, values1, shape1)
>>> input1 = COOTensor(indices1, values1, shape1)
>>> concat_dim = 1
>>> out = F.sparse_concat((input0, input1), concat_dim)
>>> print(out)
shape = [3 4]
[0 1]: "1"
[0 4]: "3"
[1 2]: "4"
[1 5]: "2"
COOTensor(shape=[3, 8], dtype=Int32, indices=Tensor(shape=[4, 2], dtype=Int32, value=
[[0 1]
[0 4]
[1 2]
[1 5]]), values=Tensor(shape=[4], dtype=Int32, value=[1 3 2 4]))
"""
if len(sp_input) < 2:
raise_value_error("For sparse_concat, not support COOTensor input number < 2.")

View File

@ -671,20 +671,23 @@ class SparseConcat(Primitive):
``CPU``
Examples:
>>> indics0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
>>> indices0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
>>> values0 = Tensor([1, 2], dtype=mstype.int32)
>>> shape0 = Tensor([3, 4], dtype=mstype.int64)
>>> indics1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
>>> indices1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
>>> values1 = Tensor([3, 4], dtype=mstype.int32)
>>> shape1 = Tensor([3, 4], dtype=mstype.int64)
>>> sparse_concat = ops.SparseConcat(0)
>>> out = sparse_concat((indices0, indices1), (values0, values1), (shape0, shape1))
>>> print(out)
shape = [3 4]
[0 1]: "1"
[0 4]: "3"
[1 2]: "4"
[1 5]: "2"
>>> sparse_concat = ops.SparseConcat(1)
>>> indices, value, shape = sparse_concat((indices0, indices1), (values0, values1), (shape0, shape1))
>>> print(indices)
[[0 1]
[0 4]
[1 2]
[1 5]]
>>> print(value)
[1 3 2 4]
>>> print(shape)
[3 8]
"""
@prim_attr_register
def __init__(self, concat_dim=0):