forked from mindspore-Ecosystem/mindspore
open complex pass for graphkernel
This commit is contained in:
parent
b1be8dfd31
commit
c9c912851a
|
@ -240,7 +240,7 @@ bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
|
||||||
bool has_complex = false;
|
bool has_complex = false;
|
||||||
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
|
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
|
||||||
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
|
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
|
||||||
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
|
if (all_inputs_type[i] == kNumberTypeComplex64) {
|
||||||
has_complex = true;
|
has_complex = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -254,32 +254,7 @@ ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
|
bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(cnode);
|
if (!PyExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
|
||||||
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
|
|
||||||
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
|
|
||||||
all_inputs_type[i] = kNumberTypeComplex64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto all_outputs_type = AnfAlgo::GetAllOutputDeviceTypes(cnode);
|
|
||||||
for (size_t i = 0; i < all_outputs_type.size(); ++i) {
|
|
||||||
if (all_outputs_type[i] == kNumberTypeFloat64) {
|
|
||||||
all_outputs_type[i] = kNumberTypeComplex64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto all_inputs_format = AnfAlgo::GetAllInputFormats(cnode);
|
|
||||||
auto all_outputs_format = AnfAlgo::GetAllOutputFormats(cnode);
|
|
||||||
auto graph_sel_info =
|
|
||||||
BuildSelectKernelBuildInfo(all_inputs_format, all_inputs_type, all_outputs_format, all_outputs_type);
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, cnode.get());
|
|
||||||
std::vector<size_t> original_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
|
||||||
ShapeVector real_shape;
|
|
||||||
(void)std::copy(original_shape.begin(), original_shape.end(), std::back_inserter(real_shape));
|
|
||||||
auto complex_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(real_shape));
|
|
||||||
TypeId complex_type = kNumberTypeComplex64;
|
|
||||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(complex_type), complex_shape_ptr);
|
|
||||||
cnode->set_abstract(abstract);
|
|
||||||
if (!DefaultExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
|
|
||||||
(*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
|
(*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,7 +83,7 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
|
||||||
auto pm = std::make_shared<GraphKernelPassManager>(1, "cluster");
|
auto pm = std::make_shared<GraphKernelPassManager>(1, "cluster");
|
||||||
|
|
||||||
// Expand complex op to composite kernels
|
// Expand complex op to composite kernels
|
||||||
pm->AddPass(std::make_shared<GraphKernelComplexExpander>(), OptLevel_1, false);
|
pm->AddPass(std::make_shared<GraphKernelComplexExpander>(), OptLevel_1, is_gpu);
|
||||||
|
|
||||||
// Expand complex basic kernels to composite kernels
|
// Expand complex basic kernels to composite kernels
|
||||||
pm->AddPass(std::make_shared<GraphKernelExpander>(), OptLevel_1);
|
pm->AddPass(std::make_shared<GraphKernelExpander>(), OptLevel_1);
|
||||||
|
@ -125,7 +125,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
|
||||||
pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
|
pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
|
||||||
|
|
||||||
// Eliminate Redundant Complex op
|
// Eliminate Redundant Complex op
|
||||||
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, false);
|
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, is_gpu);
|
||||||
|
|
||||||
// Eliminate unnecessary transform ops
|
// Eliminate unnecessary transform ops
|
||||||
auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize);
|
auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize);
|
||||||
|
|
Loading…
Reference in New Issue