!69338 split_micro_interleaved_copies_transpose_ops.

Merge pull request !69338 from yao_yf/split_micro_interleaved_copies_transpose_ops
This commit is contained in:
yao_yf 2024-05-14 06:05:23 +00:00 committed by Gitee
commit 73c6aabe1a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 96 additions and 26 deletions

View File

@ -470,24 +470,6 @@ Status MatMul::CheckInputLayout() {
return SUCCESS;
}
void MatMul::UpdateOutputTensorInfoForInterleaved() {
if (inputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
return;
}
if (!outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
return;
}
auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
auto output_dev_matrix = outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_origin().array();
output_dev_matrix[output_dev_matrix.size() - 1] = interleaved_num;
Arrangement out_device_arrangement_interleaved;
out_device_arrangement_interleaved.Init(output_dev_matrix);
auto new_tensor_layout = outputs_tensor_info_[kIndex0].tensor_layout();
new_tensor_layout.set_device_arrangement_interleaved(out_device_arrangement_interleaved);
TensorInfo new_output_tensor_info(new_tensor_layout);
outputs_tensor_info_[kIndex0] = new_output_tensor_info;
}
Status MatMul::CheckOutputLayout() {
// Check all device matrix should be the same
if (outputs_tensor_info_.size() != kSizeOne) {

View File

@ -85,7 +85,6 @@ class MatMul : public MatMulBase {
Status CheckInputStrategy(const Shape &mat_a_strategy, const Shape &mat_b_strategy);
TensorLayout InferOutputLayout();
TensorLayout output_infer_tensor_layout_;
void UpdateOutputTensorInfoForInterleaved();
};
class MatMulInfo : public MatMul {

View File

@ -537,6 +537,24 @@ Status OperatorInfo::InferMirrorOpsByLayout() {
return SUCCESS;
}
void OperatorInfo::UpdateOutputTensorInfoForInterleaved() {
if (inputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
return;
}
if (!outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
return;
}
auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
auto output_dev_matrix = outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_origin().array();
output_dev_matrix[output_dev_matrix.size() - 1] = interleaved_num;
Arrangement out_device_arrangement_interleaved;
out_device_arrangement_interleaved.Init(output_dev_matrix);
auto new_tensor_layout = outputs_tensor_info_[kIndex0].tensor_layout();
new_tensor_layout.set_device_arrangement_interleaved(out_device_arrangement_interleaved);
TensorInfo new_output_tensor_info(new_tensor_layout);
outputs_tensor_info_[kIndex0] = new_output_tensor_info;
}
Status OperatorInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";

View File

@ -292,6 +292,7 @@ class OperatorInfo {
virtual Status InferTensorInfo();
virtual void InferReplaceOps() {}
virtual void UpdateOutputTensorInfoForInterleaved();
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
virtual Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) { return SUCCESS; }
Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);

View File

@ -27,6 +27,7 @@
#include "frontend/parallel/dynamic_creator.h"
#include "frontend/parallel/step_parallel.h"
#include "include/common/utils/convert_utils.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "utils/log_adapter.h"
namespace mindspore {
@ -200,6 +201,7 @@ Status TransposeInfo::CheckOutputLayout() {
}
if (!output_infer_tensor_layout_.tensor_shape_before().array().empty()) {
MS_LOG(INFO) << name_ << ": Using output tensor layout infer by input tensor layout.";
UpdateOutputTensorInfoForInterleaved();
return SUCCESS;
}
@ -210,6 +212,7 @@ Status TransposeInfo::CheckOutputLayout() {
auto out_tensor_layout = InferOutputLayout();
// output layout is the same as inferred (transpose the tensor map)
if (out_layout == out_tensor_layout) {
UpdateOutputTensorInfoForInterleaved();
return SUCCESS;
}
@ -271,6 +274,7 @@ Status TransposeInfo::CheckOutputLayout() {
MS_LOG(ERROR) << name_ << ": The output device arrangement is not equal to the expected device arrangement.";
return FAILED;
}
UpdateOutputTensorInfoForInterleaved();
return SUCCESS;
}
@ -304,6 +308,47 @@ TensorLayout TransposeInfo::InferOutputLayout() {
return output_tensor_layout;
}
Status TransposeInfo::ComputeReplaceGraphForInterleaved(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
return FAILED;
}
auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
Attr output_nums_attr = {"output_nums", MakeValue(interleaved_num)};
OperatorAttrs virtual_converter_begin_attrs = {output_nums_attr};
auto virtual_converter_begin = gen_g.PushBack(
{gen_g.NewOpInst(VIRTUAL_CONVERTER_BEGIN, virtual_converter_begin_attrs), gen_g.virtual_input_node()});
std::vector<AnfNodePtr> virtual_converter_end_inputs_vector;
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(virtual_converter_begin, 1)};
for (int64_t i = 0; i < interleaved_num; ++i) {
auto tuple_get_item = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), virtual_converter_begin, CreatInt64Imm(i)});
auto trans_value = CreateTuple(axis_v_);
auto transpose = gen_g.PushBack({gen_g.NewOpInst(TRANSPOSE), tuple_get_item, trans_value});
virtual_converter_end_inputs_vector.push_back(transpose);
}
Attr input_nums_attr = {"input_nums", MakeValue(interleaved_num)};
OperatorAttrs virtual_converter_end_attrs = {input_nums_attr};
std::vector<AnfNodePtr> virtual_converter_end_inputs = {
gen_g.NewOpInst(VIRTUAL_CONVERTER_END, virtual_converter_end_attrs)};
std::copy(virtual_converter_end_inputs_vector.begin(), virtual_converter_end_inputs_vector.end(),
std::back_inserter(virtual_converter_end_inputs));
auto virtual_converter_end = gen_g.PushBack(virtual_converter_end_inputs);
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, virtual_converter_end));
return SUCCESS;
}
ReplaceGraphPtr TransposeInfo::replace_graph(const CNodePtr &cnode) {
if (!inputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
if (ComputeReplaceGraphForInterleaved(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << " splitting micro interleaved failed.";
}
return replace_graph_;
}
return nullptr;
}
std::vector<StrategyPtr> TransposeInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};

