diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc index 21298697f44..effbdc17c7f 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -44,7 +44,14 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { auto device_arrangement = tensor_layout->device_arrangement().array(); auto tensor_map = tensor_layout->tensor_map().array(); auto slice_shape = tensor_layout->slice_shape().array(); - std::vector> layout = {device_arrangement, tensor_map, slice_shape}; + int32_t _field_size = tensor_layout->get_field_size(); + std::vector field_size; + if (_field_size != 0) { + field_size.push_back(_field_size); + } else { + field_size = {0}; + } + std::vector> layout = {device_arrangement, tensor_map, slice_shape, field_size}; dict[py::str(name)] = layout; MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc index 60a3d60b392..b2ff493a13a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -105,6 +105,17 @@ Status MatMulBase::GetAttrs() { } } + auto field_size_iter = attrs_.find(FIELD_SIZE); + if (field_size_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(field_size_iter->second); + if (field_size_iter->second->isa()) { + field_size_ = field_size_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of field_size is not int."; + return FAILED; + } + } + // infer inputs dimension size if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; @@ -346,6 +357,10 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts return FAILED; } + if (field_size_ != 0) { + mat_b_layout.set_field_size(field_size_); + } + inputs_layout->push_back(mat_a_layout); inputs_layout->push_back(mat_b_layout); outputs_layout->push_back(output_layout); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h index d4e144c2b64..16f75abafce 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h @@ -62,6 +62,7 @@ class MatMulBase : public OperatorInfo { bool transpose_a_ = false; bool transpose_b_ = false; bool forward_reduce_scatter_ = false; + int32_t field_size_ = 0; size_t mat_a_dimension_ = 0; size_t mat_b_dimension_ = 0; }; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 79dfb56693b..732d25f06b7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -100,6 +100,7 @@ constexpr char CONCAT_DIM[] = "concat_dim"; constexpr char FORWARD[] = "forward"; constexpr char BACKWARD[] = "backward"; constexpr char REDISTRIBUTION[] = "redistribution"; +constexpr char SKIP_REDISTRIBUTION[] = "skip_redistribution"; constexpr char REPLACE[] = "replace"; constexpr char CONNSYMBOL[] = "/"; constexpr char INSTANCE_NAME[] = "instance_name"; @@ -131,6 +132,7 @@ constexpr char FORWARD_OP[] = "forward_op"; constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; constexpr char DARA_PARALLEL[] = "data_parallel"; constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; +constexpr char FIELD_SIZE[] = "field_size"; constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; constexpr char DEVICE[] = "Device"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index fb62c1d02c0..cc37da4b1e9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -18,6 +18,7 @@ #include #include +#include #include "frontend/parallel/device_manager.h" #include "frontend/parallel/device_matrix.h" @@ -145,17 +146,23 @@ Status ReshapeInfo::ComputeReplaceOp() { MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); - RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); - if (redistribution_oplist_ptr == nullptr) { - if (is_generating_costs_) { - MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; - } else { - MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; + if (is_skip_) { + ConstructOperator constructor; + replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array()); + replace_op_info_.clear(); + } else { + RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); + if (redistribution_oplist_ptr == nullptr) { + if (is_generating_costs_) { + MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; + } else { + MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; + } + return FAILED; } - return FAILED; + replace_op_ = redistribution_oplist_ptr->first; + replace_op_info_ = redistribution_oplist_ptr->second; } - replace_op_ = redistribution_oplist_ptr->first; - replace_op_info_ = redistribution_oplist_ptr->second; MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); return SUCCESS; } @@ -255,6 +262,19 @@ Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayout } Status ReshapeInfo::InferTensorInfo() { + // skip reshape infer if skip_redistribution is true + if (is_skip_) { + TensorLayout layout; + Shape shape; + Shape slice_shape; + layout.set_skip_redistribution(true); + TensorInfo tensor_info_in(layout, shape, slice_shape); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_in); + MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo"; + return SUCCESS; + } + Shapes inputs_slice_shape, outputs_slice_shape; Strategys inputs_strategy = strategy_->GetInputDim(); Strategys outputs_strategy = GetOutputsStrategy(); @@ -316,6 +336,16 @@ Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const l } Status ReshapeInfo::Init(const StrategyPtr &strategy) { + auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION); + if (reshape_skip_redis_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second); + if (!reshape_skip_redis_iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool."; + return FAILED; + } + is_skip_ = reshape_skip_redis_iter->second->cast()->value(); + } + ResetQueueMember(); device_number(strategy); if (strategy) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h index 2463b440f81..c9c28602cc7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h @@ -98,6 +98,7 @@ class ReshapeInfo : public OperatorInfo { bool input_layout_set_flag_; bool output_layout_set_flag_; bool is_generating_costs_; + bool is_skip_ = false; std::string pre_operator_name_; std::string next_operator_name_; }; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 6b9cfd9d370..20eaf329cf2 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -302,16 +302,26 @@ void Redistribution(const std::pair &node_pair, const OperatorI MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); // extract tensor layout in and out if (distribute_operator->outputs_tensor_info().empty()) { - MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; + MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name(); + return; } if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " - << next_distribute_operator->inputs_tensor_info().size(); + MS_LOG(WARNING) << "The index is out of range, the index is " << index - 1 << ", the vector size is " + << next_distribute_operator->inputs_tensor_info().size() << "next operator name is " + << next_distribute_operator->name(); + return; } TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); + + if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) { + MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name() + << "next distribute operator, operator name is" << next_distribute_operator->name(); + return; + } + if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc index 9395d3df89a..feb81a36ae7 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc @@ -28,6 +28,19 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix return Status::SUCCESS; } +// skip redistribution for reshape operator +OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) { + OperatorAttrs attrs; + ValuePtr param_value = MakeValue(shape); + Attr param = std::make_pair(SHAPE, param_value); + OperatorParams params = {std::make_pair(param, 2)}; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(RESHAPE, args); + OperatorVector opvector; + opvector.push_back(op); + return opvector; +} + Status ConstructOperator::ReshapeOP(Shape shape) { int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies()); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h index b06d70af364..cef2b3aa420 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h @@ -35,6 +35,7 @@ class ConstructOperator { ConstructOperator() : dev_size_(0) {} ~ConstructOperator() = default; Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); + OperatorVector SkipRedisReshapeOP(Shape shape); Status ReshapeOP(Shape shape); Status StridedSliceOP(Args args); Status AllGatherOP(int32_t dev_dim); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index a9fdc9610c8..fc891d6d9fb 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -41,6 +41,14 @@ class TensorLayout { Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, const std::vector &tensor_shape); + bool skip_redistribution() const { return skip_redistribution_; } + + void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } + + int32_t get_field_size() const { return field_size_; } + + void set_field_size(int32_t field_size) { field_size_ = field_size; } + Arrangement device_arrangement() const { return device_arrangement_; } Map tensor_map() const { return tensor_map_; } @@ -92,6 +100,8 @@ class TensorLayout { Arrangement device_arrangement_; Map tensor_map_; Arrangement tensor_shape_; + bool skip_redistribution_ = false; + int32_t field_size_ = 0; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 1605ee4bc55..9405e7b2602 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -247,8 +247,8 @@ class Parameter: if not isinstance(layout, list): raise TypeError("The layout should be list! layout is {}." .format(layout)) - if len(layout) != 3: - raise ValueError("The length of layout must be 3! layout is {}." + if len(layout) < 3: + raise ValueError("The length of layout must be larger than 3! layout is {}." .format(layout)) slice_index = int(_get_slice_index(layout[0], layout[1])) self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) diff --git a/mindspore/parallel/_tensor.py b/mindspore/parallel/_tensor.py index fca8b889201..598046f66a6 100644 --- a/mindspore/parallel/_tensor.py +++ b/mindspore/parallel/_tensor.py @@ -229,8 +229,8 @@ def _load_tensor_by_layout(tensor, layout): """ if not isinstance(layout, list): raise TypeError("The layout should be list! layout is {}".format(layout)) - if len(layout) != 3: - raise ValueError("The length of layout must be 3! layout is {}".format(layout)) + if len(layout) < 3: + raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) dev_mat = layout[0] tensor_map = layout[1] if tensor.size() == 1: @@ -290,3 +290,37 @@ def _reshape_param_data(param_data, dev_mat, tensor_map): tensor_slices_new = tensor_slices_new_inner return Tensor(tensor_slices_new[0]) + +def _reshape_param_data_with_weight(param_data, dev_mat, field_size): + """ + Combine param slice by the device matrix, used in model parallel scenario. + + Args: + param_data (Tensor): The tensor to be reshaped and rearrangement, + generated from all the device from AllGatherParamNet. + dev_mat (list): The device matrix of devices. + Returns: + Tensor, the combined tensor which with the whole data value. + + Examples: + >>> param_data = _allgather_param_net(param_data) + >>> dev_mat = [2, 2] + >>> field_size = [39] + >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size) + """ + device_count = 1 + for dim in dev_mat: + device_count *= dim + + tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) + tensor_slices_col = [] + for i in range(len(tensor_slices[0][0])): + tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1) + for j in range(1, device_count): + tensor_slices_new = np.concatenate((tensor_slices_new,\ + np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1) + tensor_slices_col.append(tensor_slices_new) + new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1) + for i in range(1, len(tensor_slices_col)): + new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1) + return Tensor(new_tensor) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 3812698419c..c3f5d5c1f90 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -359,14 +359,17 @@ def _get_merged_param_data(net, param_name, param_data): dev_mat = layout[0] tensor_map = layout[1] + field_size = layout[3] from mindspore.parallel._cell_wrapper import get_allgather_cell - from mindspore.parallel._tensor import _reshape_param_data + from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight # while any dim is not equal to -1, means param is splited and needs to be merged for dim in tensor_map: if dim != -1: allgather_net = get_allgather_cell() param_data = allgather_net(param_data) + if field_size[0]: + return _reshape_param_data_with_weight(param_data, dev_mat, field_size) return _reshape_param_data(param_data, dev_mat, tensor_map) return param_data diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index a34ee94840a..23649b5f0c6 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -49,8 +49,8 @@ def test_get_parameter_layout(): net.set_auto_parallel() exe = me._executor exe.compile(net, x, phase='train', auto_parallel_mode=True) - x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] - weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] + x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] + weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] expect_dict = {'x': x_layout, 'w1': weight_layout} # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut assert net.parameter_layout_dict == expect_dict diff --git a/tests/ut/python/parallel/test_reshape_skip_redistribution.py b/tests/ut/python/parallel/test_reshape_skip_redistribution.py new file mode 100644 index 00000000000..cbaf20d1132 --- /dev/null +++ b/tests/ut/python/parallel/test_reshape_skip_redistribution.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, matmul_weight, strategy1=None): + super().__init__() + self.gatherv2 = P.GatherV2().set_strategy(strategy1) + self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True) + self.matmul = P.MatMul(transpose_b=False) + self.index = Tensor(np.ones([64, 64]), dtype=ms.int32) + self.matmul_weight = Parameter(matmul_weight, "w1") + self.axis = 0 + + def construct(self, x, b): + out = self.gatherv2(x, self.index, self.axis) + out = self.reshape(out, (64, -1)) + out = self.matmul(out, self.matmul_weight) + return out + + +_w1 = Tensor(np.ones([4096, 32]), dtype=ms.float32) +_x = Tensor(np.ones([64, 64]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + +def compile_net(net): + context.set_context(save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_reshape_skip_redistribution(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8), (1, 1)) + net = Net(_w1, strategy1) + compile_net(net)