forked from mindspore-Ecosystem/mindspore
add sparse feature test cases for auto parallel
This commit is contained in:
parent
8fc6d1d81f
commit
5d63c60135
|
@ -150,6 +150,8 @@ Status ReshapeInfo::ComputeReplaceOp() {
|
||||||
ConstructOperator constructor;
|
ConstructOperator constructor;
|
||||||
replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
|
replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
|
||||||
replace_op_info_.clear();
|
replace_op_info_.clear();
|
||||||
|
MS_LOG(INFO) << "skip reshape redistribution and reshape slice_shape is "
|
||||||
|
<< ShapeToString(output_layout_.slice_shape().array());
|
||||||
} else {
|
} else {
|
||||||
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
|
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
|
||||||
if (redistribution_oplist_ptr == nullptr) {
|
if (redistribution_oplist_ptr == nullptr) {
|
||||||
|
|
|
@ -316,12 +316,6 @@ void Redistribution(const std::pair<AnfNodePtr, int> &node_pair, const OperatorI
|
||||||
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
|
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
|
||||||
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
|
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
|
||||||
|
|
||||||
if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) {
|
|
||||||
MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name()
|
|
||||||
<< "next distribute operator, operator name is" << next_distribute_operator->name();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
|
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
|
||||||
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
|
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
|
||||||
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
|
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
|
||||||
|
@ -1380,10 +1374,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
||||||
<< cloned_index << ", but not found the be cloned parameter";
|
<< cloned_index << ", but not found the be cloned parameter";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::string env = common::GetEnv("SLICE_ENV");
|
|
||||||
if (!env.empty()) {
|
|
||||||
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
||||||
|
|
|
@ -18,16 +18,14 @@ import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.tensor import Tensor, IndexedSlices
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import composite as C, operations as P
|
from mindspore.ops import composite as C, operations as P
|
||||||
from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator
|
from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator
|
||||||
from mindspore.ops._grad.grad_base import bprop_getters
|
|
||||||
from mindspore._checkparam import Validator as validator
|
|
||||||
from mindspore._checkparam import Rel
|
|
||||||
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP
|
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP
|
||||||
|
from mindspore.nn import TrainOneStepCell, Adam
|
||||||
|
|
||||||
|
|
||||||
class GradWrap(nn.Cell):
|
class GradWrap(nn.Cell):
|
||||||
def __init__(self, network):
|
def __init__(self, network):
|
||||||
|
@ -37,40 +35,9 @@ class GradWrap(nn.Cell):
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return C.grad_all(self.network)(x)
|
return C.grad_all(self.network)(x)
|
||||||
|
|
||||||
class VirtualGatherV2(PrimitiveWithInfer):
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self):
|
|
||||||
"""init index_select"""
|
|
||||||
super(VirtualGatherV2, self).__init__('VirtualGatherV2')
|
|
||||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
|
||||||
|
|
||||||
def __infer__(self, params, indices, axis):
|
|
||||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
||||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
|
||||||
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
|
||||||
axis_v = axis['value']
|
|
||||||
params_shp = params['shape']
|
|
||||||
rank = len(params_shp)
|
|
||||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
|
||||||
if axis_v < 0:
|
|
||||||
axis_v += rank
|
|
||||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
|
||||||
out = {'shape': out_shape,
|
|
||||||
'dtype': params['dtype'],
|
|
||||||
'value': None}
|
|
||||||
return out
|
|
||||||
|
|
||||||
@bprop_getters.register(VirtualGatherV2)
|
|
||||||
def get_bprop_gather_v2(self):
|
|
||||||
"""Generate bprop for GatherV2"""
|
|
||||||
|
|
||||||
def bprop(x, indices, axis, out, dout):
|
|
||||||
return IndexedSlices(indices, dout, x), axis, out
|
|
||||||
|
|
||||||
return bprop
|
|
||||||
|
|
||||||
def test_bprop_with_sparse_feature_allreduce():
|
def test_bprop_with_sparse_feature_allreduce():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
||||||
|
context.set_context(enable_sparse=True)
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, axis=0, shape=None):
|
def __init__(self, axis=0, shape=None):
|
||||||
|
@ -78,7 +45,7 @@ def test_bprop_with_sparse_feature_allreduce():
|
||||||
if shape is None:
|
if shape is None:
|
||||||
shape = [8, 8]
|
shape = [8, 8]
|
||||||
self.all_reduce = AllReduce()
|
self.all_reduce = AllReduce()
|
||||||
self.gatherv2 = P.GatherV2()
|
self.gatherv2 = P.SparseGatherV2()
|
||||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
|
|
||||||
|
@ -95,6 +62,7 @@ def test_bprop_with_sparse_feature_allreduce():
|
||||||
|
|
||||||
def test_bprop_with_sparse_feature_mirror():
|
def test_bprop_with_sparse_feature_mirror():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
||||||
|
context.set_context(enable_sparse=True)
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, axis=0, shape=None):
|
def __init__(self, axis=0, shape=None):
|
||||||
|
@ -102,7 +70,7 @@ def test_bprop_with_sparse_feature_mirror():
|
||||||
if shape is None:
|
if shape is None:
|
||||||
shape = [8, 8]
|
shape = [8, 8]
|
||||||
self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP)
|
self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP)
|
||||||
self.gatherv2 = P.GatherV2()
|
self.gatherv2 = P.SparseGatherV2()
|
||||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
|
|
||||||
|
@ -116,3 +84,35 @@ def test_bprop_with_sparse_feature_mirror():
|
||||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
|
||||||
_executor.compile(net, x)
|
_executor.compile(net, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bprop_with_sparse_feature_dataparallel():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel")
|
||||||
|
context.set_context(enable_sparse=True)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, axis=0, shape=None):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
if shape is None:
|
||||||
|
shape = [8, 8]
|
||||||
|
weight = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
self.weight = Parameter(weight, "w")
|
||||||
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||||
|
self.axis = axis
|
||||||
|
self.gatherv2 = P.SparseGatherV2()
|
||||||
|
|
||||||
|
def construct(self, x, b):
|
||||||
|
out = self.gatherv2(self.weight, self.index, self.axis)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
_x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
|
||||||
|
_b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32)
|
||||||
|
|
||||||
|
def compile_net(net):
|
||||||
|
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
|
||||||
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
|
_executor.compile(train_net, _x, _b)
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
compile_net(net)
|
||||||
|
|
Loading…
Reference in New Issue