add sparse api docs

This commit is contained in:
panyifeng 2020-07-24 11:03:32 +08:00
parent 9b8d38eab4
commit 963bd67a60
4 changed files with 68 additions and 16 deletions

View File

@ -387,16 +387,10 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
<< dense_shape_vec[i];
}
if (i == 0) {
if (dense_shape_vec[i] < values_shp[i]) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape should be greator or equal to the " << i
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
}
} else {
if (dense_shape_vec[i] != values_shp[i]) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
}
// The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
if (i != 0 && dense_shape_vec[i] != values_shp[i]) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
}
}
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);

View File

@ -213,9 +213,73 @@ class Tensor(Tensor_):
class IndexedSlices:
"""
A sparse representation of a set of tensor slices at given indices.
An IndexedSlices is typically used to represent a subset of a larger
tensor dense of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0.
The values in indices are the indices in the first dimension of the slices
that have been extracted from the larger tensor.
The dense tensor dense represented by an IndexedSlices slices has
`dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
IndexedSlices can only be used in `Cell`'s contruct method.
Args:
indices (Tensor): A 1-D integer Tensor of shape [D0].
values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn].
dense_shape: (tuple): A integer tuple containing the shape
of the corresponding dense tensor.
Returns:
IndexedSlices, composed of `indices`, `values`, `dense_shape`.
Examples:
>>> # Create a IndexedSlices.
>>> indices = Tensor([1, 2])
>>> values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
>>> dense_shape = (3, 2)
>>> indexed_slices = IndexedSlices(indices, values, dense_shape)
>>>
>>> # Get atrr.
>>> indices = indexed_slices.indices()
>>> values = indexed_slices.values()
>>> dense_shape = indexed_slices.dense_shape()
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
class SparseTensor:
"""
A sparse representation of a set of nonzero elememts from a tensor at given indices.
SparseTensor can only be used in `Cell`'s contruct method.
For a tensor dense, its SparseTensor(indices, values, dense_shape) has
`dense[indices[i]] = values[i]`.
Args:
indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
where N and ndims are the number of values and number of dimensions in
the SparseTensor, respectively.
values (Tensor): A 1-D tensor of any type and shape `[N]`, which
supplies the values for each element in indices.
dense_shape: (tuple): A integer tuple of size `ndims`,
which specifies the dense_shape of the sparse tensor.
Returns:
SparseTensor, composed of `indices`, `values`, `dense_shape`.
Examples:
>>> # Create a SparseTensor.
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> dense_shape = (3, 4)
>>> sparse_tensor = SparseTensor(indices, values, dense_shape)
>>>
>>> # Get atrr.
>>> indices = sparse_tensor.indices()
>>> values = sparse_tensor.values()
>>> dense_shape = sparse_tensor.dense_shape()
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
@ -100,7 +99,6 @@ def test_embeddinglookup_reducescatter_true_grad():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_embeddinglookup_semi_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]
@ -115,7 +113,6 @@ def test_embeddinglookup_semi_auto1():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_embeddinglookup_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]

View File

@ -61,7 +61,6 @@ class Net(nn.Cell):
return out
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_gatherv2_semi_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
@ -134,7 +133,6 @@ def test_gatherv2_semi_auto5():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_gatherv2_semi_auto6():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy2 = ((4, 2, 1), (4, 2, 1))
@ -169,7 +167,6 @@ def test_gatherv2_semi_auto8():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_gatherv2_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net(0)))