From 3102c4ff8d12ffcb560b2b1965eabe631bd50dac Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Mon, 26 Oct 2020 19:42:07 +0800 Subject: [PATCH] support split ValueList --- .../frontend/parallel/step_auto_parallel.cc | 27 ++++++ .../ccsrc/frontend/parallel/step_parallel.cc | 93 ++++++++++++++++++- tests/ut/python/parallel/test_pack.py | 86 +++++++++++++++++ 3 files changed, 203 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 326be244a4e..d1c2cb00343 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -117,6 +117,17 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { std::vector ExtractInputParameterByNode(const CNodePtr &node) { std::vector is_parameter; std::vector node_inputs{node->inputs()}; + // input is a ValueList or ValueTuple, then all inputs are not parameter. + if ((node_inputs.size() == 2) && + (IsValueNode(node_inputs[1]) || IsValueNode(node_inputs[1]))) { + std::vector inputs_seq; + if (IsValueNode(node_inputs[1])) { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } else { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } + return std::vector(inputs_seq.size(), false); + } if ((node_inputs.size() == 2) && (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { node_inputs = node_inputs[1]->cast()->inputs(); @@ -195,6 +206,22 @@ std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { std::vector inputs_type_len; std::vector node_inputs{node->inputs()}; + if ((node_inputs.size() == 2) && + (IsValueNode(node_inputs[1]) || IsValueNode(node_inputs[1]))) { + std::vector inputs_seq; + if (IsValueNode(node_inputs[1])) { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } else { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } + for (auto &ele : inputs_seq) { + auto tensor = ele->cast(); + MS_EXCEPTION_IF_NULL(tensor); + inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); + } + return inputs_type_len; + } + if ((node_inputs.size() == 2) && (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { node_inputs = node_inputs[1]->cast()->inputs(); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index f182e8e405c..bcda26eed94 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -533,6 +533,58 @@ void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { } } +void SplitTensorList(const AnfNodePtr &node, const CNodePtr &next_node, int index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(next_node); + if (next_node->inputs().size() != 2 || index != 1) { + MS_LOG(INFO) << next_node->fullname_with_scope() << " Inputs must have only one input, get " + << next_node->inputs().size() - 1 << " index should be 1, get " << index; + return; + } + OperatorInfoPtr op_info = next_node->user_data(); + MS_EXCEPTION_IF_NULL(op_info); + + std::vector inputs_values; + if (IsValueNode(node)) { + inputs_values = node->cast()->value()->cast()->value(); + } else { + inputs_values = node->cast()->value()->cast()->value(); + } + if (inputs_values.size() != op_info->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The inputs size " << inputs_values.size() << ", is not equal to inputs shape size " + << op_info->inputs_tensor_info().size(); + } + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + FuncGraphPtr func_graph = next_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + ScopePtr scope = next_node->scope(); + MS_EXCEPTION_IF_NULL(scope); + for (size_t i = 0; i < inputs_values.size(); ++i) { + auto value_ptr = inputs_values[i]; + auto tensor = value_ptr->cast(); + MS_EXCEPTION_IF_NULL(tensor); + TensorInfo tensor_info = op_info->inputs_tensor_info()[i]; + TensorLayout tensor_layout = tensor_info.tensor_layout(); + auto value_node = NewValueNode(value_ptr)->cast(); + Operator op = CreateGetTensorSliceOp(tensor_layout); + std::vector node_input = CreateInput(op, value_node, SPLIT_TENSOR); + CNodePtr new_node = func_graph->NewCNode(node_input); + new_node->set_in_forward_flag(true); + auto new_node_value = node_input[0]->cast(); + MS_EXCEPTION_IF_NULL(new_node_value); + PrimitivePtr new_node_prim = new_node_value->value()->cast(); + new_node_prim->set_instance_name(SPLIT_TENSOR); + new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + new_node->set_scope(scope); + node_input[0]->set_scope(scope); + make_tuple_inputs.push_back(new_node); + } + CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + manager->Replace(node, make_tuple); +} + void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(manager); @@ -550,7 +602,11 @@ void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) continue; } if (IsParallelCareNode(use_cnode)) { - SplitTensor(node, use_cnode, node_pair.second); + if (IsValueNode(node) || IsValueNode(node)) { + SplitTensorList(node, use_cnode, node_pair.second); + } else { + SplitTensor(node, use_cnode, node_pair.second); + } } } } @@ -852,6 +908,11 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); + if ((node->inputs().size() == 2) && (IsValueNode(node->input(1)))) { + MS_LOG(INFO) << "Input is ValueList, skip it."; + return; + } + if ((node->inputs().size() == 2) && (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) { MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node"; @@ -1049,9 +1110,34 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { return strategyPtr; } +Shapes GetValueListShape(const AnfNodePtr &node) { + Shapes shapes; + std::vector inputs_seq; + if (IsValueNode(node)) { + inputs_seq = node->cast()->value()->cast()->value(); + } else if (IsValueNode(node)) { + inputs_seq = node->cast()->value()->cast()->value(); + } else { + MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple"; + } + for (auto &ele : inputs_seq) { + auto tensor = ele->cast(); + MS_EXCEPTION_IF_NULL(tensor); + auto one_shape = tensor->shape(); + Shape shape_64; + (void)std::transform(one_shape.begin(), one_shape.end(), std::back_inserter(shape_64), + [](const int &value) { return static_cast(value); }); + shapes.push_back(shape_64); + } + return shapes; +} + Shapes GetNodeShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shapes; + if (IsValueNode(node) || IsValueNode(node)) { + return GetValueListShape(node); + } BaseShapePtr base_shape_ptr = node->Shape(); if (node->isa()) { auto cnode = node->cast(); @@ -1177,7 +1263,8 @@ std::vector ExtractShape(const CNodePtr &node) { std::pair node_pair = std::make_pair(node, SizeToInt(i)); g_RefMap[parameters[0]] = node_pair; input_shapes = GetRefKeyNodeShape(input, func_graph); - } else if (IsValueNode(input) || input->isa() || input->isa()) { + } else if (IsValueNode(input) || input->isa() || input->isa() || + ((IsValueNode(input) || IsValueNode(input)) && (inputs_size == 2))) { input_shapes = GetNodeShape(input); } else { continue; @@ -2258,7 +2345,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector(node)) { + } else if (IsValueNode(node) || IsValueNode(node) || IsValueNode(node)) { StepSplitTensor(node, manager); } } diff --git a/tests/ut/python/parallel/test_pack.py b/tests/ut/python/parallel/test_pack.py index ccc83567038..8de77f01d42 100644 --- a/tests/ut/python/parallel/test_pack.py +++ b/tests/ut/python/parallel/test_pack.py @@ -20,6 +20,7 @@ import mindspore.nn as nn from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, Momentum from mindspore.ops import operations as P +from mindspore.nn import Dense, Flatten class Net(nn.Cell): @@ -71,12 +72,67 @@ class Net2(nn.Cell): return out +class PackConstantNet1(nn.Cell): + def __init__(self, dense_in_channel, dense_out_channel, axis=0, shape=None, strategy=None): + super().__init__() + weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) + bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32) + self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32)) + self.flat = Flatten() + self.dense = Dense(in_channels=dense_in_channel, + out_channels=dense_out_channel, + weight_init=Tensor(weight_np), + bias_init=Tensor(bias_np), + has_bias=True) + self.mul = P.Mul() + self.pack = P.Pack(axis) + if strategy is not None: + self.pack.shard(strategy) + + def construct(self, inputs): + x = self.pack([self.pack_con, self.pack_con, self.pack_con, self.pack_con, + self.pack_con, self.pack_con, self.pack_con, self.pack_con]) + x1 = self.flat(x) + x2 = self.flat(inputs) + x = self.mul(x1, x2) + x = self.dense(x) + return x + + +class PackConstantNet2(nn.Cell): + def __init__(self, dense_in_channel, dense_out_channel, axis=0, shape=None, strategy=None): + super().__init__() + weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) + bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32) + self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32)) + self.flat = Flatten() + self.dense = Dense(in_channels=dense_in_channel, + out_channels=dense_out_channel, + weight_init=Tensor(weight_np), + bias_init=Tensor(bias_np), + has_bias=True) + self.mul = P.Mul() + self.pack = P.Pack(axis) + if strategy is not None: + self.pack.shard(strategy) + + def construct(self, inputs): + x = self.pack((self.pack_con, self.pack_con, self.pack_con, self.pack_con, + self.pack_con, self.pack_con, self.pack_con, self.pack_con)) + x1 = self.flat(x) + x2 = self.flat(inputs) + x = self.mul(x1, x2) + x = self.dense(x) + return x + + _w1 = Tensor(np.ones([48, 64]), dtype=ms.float32) _w2 = Tensor(np.ones([48, 64]), dtype=ms.float32) _w3 = Tensor(np.ones([48, 64]), dtype=ms.float32) _x = Tensor(np.ones([2, 48, 64]), dtype=ms.float32) _x1 = Tensor(np.ones([48, 64]), dtype=ms.float32) _x2 = Tensor(np.ones([3, 48, 64]), dtype=ms.float32) +_x_c = Tensor(np.ones([8, 8, 8]), dtype=ms.float32) def compile_net(net): @@ -106,6 +162,15 @@ def compile_net2(net): context.reset_auto_parallel_context() +def compile_net_con(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x_c) + context.reset_auto_parallel_context() + + def test_pack_parameter(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((4, 2), (4, 2)) @@ -186,3 +251,24 @@ def test_pack_auto_parallel_3_tensor(): context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) net = Net2(_w1, _w2, _w3) compile_net2(net) + + +def test_pack_constant1(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + net = PackConstantNet1(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8), + strategy=((4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1))) + compile_net_con(net) + + +def test_pack_constant2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + net = PackConstantNet2(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8), + strategy=((4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1))) + compile_net_con(net) + + +def test_pack_auto_constant(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = PackConstantNet1(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8), + strategy=((8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1))) + compile_net_con(net)