fix tbe op gather_v2
This commit is contained in:
parent
4b98f0da6d
commit
e551d1614d
|
@ -175,7 +175,8 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t
|
|||
auto kernel_pack = TbeUtils::InsertCache(json_name, processor);
|
||||
if (kernel_pack == nullptr) {
|
||||
if (set_kernel_mod) {
|
||||
MS_EXCEPTION(ArgumentError) << "build kernel name:" << task_iter->second.json_name << " failed.";
|
||||
MS_EXCEPTION(ArgumentError) << "Can not find .json file or the binary .o file for op "
|
||||
<< task_iter->second.json_name << ", go check the cache files in kernel_meta/";
|
||||
} else {
|
||||
MS_LOG(INFO) << "fusion build kernel name:" << task_iter->second.json_name << "failed.";
|
||||
auto fusion_kernel_mod = std::make_pair(task_iter->second.scope_id, nullptr);
|
||||
|
|
|
@ -24,6 +24,7 @@ gather_v2_op_info = TBERegOp("Gather") \
|
|||
.kernel_name("gather_v2") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("batch_dims", "required", "int", "all", "0") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "axis", False, "required", "all") \
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_ftrl_net():
|
|||
[[0.6821311, 0.6821311]],
|
||||
[[0.6821311, 0.6821311]]]))
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue