!40778 disable argmax and argmin fusion

Merge pull request !40778 from Gaoxiong/master
This commit is contained in:
i-robot 2022-08-24 06:52:04 +00:00 committed by Gitee
commit 123baf6473
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 5 additions and 4 deletions

View File

@ -65,8 +65,8 @@ std::vector<PrimitivePtr> GraphKernelCluster::GetClusterOps() {
// gpu
{kGPUDevice, OpLevel_0, prim::kPrimACos},
{kGPUDevice, OpLevel_0, prim::kPrimAcosh},
{kGPUDevice, OpLevel_1, prim::kPrimArgMax},
{kGPUDevice, OpLevel_1, prim::kPrimArgMin},
{kGPUDevice, OpLevel_2, prim::kPrimArgMax},
{kGPUDevice, OpLevel_2, prim::kPrimArgMin},
{kGPUDevice, OpLevel_0, prim::kPrimAsin},
{kGPUDevice, OpLevel_0, prim::kPrimAsinh},
{kGPUDevice, OpLevel_0, prim::kPrimAssign},

View File

@ -275,7 +275,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
reg.AddFlag("reduce_fuse_depth", &reduce_fuse_depth);
reg.AddFlag("online_tuning", &online_tuning);
reg.AddFlag("cpu_refer_thread_num", &cpu_refer_thread_num);
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_MAX);
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_1);
reg.AddFlag("parallel_ops_level", &parallel_ops_level);
reg.AddFlag("recompute_increment_threshold", &recompute_increment_threshold);
reg.AddFlag("recompute_peak_threshold", &recompute_peak_threshold);

View File

@ -32,7 +32,8 @@ constexpr unsigned int OptLevel_MAX = 4;
constexpr unsigned int OpLevel_0 = 0;
constexpr unsigned int OpLevel_1 = 1;
constexpr unsigned int OpLevel_MAX = 2;
constexpr unsigned int OpLevel_2 = 2;
constexpr unsigned int OpLevel_MAX = 3;
constexpr unsigned int default_cpu_refer_tread_num = 8;
class GraphKernelFlags {