add sparse feature test cases for auto parallel

This commit is contained in:
lirongzhen1 2020-07-22 11:13:54 +08:00
parent 8fc6d1d81f
commit 5d63c60135
3 changed files with 42 additions and 50 deletions

View File

@ -150,6 +150,8 @@ Status ReshapeInfo::ComputeReplaceOp() {
ConstructOperator constructor;
replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
replace_op_info_.clear();
MS_LOG(INFO) << "skip reshape redistribution and reshape slice_shape is "
<< ShapeToString(output_layout_.slice_shape().array());
} else {
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
if (redistribution_oplist_ptr == nullptr) {

View File

@ -316,12 +316,6 @@ void Redistribution(const std::pair<AnfNodePtr, int> &node_pair, const OperatorI
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) {
MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name()
<< "next distribute operator, operator name is" << next_distribute_operator->name();
return;
}
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
@ -1380,10 +1374,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
<< cloned_index << ", but not found the be cloned parameter";
}
}
std::string env = common::GetEnv("SLICE_ENV");
if (!env.empty()) {
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
}
}
void SetVirtualDatasetStrategy(const CNodePtr &node) {

View File

@ -18,16 +18,14 @@ import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor, IndexedSlices
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
from mindspore.common.api import _executor
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP
from mindspore.nn import TrainOneStepCell, Adam
class GradWrap(nn.Cell):
def __init__(self, network):
@ -37,40 +35,9 @@ class GradWrap(nn.Cell):
def construct(self, x):
return C.grad_all(self.network)(x)
class VirtualGatherV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init index_select"""
super(VirtualGatherV2, self).__init__('VirtualGatherV2')
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value']
params_shp = params['shape']
rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
if axis_v < 0:
axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
@bprop_getters.register(VirtualGatherV2)
def get_bprop_gather_v2(self):
"""Generate bprop for GatherV2"""
def bprop(x, indices, axis, out, dout):
return IndexedSlices(indices, dout, x), axis, out
return bprop
def test_bprop_with_sparse_feature_allreduce():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
context.set_context(enable_sparse=True)
class Net(nn.Cell):
def __init__(self, axis=0, shape=None):
@ -78,7 +45,7 @@ def test_bprop_with_sparse_feature_allreduce():
if shape is None:
shape = [8, 8]
self.all_reduce = AllReduce()
self.gatherv2 = P.GatherV2()
self.gatherv2 = P.SparseGatherV2()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
@ -95,6 +62,7 @@ def test_bprop_with_sparse_feature_allreduce():
def test_bprop_with_sparse_feature_mirror():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
context.set_context(enable_sparse=True)
class Net(nn.Cell):
def __init__(self, axis=0, shape=None):
@ -102,7 +70,7 @@ def test_bprop_with_sparse_feature_mirror():
if shape is None:
shape = [8, 8]
self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP)
self.gatherv2 = P.GatherV2()
self.gatherv2 = P.SparseGatherV2()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
@ -116,3 +84,35 @@ def test_bprop_with_sparse_feature_mirror():
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x)
def test_bprop_with_sparse_feature_dataparallel():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel")
context.set_context(enable_sparse=True)
class Net(nn.Cell):
def __init__(self, axis=0, shape=None):
super(Net, self).__init__()
if shape is None:
shape = [8, 8]
weight = Tensor(np.ones([64, 64]), dtype=ms.float32)
self.weight = Parameter(weight, "w")
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
self.gatherv2 = P.SparseGatherV2()
def construct(self, x, b):
out = self.gatherv2(self.weight, self.index, self.axis)
return out
_x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
def compile_net(net):
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
train_net = TrainOneStepCell(net, optimizer)
_executor.compile(train_net, _x, _b)
net = Net()
compile_net(net)