add sparse api docs
This commit is contained in:
parent
9b8d38eab4
commit
963bd67a60
|
@ -387,16 +387,10 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
|
||||||
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
|
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
|
||||||
<< dense_shape_vec[i];
|
<< dense_shape_vec[i];
|
||||||
}
|
}
|
||||||
if (i == 0) {
|
// The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
|
||||||
if (dense_shape_vec[i] < values_shp[i]) {
|
if (i != 0 && dense_shape_vec[i] != values_shp[i]) {
|
||||||
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape should be greator or equal to the " << i
|
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
|
||||||
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
|
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (dense_shape_vec[i] != values_shp[i]) {
|
|
||||||
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
|
|
||||||
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
||||||
|
|
|
@ -213,9 +213,73 @@ class Tensor(Tensor_):
|
||||||
|
|
||||||
|
|
||||||
class IndexedSlices:
|
class IndexedSlices:
|
||||||
|
"""
|
||||||
|
A sparse representation of a set of tensor slices at given indices.
|
||||||
|
|
||||||
|
An IndexedSlices is typically used to represent a subset of a larger
|
||||||
|
tensor dense of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0.
|
||||||
|
The values in indices are the indices in the first dimension of the slices
|
||||||
|
that have been extracted from the larger tensor.
|
||||||
|
The dense tensor dense represented by an IndexedSlices slices has
|
||||||
|
`dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
|
||||||
|
IndexedSlices can only be used in `Cell`'s contruct method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices (Tensor): A 1-D integer Tensor of shape [D0].
|
||||||
|
values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn].
|
||||||
|
dense_shape: (tuple): A integer tuple containing the shape
|
||||||
|
of the corresponding dense tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndexedSlices, composed of `indices`, `values`, `dense_shape`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Create a IndexedSlices.
|
||||||
|
>>> indices = Tensor([1, 2])
|
||||||
|
>>> values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
|
||||||
|
>>> dense_shape = (3, 2)
|
||||||
|
>>> indexed_slices = IndexedSlices(indices, values, dense_shape)
|
||||||
|
>>>
|
||||||
|
>>> # Get atrr.
|
||||||
|
>>> indices = indexed_slices.indices()
|
||||||
|
>>> values = indexed_slices.values()
|
||||||
|
>>> dense_shape = indexed_slices.dense_shape()
|
||||||
|
"""
|
||||||
def __init__(self, indices, values, dense_shape):
|
def __init__(self, indices, values, dense_shape):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class SparseTensor:
|
class SparseTensor:
|
||||||
|
"""
|
||||||
|
A sparse representation of a set of nonzero elememts from a tensor at given indices.
|
||||||
|
|
||||||
|
SparseTensor can only be used in `Cell`'s contruct method.
|
||||||
|
For a tensor dense, its SparseTensor(indices, values, dense_shape) has
|
||||||
|
`dense[indices[i]] = values[i]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
|
||||||
|
where N and ndims are the number of values and number of dimensions in
|
||||||
|
the SparseTensor, respectively.
|
||||||
|
values (Tensor): A 1-D tensor of any type and shape `[N]`, which
|
||||||
|
supplies the values for each element in indices.
|
||||||
|
dense_shape: (tuple): A integer tuple of size `ndims`,
|
||||||
|
which specifies the dense_shape of the sparse tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SparseTensor, composed of `indices`, `values`, `dense_shape`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Create a SparseTensor.
|
||||||
|
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||||
|
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
>>> dense_shape = (3, 4)
|
||||||
|
>>> sparse_tensor = SparseTensor(indices, values, dense_shape)
|
||||||
|
>>>
|
||||||
|
>>> # Get atrr.
|
||||||
|
>>> indices = sparse_tensor.indices()
|
||||||
|
>>> values = sparse_tensor.values()
|
||||||
|
>>> dense_shape = sparse_tensor.dense_shape()
|
||||||
|
"""
|
||||||
def __init__(self, indices, values, dense_shape):
|
def __init__(self, indices, values, dense_shape):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
@ -100,7 +99,6 @@ def test_embeddinglookup_reducescatter_true_grad():
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
|
||||||
def test_embeddinglookup_semi_auto1():
|
def test_embeddinglookup_semi_auto1():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
shape = [64, 32]
|
shape = [64, 32]
|
||||||
|
@ -115,7 +113,6 @@ def test_embeddinglookup_semi_auto1():
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
|
||||||
def test_embeddinglookup_semi_auto2():
|
def test_embeddinglookup_semi_auto2():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
shape = [64, 32]
|
shape = [64, 32]
|
||||||
|
|
|
@ -61,7 +61,6 @@ class Net(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
|
||||||
def test_gatherv2_semi_auto0():
|
def test_gatherv2_semi_auto0():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
strategy1 = ((1, 8), (1, 1))
|
strategy1 = ((1, 8), (1, 1))
|
||||||
|
@ -134,7 +133,6 @@ def test_gatherv2_semi_auto5():
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
|
||||||
def test_gatherv2_semi_auto6():
|
def test_gatherv2_semi_auto6():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
@ -169,7 +167,6 @@ def test_gatherv2_semi_auto8():
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
|
|
||||||
def test_gatherv2_auto0():
|
def test_gatherv2_auto0():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||||
net = GradWrap(NetWithLoss(Net(0)))
|
net = GradWrap(NetWithLoss(Net(0)))
|
||||||
|
|
Loading…
Reference in New Issue