forked from mindspore-Ecosystem/mindspore
place layernormgrad split pass before kernel select
This commit is contained in:
parent
e7b7abc581
commit
cf87218fb7
|
@ -145,7 +145,6 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm");
|
||||
data_layout_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>());
|
||||
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
|
@ -182,7 +181,6 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
|
||||
data_layout_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
|
||||
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
|
@ -238,6 +236,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
|
||||
} else {
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>());
|
||||
}
|
||||
|
@ -281,6 +280,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
|
|
|
@ -32,7 +32,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
|
|||
std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_grad);
|
||||
MS_EXCEPTION_IF_NULL(kernel_select_);
|
||||
auto prim = std::make_shared<Primitive>(kLayerNormXBackpropOpName);
|
||||
std::vector<AnfNodePtr> layer_norm_x_backprop_inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) {
|
||||
|
@ -46,7 +45,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
|
|||
auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get());
|
||||
|
||||
kernel_select_->SelectKernel(layer_norm_x_backprop);
|
||||
(*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop);
|
||||
}
|
||||
|
||||
|
@ -55,7 +53,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
|
|||
std::vector<AnfNodePtr> *layer_norm_beta_gamma_backprop_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_grad);
|
||||
MS_EXCEPTION_IF_NULL(kernel_select_);
|
||||
auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropOpName);
|
||||
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) {
|
||||
|
@ -73,10 +70,9 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
|
|||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get());
|
||||
|
||||
// get device shape of LayerNormGrad's 5th Input, and convert it to attr
|
||||
std::vector<size_t> shape_gamma = AnfAlgo::GetInputDeviceShape(layer_norm_grad, 4);
|
||||
std::vector<size_t> shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4);
|
||||
AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop);
|
||||
|
||||
kernel_select_->SelectKernel(layer_norm_beta_gamma_backprop);
|
||||
CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum,
|
||||
layer_norm_beta_gamma_backprop_outputs);
|
||||
}
|
||||
|
|
|
@ -26,8 +26,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class LayerNormGradSplit : public PatternProcessPass {
|
||||
public:
|
||||
explicit LayerNormGradSplit(bool multigraph = true)
|
||||
: PatternProcessPass("layer_norm_grad_split", multigraph), kernel_select_(std::make_shared<KernelSelect>()) {}
|
||||
explicit LayerNormGradSplit(bool multigraph = true) : PatternProcessPass("layer_norm_grad_split", multigraph) {}
|
||||
~LayerNormGradSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
@ -37,7 +36,6 @@ class LayerNormGradSplit : public PatternProcessPass {
|
|||
std::vector<AnfNodePtr> *layer_norm_grad_outputs) const;
|
||||
void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
std::vector<AnfNodePtr> *layer_norm_beta_gamma_outputs) const;
|
||||
KernelSelectPtr kernel_select_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,36 +39,6 @@ class TestHWLayerNormGradSplit : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
class MockLayerNormGradSplitKernelSelect : public KernelSelect {
|
||||
public:
|
||||
MockLayerNormGradSplitKernelSelect() = default;
|
||||
~MockLayerNormGradSplitKernelSelect() override = default;
|
||||
void SelectKernel(const CNodePtr &cnode) override {
|
||||
auto name = AnfAlgo::GetCNodeName(cnode);
|
||||
|
||||
if (name == kLayerNormXBackpropOpName) {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder.SetInputsDeviceType(
|
||||
{kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||
builder.SetOutputsDeviceType({kNumberTypeFloat16});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
return;
|
||||
}
|
||||
if (name == kLayerNormBetaGammaBackpropOpName) {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
return;
|
||||
}
|
||||
}
|
||||
}; // namespace opt
|
||||
|
||||
TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) {
|
||||
get_py_fun_.SetDoResolve(true);
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_layer_norm_grad_split", "before");
|
||||
|
@ -81,49 +51,9 @@ TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) {
|
|||
auto kernel_graph = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kernel_graph, nullptr);
|
||||
|
||||
// get LayerNormGrad
|
||||
CNodePtr ret = kernel_graph->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
EXPECT_TRUE(ret->input(1)->isa<CNode>());
|
||||
auto make_tuple1 = ret->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(make_tuple1->input(1), nullptr);
|
||||
EXPECT_TRUE(make_tuple1->input(1)->isa<CNode>());
|
||||
auto make_tuple2 = make_tuple1->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(make_tuple2->input(1), nullptr);
|
||||
EXPECT_TRUE(make_tuple2->input(1)->isa<CNode>());
|
||||
auto tuple_getitem = make_tuple2->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(tuple_getitem->input(1), nullptr);
|
||||
EXPECT_TRUE(tuple_getitem->input(1)->isa<CNode>());
|
||||
auto layer_norm_grad = tuple_getitem->input(1)->cast<CNodePtr>();
|
||||
|
||||
// set kernel for LayerNormGrad
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
||||
builder1.SetInputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder1.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder1.SetInputsDeviceType(
|
||||
{kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
builder1.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
builder1.SetKernelType(TBE_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), layer_norm_grad.get());
|
||||
|
||||
// get param5
|
||||
EXPECT_NE(layer_norm_grad->input(5), nullptr);
|
||||
auto param = layer_norm_grad->input(5);
|
||||
|
||||
// set kernel for param5
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder2;
|
||||
builder2.SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||
builder2.SetOutputsDeviceType({kNumberTypeFloat16});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), param.get());
|
||||
|
||||
// do layer_norm_grad_split pass
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::LayerNormGradSplit>();
|
||||
auto kernel_select = std::make_shared<MockLayerNormGradSplitKernelSelect>();
|
||||
pass->kernel_select_ = kernel_select;
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(kernel_graph);
|
||||
|
|
Loading…
Reference in New Issue