!2876 set reshape operator no redistribution for auto parallel

Merge pull request !2876 from lirongzhen1/master
This commit is contained in:
mindspore-ci-bot 2020-07-17 10:01:06 +08:00 committed by Gitee
commit edec821c50
15 changed files with 205 additions and 20 deletions

View File

@ -44,7 +44,14 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array();
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape};
int32_t _field_size = tensor_layout->get_field_size();
std::vector<int32_t> field_size;
if (_field_size != 0) {
field_size.push_back(_field_size);
} else {
field_size = {0};
}
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape, field_size};
dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
}

View File

@ -105,6 +105,17 @@ Status MatMulBase::GetAttrs() {
}
}
auto field_size_iter = attrs_.find(FIELD_SIZE);
if (field_size_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(field_size_iter->second);
if (field_size_iter->second->isa<Int32Imm>()) {
field_size_ = field_size_iter->second->cast<Int32ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of field_size is not int.";
return FAILED;
}
}
// infer inputs dimension size
if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
@ -346,6 +357,10 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts
return FAILED;
}
if (field_size_ != 0) {
mat_b_layout.set_field_size(field_size_);
}
inputs_layout->push_back(mat_a_layout);
inputs_layout->push_back(mat_b_layout);
outputs_layout->push_back(output_layout);

View File

@ -62,6 +62,7 @@ class MatMulBase : public OperatorInfo {
bool transpose_a_ = false;
bool transpose_b_ = false;
bool forward_reduce_scatter_ = false;
int32_t field_size_ = 0;
size_t mat_a_dimension_ = 0;
size_t mat_b_dimension_ = 0;
};

View File

@ -100,6 +100,7 @@ constexpr char CONCAT_DIM[] = "concat_dim";
constexpr char FORWARD[] = "forward";
constexpr char BACKWARD[] = "backward";
constexpr char REDISTRIBUTION[] = "redistribution";
constexpr char SKIP_REDISTRIBUTION[] = "skip_redistribution";
constexpr char REPLACE[] = "replace";
constexpr char CONNSYMBOL[] = "/";
constexpr char INSTANCE_NAME[] = "instance_name";
@ -131,6 +132,7 @@ constexpr char FORWARD_OP[] = "forward_op";
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
constexpr char DARA_PARALLEL[] = "data_parallel";
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
constexpr char FIELD_SIZE[] = "field_size";
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
constexpr char DEVICE[] = "Device";

View File

@ -18,6 +18,7 @@
#include <memory>
#include <vector>
#include <utility>
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/device_matrix.h"
@ -145,17 +146,23 @@ Status ReshapeInfo::ComputeReplaceOp() {
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString();
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
if (redistribution_oplist_ptr == nullptr) {
if (is_generating_costs_) {
MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed.";
} else {
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
if (is_skip_) {
ConstructOperator constructor;
replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
replace_op_info_.clear();
} else {
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
if (redistribution_oplist_ptr == nullptr) {
if (is_generating_costs_) {
MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed.";
} else {
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
}
return FAILED;
}
return FAILED;
replace_op_ = redistribution_oplist_ptr->first;
replace_op_info_ = redistribution_oplist_ptr->second;
}
replace_op_ = redistribution_oplist_ptr->first;
replace_op_info_ = redistribution_oplist_ptr->second;
MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size();
return SUCCESS;
}
@ -255,6 +262,19 @@ Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayout
}
Status ReshapeInfo::InferTensorInfo() {
// skip reshape infer if skip_redistribution is true
if (is_skip_) {
TensorLayout layout;
Shape shape;
Shape slice_shape;
layout.set_skip_redistribution(true);
TensorInfo tensor_info_in(layout, shape, slice_shape);
inputs_tensor_info_.push_back(tensor_info_in);
outputs_tensor_info_.push_back(tensor_info_in);
MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo";
return SUCCESS;
}
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = GetOutputsStrategy();
@ -316,6 +336,16 @@ Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const l
}
Status ReshapeInfo::Init(const StrategyPtr &strategy) {
auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION);
if (reshape_skip_redis_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second);
if (!reshape_skip_redis_iter->second->isa<BoolImm>()) {
MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool.";
return FAILED;
}
is_skip_ = reshape_skip_redis_iter->second->cast<BoolImmPtr>()->value();
}
ResetQueueMember();
device_number(strategy);
if (strategy) {

View File

@ -98,6 +98,7 @@ class ReshapeInfo : public OperatorInfo {
bool input_layout_set_flag_;
bool output_layout_set_flag_;
bool is_generating_costs_;
bool is_skip_ = false;
std::string pre_operator_name_;
std::string next_operator_name_;
};

View File

@ -302,16 +302,26 @@ void Redistribution(const std::pair<AnfNodePtr, int> &node_pair, const OperatorI
MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString();
// extract tensor layout in and out
if (distribute_operator->outputs_tensor_info().empty()) {
MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty";
MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name();
return;
}
if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) {
MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is "
<< next_distribute_operator->inputs_tensor_info().size();
MS_LOG(WARNING) << "The index is out of range, the index is " << index - 1 << ", the vector size is "
<< next_distribute_operator->inputs_tensor_info().size() << "next operator name is "
<< next_distribute_operator->name();
return;
}
TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)];
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) {
MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name()
<< "next distribute operator, operator name is" << next_distribute_operator->name();
return;
}
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "

View File

@ -28,6 +28,19 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix
return Status::SUCCESS;
}
// skip redistribution for reshape operator
OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) {
OperatorAttrs attrs;
ValuePtr param_value = MakeValue(shape);
Attr param = std::make_pair(SHAPE, param_value);
OperatorParams params = {std::make_pair(param, 2)};
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(RESHAPE, args);
OperatorVector opvector;
opvector.push_back(op);
return opvector;
}
Status ConstructOperator::ReshapeOP(Shape shape) {
int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int>());

View File

@ -35,6 +35,7 @@ class ConstructOperator {
ConstructOperator() : dev_size_(0) {}
~ConstructOperator() = default;
Status Init(const RankList &dev_list, const Shape &dev_matrix_shape);
OperatorVector SkipRedisReshapeOP(Shape shape);
Status ReshapeOP(Shape shape);
Status StridedSliceOP(Args args);
Status AllGatherOP(int32_t dev_dim);

View File

@ -41,6 +41,14 @@ class TensorLayout {
Status InitFromVector(const std::vector<int32_t> &device_arrangement, const std::vector<int32_t> &tensor_map,
const std::vector<int32_t> &tensor_shape);
bool skip_redistribution() const { return skip_redistribution_; }
void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; }
int32_t get_field_size() const { return field_size_; }
void set_field_size(int32_t field_size) { field_size_ = field_size; }
Arrangement device_arrangement() const { return device_arrangement_; }
Map tensor_map() const { return tensor_map_; }
@ -92,6 +100,8 @@ class TensorLayout {
Arrangement device_arrangement_;
Map tensor_map_;
Arrangement tensor_shape_;
bool skip_redistribution_ = false;
int32_t field_size_ = 0;
};
} // namespace parallel
} // namespace mindspore

View File

@ -247,8 +247,8 @@ class Parameter:
if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}."
.format(layout))
if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}."
if len(layout) < 3:
raise ValueError("The length of layout must be larger than 3! layout is {}."
.format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1]))
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])

View File

@ -229,8 +229,8 @@ def _load_tensor_by_layout(tensor, layout):
"""
if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}".format(layout))
if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}".format(layout))
if len(layout) < 3:
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout))
dev_mat = layout[0]
tensor_map = layout[1]
if tensor.size() == 1:
@ -290,3 +290,37 @@ def _reshape_param_data(param_data, dev_mat, tensor_map):
tensor_slices_new = tensor_slices_new_inner
return Tensor(tensor_slices_new[0])
def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
"""
Combine param slice by the device matrix, used in model parallel scenario.
Args:
param_data (Tensor): The tensor to be reshaped and rearrangement,
generated from all the device from AllGatherParamNet.
dev_mat (list): The device matrix of devices.
Returns:
Tensor, the combined tensor which with the whole data value.
Examples:
>>> param_data = _allgather_param_net(param_data)
>>> dev_mat = [2, 2]
>>> field_size = [39]
>>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size)
"""
device_count = 1
for dim in dev_mat:
device_count *= dim
tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
tensor_slices_col = []
for i in range(len(tensor_slices[0][0])):
tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1)
for j in range(1, device_count):
tensor_slices_new = np.concatenate((tensor_slices_new,\
np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1)
tensor_slices_col.append(tensor_slices_new)
new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
for i in range(1, len(tensor_slices_col)):
new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
return Tensor(new_tensor)

View File

@ -384,14 +384,17 @@ def _get_merged_param_data(net, param_name, param_data):
dev_mat = layout[0]
tensor_map = layout[1]
field_size = layout[3]
from mindspore.parallel._cell_wrapper import get_allgather_cell
from mindspore.parallel._tensor import _reshape_param_data
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
# while any dim is not equal to -1, means param is splited and needs to be merged
for dim in tensor_map:
if dim != -1:
allgather_net = get_allgather_cell()
param_data = allgather_net(param_data)
if field_size[0]:
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
return _reshape_param_data(param_data, dev_mat, tensor_map)
return param_data

View File

@ -49,8 +49,8 @@ def test_get_parameter_layout():
net.set_auto_parallel()
exe = me._executor
exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict

View File

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, matmul_weight, strategy1=None):
super().__init__()
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
self.matmul = P.MatMul(transpose_b=False)
self.index = Tensor(np.ones([64, 64]), dtype=ms.int32)
self.matmul_weight = Parameter(matmul_weight, "w1")
self.axis = 0
def construct(self, x, b):
out = self.gatherv2(x, self.index, self.axis)
out = self.reshape(out, (64, -1))
out = self.matmul(out, self.matmul_weight)
return out
_w1 = Tensor(np.ones([4096, 32]), dtype=ms.float32)
_x = Tensor(np.ones([64, 64]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
def compile_net(net):
context.set_context(save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_reshape_skip_redistribution():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8), (1, 1))
net = Net(_w1, strategy1)
compile_net(net)