!22206 fix gather_v2

Merge pull request !22206 from liubuyu/bug_fix
This commit is contained in:
i-robot 2021-08-23 06:16:06 +00:00 committed by Gitee
commit af34229974
3 changed files with 4 additions and 2 deletions

View File

@ -175,7 +175,8 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t
auto kernel_pack = TbeUtils::InsertCache(json_name, processor); auto kernel_pack = TbeUtils::InsertCache(json_name, processor);
if (kernel_pack == nullptr) { if (kernel_pack == nullptr) {
if (set_kernel_mod) { if (set_kernel_mod) {
MS_EXCEPTION(ArgumentError) << "build kernel name:" << task_iter->second.json_name << " failed."; MS_EXCEPTION(ArgumentError) << "Can not find .json file or the binary .o file for op "
<< task_iter->second.json_name << ", go check the cache files in kernel_meta/";
} else { } else {
MS_LOG(INFO) << "fusion build kernel name:" << task_iter->second.json_name << "failed."; MS_LOG(INFO) << "fusion build kernel name:" << task_iter->second.json_name << "failed.";
auto fusion_kernel_mod = std::make_pair(task_iter->second.scope_id, nullptr); auto fusion_kernel_mod = std::make_pair(task_iter->second.scope_id, nullptr);

View File

@ -24,6 +24,7 @@ gather_v2_op_info = TBERegOp("Gather") \
.kernel_name("gather_v2") \ .kernel_name("gather_v2") \
.partial_flag(True) \ .partial_flag(True) \
.dynamic_shape(True) \ .dynamic_shape(True) \
.attr("batch_dims", "required", "int", "all", "0") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \ .input(1, "indices", False, "required", "all") \
.input(2, "axis", False, "required", "all") \ .input(2, "axis", False, "required", "all") \

View File

@ -56,7 +56,7 @@ def test_ftrl_net():
[[0.6821311, 0.6821311]], [[0.6821311, 0.6821311]],
[[0.6821311, 0.6821311]]])) [[0.6821311, 0.6821311]]]))
@pytest.mark.level1 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard