!40030 sparse_concat 资料修改
Merge pull request !40030 from baochong/master
This commit is contained in:
commit
03b53cc3b7
|
@ -452,22 +452,22 @@ def sparse_concat(sp_input, concat_dim=0):
|
|||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indics0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
|
||||
>>> indices0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
|
||||
>>> values0 = Tensor([1, 2], dtype=mstype.int32)
|
||||
>>> shape0 = (3, 4)
|
||||
>>> input0 = COOTensor(indics0, values0, shape0)
|
||||
>>> indics1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
|
||||
>>> input0 = COOTensor(indices0, values0, shape0)
|
||||
>>> indices1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
|
||||
>>> values1 = Tensor([3, 4], dtype=mstype.int32)
|
||||
>>> shape1 = (3, 4)
|
||||
>>> input1 = COOTensor(indics1, values1, shape1)
|
||||
>>> input1 = COOTensor(indices1, values1, shape1)
|
||||
>>> concat_dim = 1
|
||||
>>> out = F.sparse_concat((input0, input1), concat_dim)
|
||||
>>> print(out)
|
||||
shape = [3 4]
|
||||
[0 1]: "1"
|
||||
[0 4]: "3"
|
||||
[1 2]: "4"
|
||||
[1 5]: "2"
|
||||
COOTensor(shape=[3, 8], dtype=Int32, indices=Tensor(shape=[4, 2], dtype=Int32, value=
|
||||
[[0 1]
|
||||
[0 4]
|
||||
[1 2]
|
||||
[1 5]]), values=Tensor(shape=[4], dtype=Int32, value=[1 3 2 4]))
|
||||
"""
|
||||
if len(sp_input) < 2:
|
||||
raise_value_error("For sparse_concat, not support COOTensor input number < 2.")
|
||||
|
|
|
@ -671,20 +671,23 @@ class SparseConcat(Primitive):
|
|||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indics0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
|
||||
>>> indices0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int32)
|
||||
>>> values0 = Tensor([1, 2], dtype=mstype.int32)
|
||||
>>> shape0 = Tensor([3, 4], dtype=mstype.int64)
|
||||
>>> indics1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
|
||||
>>> indices1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int32)
|
||||
>>> values1 = Tensor([3, 4], dtype=mstype.int32)
|
||||
>>> shape1 = Tensor([3, 4], dtype=mstype.int64)
|
||||
>>> sparse_concat = ops.SparseConcat(0)
|
||||
>>> out = sparse_concat((indices0, indices1), (values0, values1), (shape0, shape1))
|
||||
>>> print(out)
|
||||
shape = [3 4]
|
||||
[0 1]: "1"
|
||||
[0 4]: "3"
|
||||
[1 2]: "4"
|
||||
[1 5]: "2"
|
||||
>>> sparse_concat = ops.SparseConcat(1)
|
||||
>>> indices, value, shape = sparse_concat((indices0, indices1), (values0, values1), (shape0, shape1))
|
||||
>>> print(indices)
|
||||
[[0 1]
|
||||
[0 4]
|
||||
[1 2]
|
||||
[1 5]]
|
||||
>>> print(value)
|
||||
[1 3 2 4]
|
||||
>>> print(shape)
|
||||
[3 8]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, concat_dim=0):
|
||||
|
|
Loading…
Reference in New Issue