diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 74274853675..4b0c608837b 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -240,7 +240,7 @@ bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const { bool has_complex = false; auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node); for (size_t i = 0; i < all_inputs_type.size(); ++i) { - if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) { + if (all_inputs_type[i] == kNumberTypeComplex64) { has_complex = true; break; } @@ -254,32 +254,7 @@ ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &node) { } bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) { auto cnode = node->cast(); - auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(cnode); - for (size_t i = 0; i < all_inputs_type.size(); ++i) { - if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) { - all_inputs_type[i] = kNumberTypeComplex64; - } - } - - auto all_outputs_type = AnfAlgo::GetAllOutputDeviceTypes(cnode); - for (size_t i = 0; i < all_outputs_type.size(); ++i) { - if (all_outputs_type[i] == kNumberTypeFloat64) { - all_outputs_type[i] = kNumberTypeComplex64; - } - } - auto all_inputs_format = AnfAlgo::GetAllInputFormats(cnode); - auto all_outputs_format = AnfAlgo::GetAllOutputFormats(cnode); - auto graph_sel_info = - BuildSelectKernelBuildInfo(all_inputs_format, all_inputs_type, all_outputs_format, all_outputs_type); - AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, cnode.get()); - std::vector original_shape = AnfAlgo::GetOutputInferShape(cnode, 0); - ShapeVector real_shape; - (void)std::copy(original_shape.begin(), original_shape.end(), std::back_inserter(real_shape)); - auto complex_shape_ptr = std::make_shared(abstract::Shape(real_shape)); - TypeId complex_type = kNumberTypeComplex64; - auto abstract = std::make_shared(TypeIdToType(complex_type), complex_shape_ptr); - cnode->set_abstract(abstract); - if (!DefaultExpander::ExpandJsonInfo(cnode, kernel_json)) return false; + if (!PyExpander::ExpandJsonInfo(cnode, kernel_json)) return false; (*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode); return true; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index a768c532cdf..1e396847f71 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -83,7 +83,7 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const { auto pm = std::make_shared(1, "cluster"); // Expand complex op to composite kernels - pm->AddPass(std::make_shared(), OptLevel_1, false); + pm->AddPass(std::make_shared(), OptLevel_1, is_gpu); // Expand complex basic kernels to composite kernels pm->AddPass(std::make_shared(), OptLevel_1); @@ -125,7 +125,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const { pm->AddPass(std::make_shared(), OptLevel_2); // Eliminate Redundant Complex op - pm->AddPass(std::make_shared(), OptLevel_2, false); + pm->AddPass(std::make_shared(), OptLevel_2, is_gpu); // Eliminate unnecessary transform ops auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize);