fix sparse related issues

This commit is contained in:
panyifeng 2020-07-22 14:11:20 +08:00
parent e4c8365dfe
commit 8a89f003eb
9 changed files with 225 additions and 14 deletions

View File

@ -378,10 +378,25 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
if (dense_shape_vec.size() != values_shp.size()) {
MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values "
<< values_shp.size() << ", but got " << dense_shape_value->size();
}
for (size_t i = 0; i < dense_shape_vec.size(); i++) {
if (dense_shape_vec[i] < 0) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got "
<< dense_shape_vec[i];
}
if (i == 0) {
if (dense_shape_vec[i] < values_shp[i]) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape should be greator or equal to the " << i
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
}
} else {
if (dense_shape_vec[i] != values_shp[i]) {
MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
}
}
}
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);

View File

@ -386,6 +386,16 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic["shape"] = arg_tensor->shape()->shape();
dic["dtype"] = arg_tensor->BuildType();
dic["value"] = BuildValue(arg_tensor->BuildValue());
} else if (abs_base->isa<AbstractIndexedSlices>()) {
auto arg = dyn_cast<AbstractIndexedSlices>(abs_base);
dic["shape"] = arg->shape()->shape();
dic["dtype"] = arg->BuildType();
dic["value"] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractSparseTensor>()) {
auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
dic["shape"] = arg->shape()->shape();
dic["dtype"] = arg->BuildType();
dic["value"] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
std::vector<int> shape;
dic["shape"] = shape;

View File

@ -99,6 +99,8 @@ slice_type = typing.Slice
ellipsis_type = typing.TypeEllipsis
list_type = typing.List
tuple_type = typing.Tuple
index_slices = typing.IndexedSlicesType()
sparse_tensor = typing.SparseTensorType()
number_type = (int8,
int16,

View File

@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell
from ...common import Tensor
from ...common import Tensor, IndexedSlices
from ...common.parameter import Parameter
from ...ops import functional as F
from ...ops import composite as C
@ -35,6 +35,12 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
@_grad_scale.register("Tensor", "IndexedSlices")
def tensor_grad_scale_indexed_slices(scale, grad):
return IndexedSlices(grad.indices(),
grad.values() * F.cast(reciprocal(scale), F.dtype(grad.values())),
grad.dense_shape())
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()

View File

@ -15,6 +15,8 @@
"""array_ops"""
import mindspore as ms
from mindspore.ops import composite as C
from .. import operations as P
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
@ -35,6 +37,7 @@ reshape = P.Reshape()
size_op = P.Size()
invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
is_sub_class = P.IsSubClass()
@bprop_getters.register(P.Fill)
@ -57,6 +60,29 @@ def get_bprop_dtype(self):
return bprop
dout_cast = C.MultitypeFuncGraph("dout_cast")
@dout_cast.register("Tensor", "Tensor")
def dout_cast_tensor(dout, x):
cast = P.Cast()
get_dtype = P.DType()
dx = cast(dout, get_dtype(x))
return dx
@dout_cast.register("Number", "Number")
def dout_cast_number(dout, x):
cast = P.Cast()
get_dtype = P.DType()
dx = cast(dout, get_dtype(x))
return dx
@dout_cast.register("IndexedSlices", "Tensor")
def dout_cast_indexed_slices(dout, x):
cast = P.Cast()
get_dtype = P.DType()
values = cast(dout.values(), get_dtype(x))
return IndexedSlices(dout.indices(), values, dout.dense_shape())
@bprop_getters.register(P.Cast)
def get_bprop_cast(self):
"""Generate bprop for Cast"""
@ -67,6 +93,13 @@ def get_bprop_cast(self):
dx = cast(dout, get_dtype(x))
return dx, zeros_like(t)
def bprop_sparse(x, t, out, dout):
dx = dout_cast(dout, x)
return dx, zeros_like(t)
if context.get_context('enable_sparse'):
return bprop_sparse
return bprop
@ -372,6 +405,11 @@ def get_bprop_pack(self):
def bprop(x, out, dout):
pack_grad = P.Unpack(axis)
out = pack_grad(dout)
if is_sub_class(F.typeof(x), ms.list_):
ret = []
for item in out:
ret.append(item)
return (ret,)
return (out,)
return bprop

View File

@ -18,8 +18,10 @@ import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
import mindspore as ms
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
@ -282,3 +284,26 @@ test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
def test_exec():
context.set_context(mode=context.GRAPH_MODE)
return test_exec_case
def test_grad_make_list():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, idx, x):
return x[idx, :, :]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, x)

View File

@ -19,6 +19,7 @@
@Desc : test mindspore indexed_slices's operation
"""
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
@ -222,7 +223,7 @@ def test_indexed_slices_make_indexed_slices():
class MakeIndexedSlices(nn.Cell):
def __init__(self):
super(MakeIndexedSlices, self).__init__()
self.dense_shape = (3, 4)
self.dense_shape = (3, 2)
def construct(self, indices, values):
ret = (IndexedSlices(indices, values, self.dense_shape),)
return ret[0]
@ -231,17 +232,19 @@ def test_indexed_slices_make_indexed_slices():
MakeIndexedSlices()(indices, values)
class IndexedSlicesGetAttr(nn.Cell):
def __init__(self, dense_shape):
super(IndexedSlicesGetAttr, self).__init__()
self.dense_shape = dense_shape
def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape()
def test_indexed_slices_attr():
class IndexedSlicesGetAttr(nn.Cell):
def __init__(self):
super(IndexedSlicesGetAttr, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape()
indices = Tensor([0])
values = Tensor([[1, 2]], dtype=ms.float32)
IndexedSlicesGetAttr()(indices, values)
IndexedSlicesGetAttr((3, 2))(indices, values)
def test_indexed_slices_sparse_gatherv2_grad_all():
@ -342,3 +345,109 @@ def test_indexed_slices_model_train():
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_indexed_slices_values_dim_greater_than_dense_shape_dim():
indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4, 5).astype(np.float32))
dense_shape = (3, 4)
with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values)
def test_indexed_slices_values_dim_less_than_dense_shape_dim():
indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4).astype(np.float32))
dense_shape = (3, 4, 5)
with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values)
def test_indexed_slices_value_and_dense_shape_illegal():
indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4).astype(np.float32))
dense_shape = (3, 5)
with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values)
class IndexedSlicesValuesDouble(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
indices = x.indices()
values = x.values() * 2
dense_shape = x.dense_shape()
return IndexedSlices(indices, values, dense_shape)
class IndexedSlicesValuesAdd2(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
indices = x.indices()
values = x.values() + 2
dense_shape = x.dense_shape()
return IndexedSlices(indices, values, dense_shape)
class IndexedSlicesWithControlIf(nn.Cell):
def __init__(self, dense_shape):
super().__init__()
self.op1 = IndexedSlicesValuesDouble()
self.op2 = IndexedSlicesValuesAdd2()
self.dense_shape = dense_shape
def construct(self, a, b, indices, values):
x = IndexedSlices(indices, values, self.dense_shape)
if a > b:
x = self.op1(x)
else:
x = self.op2(x)
return x.indices(), x.values()
def test_indexed_slices_with_control_flow_if():
a = Tensor(np.array(0).astype(np.int32))
b = Tensor(np.array(2).astype(np.int32))
indices = Tensor(np.array([0, 2]).astype(np.int32))
values = Tensor(np.ones([2, 2]).astype(np.float32))
dense_shape = (5, 2)
net = IndexedSlicesWithControlIf(dense_shape)
net(a, b, indices, values)
class EmbeddingLookUpBnNet(nn.Cell):
def __init__(self, param_np, target='CPU'):
super().__init__()
self.param = Parameter(Tensor(param_np), name="w1")
self.embedding_lookup = nn.EmbeddingLookup(target=target)
self.bn = nn.BatchNorm2d(num_features=3)
self.mul = P.Mul()
self.reshape = P.Reshape()
self.relu = nn.PReLU()
def construct(self, indices):
x = self.embedding_lookup(self.param, indices)
x = self.reshape(x, (2, 3, 2, 2))
x = self.relu(x)
x = self.bn(x)
return x
def test_embedding_lookup_with_mix_precision():
param_np = np.ones([8, 8]).astype(np.float32)
data = Tensor(np.array([0, 1, 2]).astype(np.int32))
label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32))
net = EmbeddingLookUpBnNet(param_np, target='CPU')
criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2")
train_network.set_train()
for _ in range(2):
train_network(data, label)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
@ -99,6 +100,7 @@ def test_embeddinglookup_reducescatter_true_grad():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_embeddinglookup_semi_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]
@ -113,6 +115,7 @@ def test_embeddinglookup_semi_auto1():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_embeddinglookup_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]

View File

@ -61,6 +61,7 @@ class Net(nn.Cell):
return out
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_gatherv2_semi_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
@ -133,6 +134,7 @@ def test_gatherv2_semi_auto5():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_gatherv2_semi_auto6():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy2 = ((4, 2, 1), (4, 2, 1))
@ -167,6 +169,7 @@ def test_gatherv2_semi_auto8():
_executor.compile(net, x, y)
@pytest.mark.skip(reason="waiting for fix by parallel strategy")
def test_gatherv2_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net(0)))