!35122 support is_dynamic_impl

Merge pull request !35122 from liubuyu/dynamic
This commit is contained in:
i-robot 2022-06-06 06:02:26 +00:00 committed by Gitee
commit f5814afc26
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 24 additions and 17 deletions

View File

@ -388,10 +388,14 @@ void TbeJsonCreator::GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::
python_module_path = kPyPath;
}
auto dynamic_compile_static = op_info_ptr->dynamic_compile_static();
auto is_dynamic = op_info_ptr->dynamic_shape() && tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node);
auto is_dynamic_impl = is_dynamic || dynamic_compile_static;
auto iter = tbe::opTypeAdapter.find(op_name);
(*compute_json)[kJType] = (iter != tbe::opTypeAdapter.end()) ? iter->second : op_name;
(*compute_json)[kJPyModulePath] = python_module_path;
(*compute_json)[kJDynamicCompileStatic] = op_info_ptr->dynamic_compile_static();
(*compute_json)[kJDynamicCompileStatic] = dynamic_compile_static;
(*compute_json)[kJIsDynamicImpl] = is_dynamic_impl;
(*compute_json)[kJInt64Mode] = false;
(*compute_json)[kJName] = cnode->fullname_with_scope();
(*compute_json)[kJPattern] = kernel::GetFusionNameByType(AnfAlgo::GetFusionType(cnode));

View File

@ -74,6 +74,7 @@ constexpr auto kJSliceOffset = "slice_offset";
constexpr auto kJSplitIndex = "split_index";
constexpr auto kJTotalShape = "total_shape";
constexpr auto kJDynamicCompileStatic = "dynamic_compile_static";
constexpr auto kJIsDynamicImpl = "is_dynamic_impl";
constexpr auto kJInt64Mode = "int64mode";
constexpr auto kJValidShape = "valid_shape";
constexpr auto kJModuleName = "module_name";

View File

@ -389,6 +389,7 @@ def _pre_build_compute_op_info(compute_op, job):
set_L1_info("op_L1_space", l1_size)
_normalize_module_name(op_module_name, py_module_path)
unknown_shape = compute_op["unknown_shape"]
is_dynamic_impl = compute_op["is_dynamic_impl"]
int64_mode = compute_op["int64mode"]
res = check_op_impl_mode(op_module_name, op_func_name)
op_impl_mode = job.content["SocInfo"]["op_impl_mode"]
@ -402,7 +403,7 @@ def _pre_build_compute_op_info(compute_op, job):
options = get_options_info(job.content)
dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_full_name,
op_type, op_func_name, unknown_shape,
(inputs, outputs, attrs, options), int64_mode, unknown_shape,
(inputs, outputs, attrs, options), int64_mode, is_dynamic_impl,
None, job.pass_list)
@ -465,13 +466,14 @@ def build_single_pre_op(job: TbeJob):
op_func_name = compute_op_info["func_name"]
_normalize_module_name(op_module_name, py_module_path)
unknown_shape = compute_op_info["unknown_shape"]
is_dynamic_impl = compute_op_info["is_dynamic_impl"]
int64_mode = compute_op_info["int64mode"]
op_pattern = compute_op_info["pattern"]
options = get_options_info(job.content)
fuzz_build_info = get_fuzz_build_info(job.content)
dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_name, op_type, op_func_name,
op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode,
None, None, unknown_shape, op_pattern,
None, None, is_dynamic_impl, op_pattern,
json.dumps(fuzz_build_info), None, job.pass_list)
return True

View File

@ -77,11 +77,11 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13146561810461380838U);
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 5780584009322070553U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17413190217831512531U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17322530358240753834U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17413190217831512531U);
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17322530358240753834U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
@ -120,11 +120,11 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 3927968868169541779U);
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 7656283680331759978U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 5438793620486689761U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 3632095151624181824U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 5438793620486689761U);
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 3632095151624181824U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
@ -178,11 +178,11 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13288675099420394285U);
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 8179988591608352552U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17084598473306810717U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 11572005077409464386U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17084598473306810717U);
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 11572005077409464386U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
@ -232,11 +232,11 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 6545088373747371515U);
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 1374295440061239938U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 10583210293426000299U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4359214283733046791U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 10583210293426000299U);
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 4359214283733046791U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
@ -308,7 +308,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
nlohmann::json fusion_json;
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 12371036326019427133U);
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 3090761817012021496U);
}
TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
@ -368,7 +368,7 @@ TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
nlohmann::json fusion_json;
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 3371595473173037387U);
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 15855944752652799179U);
}
} // namespace mindspore::kernel