View File

@ -39,6 +39,7 @@ class TransposeInfo : public OperatorInfo {
~TransposeInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
@ -56,6 +57,7 @@ class TransposeInfo : public OperatorInfo {
TensorLayout InferOutputLayout();
TensorLayout output_infer_tensor_layout_;
Status ComputeAxis();
Status ComputeReplaceGraphForInterleaved(const CNodePtr &cnode);
std::vector<int64_t> axis_v_;
Dimensions input_strategy_;
};

View File

@ -2523,15 +2523,15 @@ std::vector<std::vector<CNodePtr>> CreateInterleavedNeedReplaceOpLists(const CNo
return need_replace_op_lists;
}
void ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const std::vector<CNodePtr> &ag_vector,
const std::vector<std::vector<int64_t>> &new_group_ranks_vector,
size_t independent_size) {
CNodePtr ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const std::vector<CNodePtr> &ag_vector,
const std::vector<std::vector<int64_t>> &new_group_ranks_vector,
size_t independent_size) {
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple->Clone())};
std::transform(ag_vector.begin(), ag_vector.end(), std::back_inserter(make_tuple_inputs),
[&](auto node) { return independent_size == 1 ? node->input(kIndex1) : node; });
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
auto replace_nodes = InterleavedReplacedConcatNodes(ag_vector);
bool replace_concat = (!replace_nodes.empty() && replace_nodes.size() == ag_vector.size() && independent_size == 1);
bool replace_concat = (!replace_nodes.empty() && independent_size == 1);
AnfNodePtr axis = NewValueNode(MakeValue<int64_t>(0));
if (replace_concat) {
axis = replace_nodes.front()->input(kIndex2);
@ -2549,10 +2549,15 @@ void ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const s
}
if (!replace_concat) {
manager->Replace(ag, concat);
continue;
}
}
if (!replace_concat) {
return concat;
}
for (size_t i = 0; i < replace_nodes.size(); ++i) {
manager->Replace(replace_nodes[i], concat);
}
return concat;
}
void MergeOpBeforeInterleaveSlice(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end) {
@ -2586,8 +2591,10 @@ void MergeOpBeforeInterleaveSlice(const FuncGraphPtr &func_graph, const CNodePtr
}
// merge nodes before multi slice
auto slice_input = need_replace_op_lists[kIndex0][col]->input(kIndex1);
need_replace_op_lists[kIndex0][col]->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
for (size_t row = 1; row < need_replace_op_lists.size(); ++row) {
auto slice_cnode = need_replace_op_lists[row][col];
slice_cnode->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
manager->SetEdge(slice_cnode, kIndex1, slice_input);
}
}
@ -2650,12 +2657,26 @@ void ConvertInterleaveAllGatherToConcat(const FuncGraphPtr &func_graph, const CN
}
// replace allgathers to one concat.
ReplaceInterleavedAllGatherToConcat(func_graph, ag_vector, new_group_ranks_vector, independent_size);
auto replaced_concat =
ReplaceInterleavedAllGatherToConcat(func_graph, ag_vector, new_group_ranks_vector, independent_size);
auto manager = func_graph->manager();
auto replaced_concat_users =
GetOutputNodesWithFilter(replaced_concat, [&](const AnfNodePtr &anode) { return false; });
if (replaced_concat_users.size() == kSizeOne) {
continue;
}
if (std::all_of(replaced_concat_users.begin(), replaced_concat_users.end(),
[](const std::pair<AnfNodePtr, int> &pair) {
return IsPrimitiveCNode(pair.first, prim::kPrimStridedSlice) &&
pair.first->cast<CNodePtr>()->HasAttr(INTERLEAVED_PARALLEL);
})) {
continue;
}
// merge the nodes afer the interleaved parallel concat.
auto virtual_end_input1 = virtual_converter_end->input(kIndex1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(virtual_end_input1);
auto new_virtual_converter_end = CreateVirtualConverterEndNode(func_graph, {virtual_end_input1});
auto manager = func_graph->manager();
manager->Replace(virtual_converter_end, new_virtual_converter_end);
}
}

View File

@ -78,6 +78,7 @@ class NetWithReshape(nn.Cell):
super().__init__()
self.matmul1 = P.MatMul().shard(in_strategy=in_layout1, out_strategy=out_layout1)
self.matmul2 = P.MatMul().shard(in_strategy=in_layout2, out_strategy=out_layout2)
self.transpose = P.Transpose().shard(out_layout2)
self.matmul2.add_prim_attr("recompute_comm_op", True)
self.reshape = P.Reshape().add_prim_attr("recompute_comm_op", True)
self.relu = P.ReLU()
@ -93,6 +94,7 @@ class NetWithReshape(nn.Cell):
y_new = self.reshape(y_new, (1024, 1024))
out1 = self.matmul1(y_new, self.w1)
out1 = self.cast(out1, ms.float16)
out1 = self.transpose(out1, (1, 0))
out1 = self.reshape(out1, (512, 2048))
out2 = self.matmul2(out1, self.w2)
out2 = self.reshape(out2, (1024, 1024))