!69338 split_micro_interleaved_copies_transpose_ops.
Merge pull request !69338 from yao_yf/split_micro_interleaved_copies_transpose_ops
This commit is contained in:
commit
73c6aabe1a
|
@ -470,24 +470,6 @@ Status MatMul::CheckInputLayout() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void MatMul::UpdateOutputTensorInfoForInterleaved() {
|
||||
if (inputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
|
||||
return;
|
||||
}
|
||||
if (!outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
|
||||
return;
|
||||
}
|
||||
auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
|
||||
auto output_dev_matrix = outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_origin().array();
|
||||
output_dev_matrix[output_dev_matrix.size() - 1] = interleaved_num;
|
||||
Arrangement out_device_arrangement_interleaved;
|
||||
out_device_arrangement_interleaved.Init(output_dev_matrix);
|
||||
auto new_tensor_layout = outputs_tensor_info_[kIndex0].tensor_layout();
|
||||
new_tensor_layout.set_device_arrangement_interleaved(out_device_arrangement_interleaved);
|
||||
TensorInfo new_output_tensor_info(new_tensor_layout);
|
||||
outputs_tensor_info_[kIndex0] = new_output_tensor_info;
|
||||
}
|
||||
|
||||
Status MatMul::CheckOutputLayout() {
|
||||
// Check all device matrix should be the same
|
||||
if (outputs_tensor_info_.size() != kSizeOne) {
|
||||
|
|
|
@ -85,7 +85,6 @@ class MatMul : public MatMulBase {
|
|||
Status CheckInputStrategy(const Shape &mat_a_strategy, const Shape &mat_b_strategy);
|
||||
TensorLayout InferOutputLayout();
|
||||
TensorLayout output_infer_tensor_layout_;
|
||||
void UpdateOutputTensorInfoForInterleaved();
|
||||
};
|
||||
|
||||
class MatMulInfo : public MatMul {
|
||||
|
|
|
@ -537,6 +537,24 @@ Status OperatorInfo::InferMirrorOpsByLayout() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void OperatorInfo::UpdateOutputTensorInfoForInterleaved() {
|
||||
if (inputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
|
||||
return;
|
||||
}
|
||||
if (!outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
|
||||
return;
|
||||
}
|
||||
auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
|
||||
auto output_dev_matrix = outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_origin().array();
|
||||
output_dev_matrix[output_dev_matrix.size() - 1] = interleaved_num;
|
||||
Arrangement out_device_arrangement_interleaved;
|
||||
out_device_arrangement_interleaved.Init(output_dev_matrix);
|
||||
auto new_tensor_layout = outputs_tensor_info_[kIndex0].tensor_layout();
|
||||
new_tensor_layout.set_device_arrangement_interleaved(out_device_arrangement_interleaved);
|
||||
TensorInfo new_output_tensor_info(new_tensor_layout);
|
||||
outputs_tensor_info_[kIndex0] = new_output_tensor_info;
|
||||
}
|
||||
|
||||
Status OperatorInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid args";
|
||||
|
|
|
@ -292,6 +292,7 @@ class OperatorInfo {
|
|||
virtual Status InferTensorInfo();
|
||||
|
||||
virtual void InferReplaceOps() {}
|
||||
virtual void UpdateOutputTensorInfoForInterleaved();
|
||||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
|
||||
virtual Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) { return SUCCESS; }
|
||||
Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "frontend/parallel/dynamic_creator.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -200,6 +201,7 @@ Status TransposeInfo::CheckOutputLayout() {
|
|||
}
|
||||
if (!output_infer_tensor_layout_.tensor_shape_before().array().empty()) {
|
||||
MS_LOG(INFO) << name_ << ": Using output tensor layout infer by input tensor layout.";
|
||||
UpdateOutputTensorInfoForInterleaved();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -210,6 +212,7 @@ Status TransposeInfo::CheckOutputLayout() {
|
|||
auto out_tensor_layout = InferOutputLayout();
|
||||
// output layout is the same as inferred (transpose the tensor map)
|
||||
if (out_layout == out_tensor_layout) {
|
||||
UpdateOutputTensorInfoForInterleaved();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -271,6 +274,7 @@ Status TransposeInfo::CheckOutputLayout() {
|
|||
MS_LOG(ERROR) << name_ << ": The output device arrangement is not equal to the expected device arrangement.";
|
||||
return FAILED;
|
||||
}
|
||||
UpdateOutputTensorInfoForInterleaved();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -304,6 +308,47 @@ TensorLayout TransposeInfo::InferOutputLayout() {
|
|||
return output_tensor_layout;
|
||||
}
|
||||
|
||||
Status TransposeInfo::ComputeReplaceGraphForInterleaved(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
|
||||
return FAILED;
|
||||
}
|
||||
auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
|
||||
Attr output_nums_attr = {"output_nums", MakeValue(interleaved_num)};
|
||||
OperatorAttrs virtual_converter_begin_attrs = {output_nums_attr};
|
||||
auto virtual_converter_begin = gen_g.PushBack(
|
||||
{gen_g.NewOpInst(VIRTUAL_CONVERTER_BEGIN, virtual_converter_begin_attrs), gen_g.virtual_input_node()});
|
||||
std::vector<AnfNodePtr> virtual_converter_end_inputs_vector;
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(virtual_converter_begin, 1)};
|
||||
for (int64_t i = 0; i < interleaved_num; ++i) {
|
||||
auto tuple_get_item = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), virtual_converter_begin, CreatInt64Imm(i)});
|
||||
auto trans_value = CreateTuple(axis_v_);
|
||||
auto transpose = gen_g.PushBack({gen_g.NewOpInst(TRANSPOSE), tuple_get_item, trans_value});
|
||||
virtual_converter_end_inputs_vector.push_back(transpose);
|
||||
}
|
||||
Attr input_nums_attr = {"input_nums", MakeValue(interleaved_num)};
|
||||
OperatorAttrs virtual_converter_end_attrs = {input_nums_attr};
|
||||
std::vector<AnfNodePtr> virtual_converter_end_inputs = {
|
||||
gen_g.NewOpInst(VIRTUAL_CONVERTER_END, virtual_converter_end_attrs)};
|
||||
std::copy(virtual_converter_end_inputs_vector.begin(), virtual_converter_end_inputs_vector.end(),
|
||||
std::back_inserter(virtual_converter_end_inputs));
|
||||
auto virtual_converter_end = gen_g.PushBack(virtual_converter_end_inputs);
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, virtual_converter_end));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
ReplaceGraphPtr TransposeInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (!inputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_interleaved().array().empty()) {
|
||||
if (ComputeReplaceGraphForInterleaved(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << " splitting micro interleaved failed.";
|
||||
}
|
||||
return replace_graph_;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> TransposeInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
|
|
|
@ -39,6 +39,7 @@ class TransposeInfo : public OperatorInfo {
|
|||
~TransposeInfo() override = default;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
@ -56,6 +57,7 @@ class TransposeInfo : public OperatorInfo {
|
|||
TensorLayout InferOutputLayout();
|
||||
TensorLayout output_infer_tensor_layout_;
|
||||
Status ComputeAxis();
|
||||
Status ComputeReplaceGraphForInterleaved(const CNodePtr &cnode);
|
||||
std::vector<int64_t> axis_v_;
|
||||
Dimensions input_strategy_;
|
||||
};
|
||||
|
|
|
@ -2523,15 +2523,15 @@ std::vector<std::vector<CNodePtr>> CreateInterleavedNeedReplaceOpLists(const CNo
|
|||
return need_replace_op_lists;
|
||||
}
|
||||
|
||||
void ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const std::vector<CNodePtr> &ag_vector,
|
||||
const std::vector<std::vector<int64_t>> &new_group_ranks_vector,
|
||||
size_t independent_size) {
|
||||
CNodePtr ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const std::vector<CNodePtr> &ag_vector,
|
||||
const std::vector<std::vector<int64_t>> &new_group_ranks_vector,
|
||||
size_t independent_size) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple->Clone())};
|
||||
std::transform(ag_vector.begin(), ag_vector.end(), std::back_inserter(make_tuple_inputs),
|
||||
[&](auto node) { return independent_size == 1 ? node->input(kIndex1) : node; });
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
auto replace_nodes = InterleavedReplacedConcatNodes(ag_vector);
|
||||
bool replace_concat = (!replace_nodes.empty() && replace_nodes.size() == ag_vector.size() && independent_size == 1);
|
||||
bool replace_concat = (!replace_nodes.empty() && independent_size == 1);
|
||||
AnfNodePtr axis = NewValueNode(MakeValue<int64_t>(0));
|
||||
if (replace_concat) {
|
||||
axis = replace_nodes.front()->input(kIndex2);
|
||||
|
@ -2549,10 +2549,15 @@ void ReplaceInterleavedAllGatherToConcat(const FuncGraphPtr &func_graph, const s
|
|||
}
|
||||
if (!replace_concat) {
|
||||
manager->Replace(ag, concat);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!replace_concat) {
|
||||
return concat;
|
||||
}
|
||||
for (size_t i = 0; i < replace_nodes.size(); ++i) {
|
||||
manager->Replace(replace_nodes[i], concat);
|
||||
}
|
||||
return concat;
|
||||
}
|
||||
|
||||
void MergeOpBeforeInterleaveSlice(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end) {
|
||||
|
@ -2586,8 +2591,10 @@ void MergeOpBeforeInterleaveSlice(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
}
|
||||
// merge nodes before multi slice
|
||||
auto slice_input = need_replace_op_lists[kIndex0][col]->input(kIndex1);
|
||||
need_replace_op_lists[kIndex0][col]->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
|
||||
for (size_t row = 1; row < need_replace_op_lists.size(); ++row) {
|
||||
auto slice_cnode = need_replace_op_lists[row][col];
|
||||
slice_cnode->AddAttr(INTERLEAVED_PARALLEL, MakeValue(true));
|
||||
manager->SetEdge(slice_cnode, kIndex1, slice_input);
|
||||
}
|
||||
}
|
||||
|
@ -2650,12 +2657,26 @@ void ConvertInterleaveAllGatherToConcat(const FuncGraphPtr &func_graph, const CN
|
|||
}
|
||||
|
||||
// replace allgathers to one concat.
|
||||
ReplaceInterleavedAllGatherToConcat(func_graph, ag_vector, new_group_ranks_vector, independent_size);
|
||||
auto replaced_concat =
|
||||
ReplaceInterleavedAllGatherToConcat(func_graph, ag_vector, new_group_ranks_vector, independent_size);
|
||||
auto manager = func_graph->manager();
|
||||
auto replaced_concat_users =
|
||||
GetOutputNodesWithFilter(replaced_concat, [&](const AnfNodePtr &anode) { return false; });
|
||||
if (replaced_concat_users.size() == kSizeOne) {
|
||||
continue;
|
||||
}
|
||||
if (std::all_of(replaced_concat_users.begin(), replaced_concat_users.end(),
|
||||
[](const std::pair<AnfNodePtr, int> &pair) {
|
||||
return IsPrimitiveCNode(pair.first, prim::kPrimStridedSlice) &&
|
||||
pair.first->cast<CNodePtr>()->HasAttr(INTERLEAVED_PARALLEL);
|
||||
})) {
|
||||
continue;
|
||||
}
|
||||
// merge the nodes afer the interleaved parallel concat.
|
||||
auto virtual_end_input1 = virtual_converter_end->input(kIndex1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(virtual_end_input1);
|
||||
auto new_virtual_converter_end = CreateVirtualConverterEndNode(func_graph, {virtual_end_input1});
|
||||
auto manager = func_graph->manager();
|
||||
|
||||
manager->Replace(virtual_converter_end, new_virtual_converter_end);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,6 +78,7 @@ class NetWithReshape(nn.Cell):
|
|||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(in_strategy=in_layout1, out_strategy=out_layout1)
|
||||
self.matmul2 = P.MatMul().shard(in_strategy=in_layout2, out_strategy=out_layout2)
|
||||
self.transpose = P.Transpose().shard(out_layout2)
|
||||
self.matmul2.add_prim_attr("recompute_comm_op", True)
|
||||
self.reshape = P.Reshape().add_prim_attr("recompute_comm_op", True)
|
||||
self.relu = P.ReLU()
|
||||
|
@ -93,6 +94,7 @@ class NetWithReshape(nn.Cell):
|
|||
y_new = self.reshape(y_new, (1024, 1024))
|
||||
out1 = self.matmul1(y_new, self.w1)
|
||||
out1 = self.cast(out1, ms.float16)
|
||||
out1 = self.transpose(out1, (1, 0))
|
||||
out1 = self.reshape(out1, (512, 2048))
|
||||
out2 = self.matmul2(out1, self.w2)
|
||||
out2 = self.reshape(out2, (1024, 1024))
|
||||
|
|
Loading…
Reference in New Issue