From 5d63c60135baa43decda2f9910502fc3e1ed2db6 Mon Sep 17 00:00:00 2001 From: lirongzhen1 Date: Wed, 22 Jul 2020 11:13:54 +0800 Subject: [PATCH] add sparse feature test cases for auto parallel --- .../parallel/ops_info/reshape_info.cc | 2 + .../ccsrc/frontend/parallel/step_parallel.cc | 10 --- .../parallel/test_sparse_feature_bprop.py | 80 +++++++++---------- 3 files changed, 42 insertions(+), 50 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index cc37da4b1e9..11ef3e43d38 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -150,6 +150,8 @@ Status ReshapeInfo::ComputeReplaceOp() { ConstructOperator constructor; replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array()); replace_op_info_.clear(); + MS_LOG(INFO) << "skip reshape redistribution and reshape slice_shape is " + << ShapeToString(output_layout_.slice_shape().array()); } else { RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); if (redistribution_oplist_ptr == nullptr) { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index dec37030c71..689f3055151 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -316,12 +316,6 @@ void Redistribution(const std::pair &node_pair, const OperatorI TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); - if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) { - MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name() - << "next distribute operator, operator name is" << next_distribute_operator->name(); - return; - } - if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " @@ -1380,10 +1374,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { << cloned_index << ", but not found the be cloned parameter"; } } - std::string env = common::GetEnv("SLICE_ENV"); - if (!env.empty()) { - MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; - } } void SetVirtualDatasetStrategy(const CNodePtr &node) { diff --git a/tests/ut/python/parallel/test_sparse_feature_bprop.py b/tests/ut/python/parallel/test_sparse_feature_bprop.py index 515be06e450..14f794c2920 100644 --- a/tests/ut/python/parallel/test_sparse_feature_bprop.py +++ b/tests/ut/python/parallel/test_sparse_feature_bprop.py @@ -18,16 +18,14 @@ import numpy as np import mindspore as ms import mindspore.nn as nn from mindspore import context -from mindspore.common import dtype as mstype -from mindspore.common.tensor import Tensor, IndexedSlices +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor from mindspore.ops import composite as C, operations as P from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator -from mindspore.ops._grad.grad_base import bprop_getters -from mindspore._checkparam import Validator as validator -from mindspore._checkparam import Rel -from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer from mindspore.common.api import _executor from mindspore.communication.management import HCCL_WORLD_COMM_GROUP +from mindspore.nn import TrainOneStepCell, Adam + class GradWrap(nn.Cell): def __init__(self, network): @@ -37,40 +35,9 @@ class GradWrap(nn.Cell): def construct(self, x): return C.grad_all(self.network)(x) -class VirtualGatherV2(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - """init index_select""" - super(VirtualGatherV2, self).__init__('VirtualGatherV2') - self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) - - def __infer__(self, params, indices, axis): - validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) - validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) - axis_v = axis['value'] - params_shp = params['shape'] - rank = len(params_shp) - validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) - if axis_v < 0: - axis_v += rank - out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] - out = {'shape': out_shape, - 'dtype': params['dtype'], - 'value': None} - return out - -@bprop_getters.register(VirtualGatherV2) -def get_bprop_gather_v2(self): - """Generate bprop for GatherV2""" - - def bprop(x, indices, axis, out, dout): - return IndexedSlices(indices, dout, x), axis, out - - return bprop - def test_bprop_with_sparse_feature_allreduce(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") + context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self, axis=0, shape=None): @@ -78,7 +45,7 @@ def test_bprop_with_sparse_feature_allreduce(): if shape is None: shape = [8, 8] self.all_reduce = AllReduce() - self.gatherv2 = P.GatherV2() + self.gatherv2 = P.SparseGatherV2() self.index = Tensor(np.ones(shape), dtype=ms.int32) self.axis = axis @@ -95,6 +62,7 @@ def test_bprop_with_sparse_feature_allreduce(): def test_bprop_with_sparse_feature_mirror(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") + context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self, axis=0, shape=None): @@ -102,7 +70,7 @@ def test_bprop_with_sparse_feature_mirror(): if shape is None: shape = [8, 8] self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP) - self.gatherv2 = P.GatherV2() + self.gatherv2 = P.SparseGatherV2() self.index = Tensor(np.ones(shape), dtype=ms.int32) self.axis = axis @@ -116,3 +84,35 @@ def test_bprop_with_sparse_feature_mirror(): x = Tensor(np.ones([64, 64]), dtype=ms.float32) _executor.compile(net, x) + + +def test_bprop_with_sparse_feature_dataparallel(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel") + context.set_context(enable_sparse=True) + + class Net(nn.Cell): + def __init__(self, axis=0, shape=None): + super(Net, self).__init__() + if shape is None: + shape = [8, 8] + weight = Tensor(np.ones([64, 64]), dtype=ms.float32) + self.weight = Parameter(weight, "w") + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.axis = axis + self.gatherv2 = P.SparseGatherV2() + + def construct(self, x, b): + out = self.gatherv2(self.weight, self.index, self.axis) + + return out + + _x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) + _b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) + + def compile_net(net): + optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + + net = Net() + compile_net(net)