forked from mindspore-Ecosystem/mindspore
!2876 set reshape operator no redistribution for auto parallel
Merge pull request !2876 from lirongzhen1/master
This commit is contained in:
commit
edec821c50
|
@ -44,7 +44,14 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
|
|||
auto device_arrangement = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape};
|
||||
int32_t _field_size = tensor_layout->get_field_size();
|
||||
std::vector<int32_t> field_size;
|
||||
if (_field_size != 0) {
|
||||
field_size.push_back(_field_size);
|
||||
} else {
|
||||
field_size = {0};
|
||||
}
|
||||
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape, field_size};
|
||||
dict[py::str(name)] = layout;
|
||||
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
|
||||
}
|
||||
|
|
|
@ -105,6 +105,17 @@ Status MatMulBase::GetAttrs() {
|
|||
}
|
||||
}
|
||||
|
||||
auto field_size_iter = attrs_.find(FIELD_SIZE);
|
||||
if (field_size_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(field_size_iter->second);
|
||||
if (field_size_iter->second->isa<Int32Imm>()) {
|
||||
field_size_ = field_size_iter->second->cast<Int32ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of field_size is not int.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// infer inputs dimension size
|
||||
if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
|
||||
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
|
||||
|
@ -346,6 +357,10 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (field_size_ != 0) {
|
||||
mat_b_layout.set_field_size(field_size_);
|
||||
}
|
||||
|
||||
inputs_layout->push_back(mat_a_layout);
|
||||
inputs_layout->push_back(mat_b_layout);
|
||||
outputs_layout->push_back(output_layout);
|
||||
|
|
|
@ -62,6 +62,7 @@ class MatMulBase : public OperatorInfo {
|
|||
bool transpose_a_ = false;
|
||||
bool transpose_b_ = false;
|
||||
bool forward_reduce_scatter_ = false;
|
||||
int32_t field_size_ = 0;
|
||||
size_t mat_a_dimension_ = 0;
|
||||
size_t mat_b_dimension_ = 0;
|
||||
};
|
||||
|
|
|
@ -100,6 +100,7 @@ constexpr char CONCAT_DIM[] = "concat_dim";
|
|||
constexpr char FORWARD[] = "forward";
|
||||
constexpr char BACKWARD[] = "backward";
|
||||
constexpr char REDISTRIBUTION[] = "redistribution";
|
||||
constexpr char SKIP_REDISTRIBUTION[] = "skip_redistribution";
|
||||
constexpr char REPLACE[] = "replace";
|
||||
constexpr char CONNSYMBOL[] = "/";
|
||||
constexpr char INSTANCE_NAME[] = "instance_name";
|
||||
|
@ -131,6 +132,7 @@ constexpr char FORWARD_OP[] = "forward_op";
|
|||
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
||||
constexpr char DARA_PARALLEL[] = "data_parallel";
|
||||
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
||||
constexpr char FIELD_SIZE[] = "field_size";
|
||||
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
|
||||
constexpr char DEVICE[] = "Device";
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
|
@ -145,6 +146,11 @@ Status ReshapeInfo::ComputeReplaceOp() {
|
|||
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
|
||||
MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString();
|
||||
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
|
||||
if (is_skip_) {
|
||||
ConstructOperator constructor;
|
||||
replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
|
||||
replace_op_info_.clear();
|
||||
} else {
|
||||
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
|
||||
if (redistribution_oplist_ptr == nullptr) {
|
||||
if (is_generating_costs_) {
|
||||
|
@ -156,6 +162,7 @@ Status ReshapeInfo::ComputeReplaceOp() {
|
|||
}
|
||||
replace_op_ = redistribution_oplist_ptr->first;
|
||||
replace_op_info_ = redistribution_oplist_ptr->second;
|
||||
}
|
||||
MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -255,6 +262,19 @@ Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayout
|
|||
}
|
||||
|
||||
Status ReshapeInfo::InferTensorInfo() {
|
||||
// skip reshape infer if skip_redistribution is true
|
||||
if (is_skip_) {
|
||||
TensorLayout layout;
|
||||
Shape shape;
|
||||
Shape slice_shape;
|
||||
layout.set_skip_redistribution(true);
|
||||
TensorInfo tensor_info_in(layout, shape, slice_shape);
|
||||
inputs_tensor_info_.push_back(tensor_info_in);
|
||||
outputs_tensor_info_.push_back(tensor_info_in);
|
||||
MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
||||
Strategys outputs_strategy = GetOutputsStrategy();
|
||||
|
@ -316,6 +336,16 @@ Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const l
|
|||
}
|
||||
|
||||
Status ReshapeInfo::Init(const StrategyPtr &strategy) {
|
||||
auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION);
|
||||
if (reshape_skip_redis_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second);
|
||||
if (!reshape_skip_redis_iter->second->isa<BoolImm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool.";
|
||||
return FAILED;
|
||||
}
|
||||
is_skip_ = reshape_skip_redis_iter->second->cast<BoolImmPtr>()->value();
|
||||
}
|
||||
|
||||
ResetQueueMember();
|
||||
device_number(strategy);
|
||||
if (strategy) {
|
||||
|
|
|
@ -98,6 +98,7 @@ class ReshapeInfo : public OperatorInfo {
|
|||
bool input_layout_set_flag_;
|
||||
bool output_layout_set_flag_;
|
||||
bool is_generating_costs_;
|
||||
bool is_skip_ = false;
|
||||
std::string pre_operator_name_;
|
||||
std::string next_operator_name_;
|
||||
};
|
||||
|
|
|
@ -302,16 +302,26 @@ void Redistribution(const std::pair<AnfNodePtr, int> &node_pair, const OperatorI
|
|||
MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString();
|
||||
// extract tensor layout in and out
|
||||
if (distribute_operator->outputs_tensor_info().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty";
|
||||
MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name();
|
||||
return;
|
||||
}
|
||||
|
||||
if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) {
|
||||
MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is "
|
||||
<< next_distribute_operator->inputs_tensor_info().size();
|
||||
MS_LOG(WARNING) << "The index is out of range, the index is " << index - 1 << ", the vector size is "
|
||||
<< next_distribute_operator->inputs_tensor_info().size() << "next operator name is "
|
||||
<< next_distribute_operator->name();
|
||||
return;
|
||||
}
|
||||
TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)];
|
||||
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
|
||||
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
|
||||
|
||||
if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) {
|
||||
MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name()
|
||||
<< "next distribute operator, operator name is" << next_distribute_operator->name();
|
||||
return;
|
||||
}
|
||||
|
||||
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
|
||||
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
|
||||
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
|
||||
|
|
|
@ -28,6 +28,19 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix
|
|||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
// skip redistribution for reshape operator
|
||||
OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) {
|
||||
OperatorAttrs attrs;
|
||||
ValuePtr param_value = MakeValue(shape);
|
||||
Attr param = std::make_pair(SHAPE, param_value);
|
||||
OperatorParams params = {std::make_pair(param, 2)};
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(RESHAPE, args);
|
||||
OperatorVector opvector;
|
||||
opvector.push_back(op);
|
||||
return opvector;
|
||||
}
|
||||
|
||||
Status ConstructOperator::ReshapeOP(Shape shape) {
|
||||
int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int>());
|
||||
|
|
|
@ -35,6 +35,7 @@ class ConstructOperator {
|
|||
ConstructOperator() : dev_size_(0) {}
|
||||
~ConstructOperator() = default;
|
||||
Status Init(const RankList &dev_list, const Shape &dev_matrix_shape);
|
||||
OperatorVector SkipRedisReshapeOP(Shape shape);
|
||||
Status ReshapeOP(Shape shape);
|
||||
Status StridedSliceOP(Args args);
|
||||
Status AllGatherOP(int32_t dev_dim);
|
||||
|
|
|
@ -41,6 +41,14 @@ class TensorLayout {
|
|||
Status InitFromVector(const std::vector<int32_t> &device_arrangement, const std::vector<int32_t> &tensor_map,
|
||||
const std::vector<int32_t> &tensor_shape);
|
||||
|
||||
bool skip_redistribution() const { return skip_redistribution_; }
|
||||
|
||||
void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; }
|
||||
|
||||
int32_t get_field_size() const { return field_size_; }
|
||||
|
||||
void set_field_size(int32_t field_size) { field_size_ = field_size; }
|
||||
|
||||
Arrangement device_arrangement() const { return device_arrangement_; }
|
||||
|
||||
Map tensor_map() const { return tensor_map_; }
|
||||
|
@ -92,6 +100,8 @@ class TensorLayout {
|
|||
Arrangement device_arrangement_;
|
||||
Map tensor_map_;
|
||||
Arrangement tensor_shape_;
|
||||
bool skip_redistribution_ = false;
|
||||
int32_t field_size_ = 0;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -247,8 +247,8 @@ class Parameter:
|
|||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}."
|
||||
.format(layout))
|
||||
if len(layout) != 3:
|
||||
raise ValueError("The length of layout must be 3! layout is {}."
|
||||
if len(layout) < 3:
|
||||
raise ValueError("The length of layout must be larger than 3! layout is {}."
|
||||
.format(layout))
|
||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
|
||||
|
|
|
@ -229,8 +229,8 @@ def _load_tensor_by_layout(tensor, layout):
|
|||
"""
|
||||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}".format(layout))
|
||||
if len(layout) != 3:
|
||||
raise ValueError("The length of layout must be 3! layout is {}".format(layout))
|
||||
if len(layout) < 3:
|
||||
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout))
|
||||
dev_mat = layout[0]
|
||||
tensor_map = layout[1]
|
||||
if tensor.size() == 1:
|
||||
|
@ -290,3 +290,37 @@ def _reshape_param_data(param_data, dev_mat, tensor_map):
|
|||
tensor_slices_new = tensor_slices_new_inner
|
||||
|
||||
return Tensor(tensor_slices_new[0])
|
||||
|
||||
def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
|
||||
"""
|
||||
Combine param slice by the device matrix, used in model parallel scenario.
|
||||
|
||||
Args:
|
||||
param_data (Tensor): The tensor to be reshaped and rearrangement,
|
||||
generated from all the device from AllGatherParamNet.
|
||||
dev_mat (list): The device matrix of devices.
|
||||
Returns:
|
||||
Tensor, the combined tensor which with the whole data value.
|
||||
|
||||
Examples:
|
||||
>>> param_data = _allgather_param_net(param_data)
|
||||
>>> dev_mat = [2, 2]
|
||||
>>> field_size = [39]
|
||||
>>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size)
|
||||
"""
|
||||
device_count = 1
|
||||
for dim in dev_mat:
|
||||
device_count *= dim
|
||||
|
||||
tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
|
||||
tensor_slices_col = []
|
||||
for i in range(len(tensor_slices[0][0])):
|
||||
tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1)
|
||||
for j in range(1, device_count):
|
||||
tensor_slices_new = np.concatenate((tensor_slices_new,\
|
||||
np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1)
|
||||
tensor_slices_col.append(tensor_slices_new)
|
||||
new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
|
||||
for i in range(1, len(tensor_slices_col)):
|
||||
new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
|
||||
return Tensor(new_tensor)
|
||||
|
|
|
@ -384,14 +384,17 @@ def _get_merged_param_data(net, param_name, param_data):
|
|||
|
||||
dev_mat = layout[0]
|
||||
tensor_map = layout[1]
|
||||
field_size = layout[3]
|
||||
|
||||
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
||||
from mindspore.parallel._tensor import _reshape_param_data
|
||||
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
||||
# while any dim is not equal to -1, means param is splited and needs to be merged
|
||||
for dim in tensor_map:
|
||||
if dim != -1:
|
||||
allgather_net = get_allgather_cell()
|
||||
param_data = allgather_net(param_data)
|
||||
if field_size[0]:
|
||||
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
|
||||
return _reshape_param_data(param_data, dev_mat, tensor_map)
|
||||
|
||||
return param_data
|
||||
|
|
|
@ -49,8 +49,8 @@ def test_get_parameter_layout():
|
|||
net.set_auto_parallel()
|
||||
exe = me._executor
|
||||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, matmul_weight, strategy1=None):
|
||||
super().__init__()
|
||||
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
||||
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.index = Tensor(np.ones([64, 64]), dtype=ms.int32)
|
||||
self.matmul_weight = Parameter(matmul_weight, "w1")
|
||||
self.axis = 0
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.gatherv2(x, self.index, self.axis)
|
||||
out = self.reshape(out, (64, -1))
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
return out
|
||||
|
||||
|
||||
_w1 = Tensor(np.ones([4096, 32]), dtype=ms.float32)
|
||||
_x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
|
||||
def compile_net(net):
|
||||
context.set_context(save_graphs=True)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_reshape_skip_redistribution():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
net = Net(_w1, strategy1)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue