add sparse api docs
This commit is contained in:
parent
9b8d38eab4
commit
963bd67a60
|
@ -387,16 +387,10 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
|
|||
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
|
||||
<< dense_shape_vec[i];
|
||||
}
|
||||
if (i == 0) {
|
||||
if (dense_shape_vec[i] < values_shp[i]) {
|
||||
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape should be greator or equal to the " << i
|
||||
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
|
||||
}
|
||||
} else {
|
||||
if (dense_shape_vec[i] != values_shp[i]) {
|
||||
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
|
||||
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
|
||||
}
|
||||
// The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
|
||||
if (i != 0 && dense_shape_vec[i] != values_shp[i]) {
|
||||
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
|
||||
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
|
||||
}
|
||||
}
|
||||
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
||||
|
|
|
@ -213,9 +213,73 @@ class Tensor(Tensor_):
|
|||
|
||||
|
||||
class IndexedSlices:
|
||||
"""
|
||||
A sparse representation of a set of tensor slices at given indices.
|
||||
|
||||
An IndexedSlices is typically used to represent a subset of a larger
|
||||
tensor dense of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0.
|
||||
The values in indices are the indices in the first dimension of the slices
|
||||
that have been extracted from the larger tensor.
|
||||
The dense tensor dense represented by an IndexedSlices slices has
|
||||
`dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
|
||||
IndexedSlices can only be used in `Cell`'s contruct method.
|
||||
|
||||
Args:
|
||||
indices (Tensor): A 1-D integer Tensor of shape [D0].
|
||||
values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn].
|
||||
dense_shape: (tuple): A integer tuple containing the shape
|
||||
of the corresponding dense tensor.
|
||||
|
||||
Returns:
|
||||
IndexedSlices, composed of `indices`, `values`, `dense_shape`.
|
||||
|
||||
Examples:
|
||||
>>> # Create a IndexedSlices.
|
||||
>>> indices = Tensor([1, 2])
|
||||
>>> values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
|
||||
>>> dense_shape = (3, 2)
|
||||
>>> indexed_slices = IndexedSlices(indices, values, dense_shape)
|
||||
>>>
|
||||
>>> # Get atrr.
|
||||
>>> indices = indexed_slices.indices()
|
||||
>>> values = indexed_slices.values()
|
||||
>>> dense_shape = indexed_slices.dense_shape()
|
||||
"""
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SparseTensor:
|
||||
"""
|
||||
A sparse representation of a set of nonzero elememts from a tensor at given indices.
|
||||
|
||||
SparseTensor can only be used in `Cell`'s contruct method.
|
||||
For a tensor dense, its SparseTensor(indices, values, dense_shape) has
|
||||
`dense[indices[i]] = values[i]`.
|
||||
|
||||
Args:
|
||||
indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
|
||||
where N and ndims are the number of values and number of dimensions in
|
||||
the SparseTensor, respectively.
|
||||
values (Tensor): A 1-D tensor of any type and shape `[N]`, which
|
||||
supplies the values for each element in indices.
|
||||
dense_shape: (tuple): A integer tuple of size `ndims`,
|
||||
which specifies the dense_shape of the sparse tensor.
|
||||
|
||||
Returns:
|
||||
SparseTensor, composed of `indices`, `values`, `dense_shape`.
|
||||
|
||||
Examples:
|
||||
>>> # Create a SparseTensor.
|
||||
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> dense_shape = (3, 4)
|
||||
>>> sparse_tensor = SparseTensor(indices, values, dense_shape)
|
||||
>>>
|
||||
>>> # Get atrr.
|
||||
>>> indices = sparse_tensor.indices()
|
||||
>>> values = sparse_tensor.values()
|
||||
>>> dense_shape = sparse_tensor.dense_shape()
|
||||
"""
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
@ -100,7 +99,6 @@ def test_embeddinglookup_reducescatter_true_grad():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
||||
def test_embeddinglookup_semi_auto1():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 32]
|
||||
|
@ -115,7 +113,6 @@ def test_embeddinglookup_semi_auto1():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
||||
def test_embeddinglookup_semi_auto2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 32]
|
||||
|
|
|
@ -61,7 +61,6 @@ class Net(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
||||
def test_gatherv2_semi_auto0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
|
@ -134,7 +133,6 @@ def test_gatherv2_semi_auto5():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
||||
def test_gatherv2_semi_auto6():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -169,7 +167,6 @@ def test_gatherv2_semi_auto8():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
||||
def test_gatherv2_auto0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(0)))
|
||||
|
|
Loading…
Reference in New Issue