forked from mindspore-Ecosystem/mindspore
Overlength functions rectification
This commit is contained in:
parent
53b3d187b9
commit
146ac1263e
|
@ -70,6 +70,35 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
||||
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
|
||||
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayV1Rule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
|
@ -164,29 +193,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
|
||||
if (context_ptr->ir_fusion_flag()) {
|
||||
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayV1Rule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
|
||||
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
|
||||
}
|
||||
|
||||
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag,
|
||||
std::vector<CNodePtr> *trans_road) {
|
||||
if (node == nullptr) {
|
||||
|
@ -59,6 +60,24 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type,
|
||||
TypeId output_type) {
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
auto kernel_info = cast->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto cast_build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(cast_build_info);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsFormat({format});
|
||||
builder.SetInputsFormat({format});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
builder.SetKernelType(cast_build_info->kernel_type());
|
||||
builder.SetFusionType(cast_build_info->fusion_type());
|
||||
builder.SetProcessor(cast_build_info->processor());
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace
|
||||
bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Func graph is nullptr";
|
||||
|
@ -95,17 +114,7 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0);
|
||||
|
||||
auto cast = trans_road[1];
|
||||
auto cast_format = AnfAlgo::GetOutputFormat(cast, 0);
|
||||
auto cast_build_info = cast->kernel_info()->select_kernel_build_info();
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsFormat({format});
|
||||
builder.SetInputsFormat({format});
|
||||
builder.SetInputsDeviceType({param_dtype});
|
||||
builder.SetOutputsDeviceType({dtype});
|
||||
builder.SetKernelType(cast_build_info->kernel_type());
|
||||
builder.SetFusionType(cast_build_info->fusion_type());
|
||||
builder.SetProcessor(cast_build_info->processor());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get());
|
||||
if (param_format == format && param_dtype != dtype) {
|
||||
manager->Replace(trans_road[2], final_node);
|
||||
manager->Replace(cur_transop, cast);
|
||||
|
|
Loading…
Reference in New Issue