forked from OSSInnovation/mindspore
!3311 add sparse feature test cases for auto parallel
Merge pull request !3311 from lirongzhen1/master
This commit is contained in:
commit
7f1ccc5f3b
|
@ -150,6 +150,8 @@ Status ReshapeInfo::ComputeReplaceOp() {
|
|||
ConstructOperator constructor;
|
||||
replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
|
||||
replace_op_info_.clear();
|
||||
MS_LOG(INFO) << "skip reshape redistribution and reshape slice_shape is "
|
||||
<< ShapeToString(output_layout_.slice_shape().array());
|
||||
} else {
|
||||
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
|
||||
if (redistribution_oplist_ptr == nullptr) {
|
||||
|
|
|
@ -316,12 +316,6 @@ void Redistribution(const std::pair<AnfNodePtr, int> &node_pair, const OperatorI
|
|||
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
|
||||
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
|
||||
|
||||
if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) {
|
||||
MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name()
|
||||
<< "next distribute operator, operator name is" << next_distribute_operator->name();
|
||||
return;
|
||||
}
|
||||
|
||||
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
|
||||
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
|
||||
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
|
||||
|
@ -1380,10 +1374,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
<< cloned_index << ", but not found the be cloned parameter";
|
||||
}
|
||||
}
|
||||
std::string env = common::GetEnv("SLICE_ENV");
|
||||
if (!env.empty()) {
|
||||
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
|
||||
}
|
||||
}
|
||||
|
||||
void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
||||
|
|
|
@ -18,16 +18,14 @@ import numpy as np
|
|||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor, IndexedSlices
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C, operations as P
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP
|
||||
from mindspore.nn import TrainOneStepCell, Adam
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
@ -37,40 +35,9 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x):
|
||||
return C.grad_all(self.network)(x)
|
||||
|
||||
class VirtualGatherV2(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init index_select"""
|
||||
super(VirtualGatherV2, self).__init__('VirtualGatherV2')
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
def __infer__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
||||
axis_v = axis['value']
|
||||
params_shp = params['shape']
|
||||
rank = len(params_shp)
|
||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||
if axis_v < 0:
|
||||
axis_v += rank
|
||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': params['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
@bprop_getters.register(VirtualGatherV2)
|
||||
def get_bprop_gather_v2(self):
|
||||
"""Generate bprop for GatherV2"""
|
||||
|
||||
def bprop(x, indices, axis, out, dout):
|
||||
return IndexedSlices(indices, dout, x), axis, out
|
||||
|
||||
return bprop
|
||||
|
||||
def test_bprop_with_sparse_feature_allreduce():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, shape=None):
|
||||
|
@ -78,7 +45,7 @@ def test_bprop_with_sparse_feature_allreduce():
|
|||
if shape is None:
|
||||
shape = [8, 8]
|
||||
self.all_reduce = AllReduce()
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
||||
|
@ -95,6 +62,7 @@ def test_bprop_with_sparse_feature_allreduce():
|
|||
|
||||
def test_bprop_with_sparse_feature_mirror():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, shape=None):
|
||||
|
@ -102,7 +70,7 @@ def test_bprop_with_sparse_feature_mirror():
|
|||
if shape is None:
|
||||
shape = [8, 8]
|
||||
self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP)
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
||||
|
@ -116,3 +84,35 @@ def test_bprop_with_sparse_feature_mirror():
|
|||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x)
|
||||
|
||||
|
||||
def test_bprop_with_sparse_feature_dataparallel():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel")
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, shape=None):
|
||||
super(Net, self).__init__()
|
||||
if shape is None:
|
||||
shape = [8, 8]
|
||||
weight = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
self.weight = Parameter(weight, "w")
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.gatherv2(self.weight, self.index, self.axis)
|
||||
|
||||
return out
|
||||
|
||||
_x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_net, _x, _b)
|
||||
|
||||
net = Net()
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue