open complex pass for graphkernel

This commit is contained in:
zengzitao 2021-09-07 18:50:16 +08:00
parent b1be8dfd31
commit c9c912851a
2 changed files with 4 additions and 29 deletions

View File

@ -240,7 +240,7 @@ bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
bool has_complex = false;
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
if (all_inputs_type[i] == kNumberTypeComplex64) {
has_complex = true;
break;
}
@ -254,32 +254,7 @@ ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &node) {
}
bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
auto cnode = node->cast<CNodePtr>();
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(cnode);
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
all_inputs_type[i] = kNumberTypeComplex64;
}
}
auto all_outputs_type = AnfAlgo::GetAllOutputDeviceTypes(cnode);
for (size_t i = 0; i < all_outputs_type.size(); ++i) {
if (all_outputs_type[i] == kNumberTypeFloat64) {
all_outputs_type[i] = kNumberTypeComplex64;
}
}
auto all_inputs_format = AnfAlgo::GetAllInputFormats(cnode);
auto all_outputs_format = AnfAlgo::GetAllOutputFormats(cnode);
auto graph_sel_info =
BuildSelectKernelBuildInfo(all_inputs_format, all_inputs_type, all_outputs_format, all_outputs_type);
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, cnode.get());
std::vector<size_t> original_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
ShapeVector real_shape;
(void)std::copy(original_shape.begin(), original_shape.end(), std::back_inserter(real_shape));
auto complex_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(real_shape));
TypeId complex_type = kNumberTypeComplex64;
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(complex_type), complex_shape_ptr);
cnode->set_abstract(abstract);
if (!DefaultExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
if (!PyExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
(*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
return true;
}

View File

@ -83,7 +83,7 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
auto pm = std::make_shared<GraphKernelPassManager>(1, "cluster");
// Expand complex op to composite kernels
pm->AddPass(std::make_shared<GraphKernelComplexExpander>(), OptLevel_1, false);
pm->AddPass(std::make_shared<GraphKernelComplexExpander>(), OptLevel_1, is_gpu);
// Expand complex basic kernels to composite kernels
pm->AddPass(std::make_shared<GraphKernelExpander>(), OptLevel_1);
@ -125,7 +125,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
// Eliminate Redundant Complex op
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, false);
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, is_gpu);
// Eliminate unnecessary transform ops
auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize);