fix sparse related issues

This commit is contained in:
panyifeng 2020-07-22 14:11:20 +08:00
parent 251fba00f5
commit 2cebc62bbf
8 changed files with 277 additions and 14 deletions

View File

@ -378,10 +378,19 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
if (dense_shape_vec.size() != values_shp.size()) {
MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values "
<< values_shp.size() << ", but got " << dense_shape_value->size();
}
for (size_t i = 0; i < dense_shape_vec.size(); i++) {
if (dense_shape_vec[i] < 0) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
<< dense_shape_vec[i];
}
// The 0th mode might be less or exceed dense_shape[0] due to duplicated selection
if (i != 0 && dense_shape_vec[i] != values_shp[i]) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
}
}
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);

View File

@ -386,6 +386,16 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic["shape"] = arg_tensor->shape()->shape();
dic["dtype"] = arg_tensor->BuildType();
dic["value"] = BuildValue(arg_tensor->BuildValue());
} else if (abs_base->isa<AbstractIndexedSlices>()) {
auto arg = dyn_cast<AbstractIndexedSlices>(abs_base);
dic["shape"] = arg->shape()->shape();
dic["dtype"] = arg->BuildType();
dic["value"] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractSparseTensor>()) {
auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
dic["shape"] = arg->shape()->shape();
dic["dtype"] = arg->BuildType();
dic["value"] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
std::vector<int> shape;
dic["shape"] = shape;

View File

@ -99,6 +99,8 @@ slice_type = typing.Slice
ellipsis_type = typing.TypeEllipsis
list_type = typing.List
tuple_type = typing.Tuple
index_slices = typing.IndexedSlicesType()
sparse_tensor = typing.SparseTensorType()
number_type = (int8,
int16,

View File

@ -209,9 +209,73 @@ class Tensor(Tensor_):
class IndexedSlices:
"""
A sparse representation of a set of tensor slices at given indices.
An IndexedSlices is typically used to represent a subset of a larger
tensor dense of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0.
The values in indices are the indices in the first dimension of the slices
that have been extracted from the larger tensor.
The dense tensor dense represented by an IndexedSlices slices has
`dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
IndexedSlices can only be used in `Cell`'s contruct method.
Args:
indices (Tensor): A 1-D integer Tensor of shape [D0].
values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn].
dense_shape: (tuple): A integer tuple containing the shape
of the corresponding dense tensor.
Outputs:
IndexedSlices, composed of `indices`, `values`, `dense_shape`.
Examples:
>>> # Create a IndexedSlices.
>>> indices = Tensor([1, 2])
>>> values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
>>> dense_shape = (3, 2)
>>> indexed_slices = IndexedSlices(indices, values, dense_shape)
>>>
>>> # Get atrr.
>>> indices = indexed_slices.indices()
>>> values = indexed_slices.values()
>>> dense_shape = indexed_slices.dense_shape()
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
class SparseTensor:
"""
A sparse representation of a set of nonzero elememts from a tensor at given indices.
SparseTensor can only be used in `Cell`'s contruct method.
For a tensor dense, its SparseTensor(indices, values, dense_shape) has
`dense[indices[i]] = values[i]`
Args:
indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
where N and ndims are the number of values and number of dimensions in
the SparseTensor, respectively.
values (Tensor): A 1-D tensor of any type and shape `[N]`, which
supplies the values for each element in indices.
dense_shape: (tuple): A integer tuple of size `ndims`,
which specifies the dense_shape of the sparse tensor.
Outputs:
SparseTensor, composed of `indices`, `values`, `dense_shape`.
Examples:
>>> # Create a SparseTensor.
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> dense_shape = (3, 4)
>>> sparse_tensor = SparseTensor(indices, values, dense_shape)
>>>
>>> # Get atrr.
>>> indices = sparse_tensor.indices()
>>> values = sparse_tensor.values()
>>> dense_shape = sparse_tensor.dense_shape()
"""
def __init__(self, indices, values, dense_shape):
raise NotImplementedError

View File

@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell
from ...common import Tensor
from ...common import Tensor, IndexedSlices
from ...common.parameter import Parameter
from ...ops import functional as F
from ...ops import composite as C
@ -35,6 +35,12 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
@_grad_scale.register("Tensor", "IndexedSlices")
def tensor_grad_scale_indexed_slices(scale, grad):
return IndexedSlices(grad.indices(),
grad.values() * F.cast(reciprocal(scale), F.dtype(grad.values())),
grad.dense_shape())
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()

View File

@ -15,6 +15,8 @@
"""array_ops"""
import mindspore as ms
from mindspore.ops import composite as C
from .. import operations as P
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
@ -35,6 +37,7 @@ reshape = P.Reshape()
size_op = P.Size()
invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
is_sub_class = P.IsSubClass()
@bprop_getters.register(P.Fill)
@ -57,6 +60,29 @@ def get_bprop_dtype(self):
return bprop
dout_cast = C.MultitypeFuncGraph("dout_cast")
@dout_cast.register("Tensor", "Tensor")
def dout_cast_tensor(dout, x):
cast = P.Cast()
get_dtype = P.DType()
dx = cast(dout, get_dtype(x))
return dx
@dout_cast.register("Number", "Number")
def dout_cast_number(dout, x):
cast = P.Cast()
get_dtype = P.DType()
dx = cast(dout, get_dtype(x))
return dx
@dout_cast.register("IndexedSlices", "Tensor")
def dout_cast_indexed_slices(dout, x):
cast = P.Cast()
get_dtype = P.DType()
values = cast(dout.values(), get_dtype(x))
return IndexedSlices(dout.indices(), values, dout.dense_shape())
@bprop_getters.register(P.Cast)
def get_bprop_cast(self):
"""Generate bprop for Cast"""
@ -67,6 +93,13 @@ def get_bprop_cast(self):
dx = cast(dout, get_dtype(x))
return dx, zeros_like(t)
def bprop_sparse(x, t, out, dout):
dx = dout_cast(dout, x)
return dx, zeros_like(t)
if context.get_context('enable_sparse'):
return bprop_sparse
return bprop
@ -372,6 +405,11 @@ def get_bprop_pack(self):
def bprop(x, out, dout):
pack_grad = P.Unpack(axis)
out = pack_grad(dout)
if is_sub_class(F.typeof(x), ms.list_):
ret = []
for item in out:
ret.append(item)
return (ret,)
return (out,)
return bprop

View File

@ -18,8 +18,10 @@ import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
import mindspore as ms
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
@ -282,3 +284,26 @@ test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
def test_exec():
context.set_context(mode=context.GRAPH_MODE)
return test_exec_case
def test_grad_make_list():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, idx, x):
return x[idx, :, :]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, x)

View File

@ -19,6 +19,7 @@
@Desc : test mindspore indexed_slices's operation
"""
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
@ -222,7 +223,7 @@ def test_indexed_slices_make_indexed_slices():
class MakeIndexedSlices(nn.Cell):
def __init__(self):
super(MakeIndexedSlices, self).__init__()
self.dense_shape = (3, 4)
self.dense_shape = (3, 2)
def construct(self, indices, values):
ret = (IndexedSlices(indices, values, self.dense_shape),)
return ret[0]
@ -231,17 +232,19 @@ def test_indexed_slices_make_indexed_slices():
MakeIndexedSlices()(indices, values)
class IndexedSlicesGetAttr(nn.Cell):
def __init__(self, dense_shape):
super(IndexedSlicesGetAttr, self).__init__()
self.dense_shape = dense_shape
def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape()
def test_indexed_slices_attr():
class IndexedSlicesGetAttr(nn.Cell):
def __init__(self):
super(IndexedSlicesGetAttr, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape()
indices = Tensor([0])
values = Tensor([[1, 2]], dtype=ms.float32)
IndexedSlicesGetAttr()(indices, values)
IndexedSlicesGetAttr((3, 2))(indices, values)
def test_indexed_slices_sparse_gatherv2_grad_all():
@ -342,3 +345,109 @@ def test_indexed_slices_model_train():
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_indexed_slices_values_dim_greater_than_dense_shape_dim():
indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4, 5).astype(np.float32))
dense_shape = (3, 4)
with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values)
def test_indexed_slices_values_dim_less_than_dense_shape_dim():
indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4).astype(np.float32))
dense_shape = (3, 4, 5)
with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values)
def test_indexed_slices_value_and_dense_shape_illegal():
indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4).astype(np.float32))
dense_shape = (3, 5)
with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values)
class IndexedSlicesValuesDouble(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
indices = x.indices()
values = x.values() * 2
dense_shape = x.dense_shape()
return IndexedSlices(indices, values, dense_shape)
class IndexedSlicesValuesAdd2(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
indices = x.indices()
values = x.values() + 2
dense_shape = x.dense_shape()
return IndexedSlices(indices, values, dense_shape)
class IndexedSlicesWithControlIf(nn.Cell):
def __init__(self, dense_shape):
super().__init__()
self.op1 = IndexedSlicesValuesDouble()
self.op2 = IndexedSlicesValuesAdd2()
self.dense_shape = dense_shape
def construct(self, a, b, indices, values):
x = IndexedSlices(indices, values, self.dense_shape)
if a > b:
x = self.op1(x)
else:
x = self.op2(x)
return x.indices(), x.values()
def test_indexed_slices_with_control_flow_if():
a = Tensor(np.array(0).astype(np.int32))
b = Tensor(np.array(2).astype(np.int32))
indices = Tensor(np.array([0, 2]).astype(np.int32))
values = Tensor(np.ones([2, 2]).astype(np.float32))
dense_shape = (5, 2)
net = IndexedSlicesWithControlIf(dense_shape)
net(a, b, indices, values)
class EmbeddingLookUpBnNet(nn.Cell):
def __init__(self, param_np, target='CPU'):
super().__init__()
self.param = Parameter(Tensor(param_np), name="w1")
self.embedding_lookup = nn.EmbeddingLookup(target=target)
self.bn = nn.BatchNorm2d(num_features=3)
self.mul = P.Mul()
self.reshape = P.Reshape()
self.relu = nn.PReLU()
def construct(self, indices):
x = self.embedding_lookup(self.param, indices)
x = self.reshape(x, (2, 3, 2, 2))
x = self.relu(x)
x = self.bn(x)
return x
def test_embedding_lookup_with_mix_precision():
param_np = np.ones([8, 8]).astype(np.float32)
data = Tensor(np.array([0, 1, 2]).astype(np.int32))
label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32))
net = EmbeddingLookUpBnNet(param_np, target='CPU')
criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2")
train_network.set_train()
for _ in range(2):
train_network(data, label)