Merge pull request !49560 from liubuyu/bug_fix
This commit is contained in:
i-robot 2023-03-02 12:46:05 +00:00 committed by Gitee
commit bea6862c57
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 20 additions and 3 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -198,6 +198,7 @@ int AclKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector
AscendKernelMod::UpdateOutputSizeList();
// Update Output desc list
UpdateOutput(node, node_op_runtime_info);
need_skip_execute_ = AnfAlgo::IsDynamicShapeSkipExecute(cnode);
return 0;
}
@ -246,6 +247,21 @@ bool AclKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
}
auto node = anf_node_.lock();
MS_EXCEPTION_IF_NULL(node);
if (need_skip_execute_) {
// Skip reduce if axis is a empty Tensor (shape = 0)
MS_LOG(INFO) << "The node " << node->fullname_with_scope() << " need skip.";
// cppcheck-suppress unreadVariable
auto lock = device::KernelRuntime::LockRuntime(stream_ptr);
rtError_t status = aclrtMemcpyAsync(outputs[0]->addr, inputs[0]->size, inputs[0]->addr, inputs[0]->size,
ACL_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
if (status != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "AclrtMemcpyAsync failed for " << node->fullname_with_scope();
}
MS_LOG(INFO) << "Execute node:" << node->fullname_with_scope() << " success.";
return true;
}
auto node_op_runtime_info = node->user_data<runtime::OpRuntimeInfo>();
bool node_acl_runtime_info_legal = node_op_runtime_info != nullptr &&
node_op_runtime_info->acl_runtime_info_ != nullptr &&

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,6 +58,7 @@ class AclKernelMod : public AscendKernelMod {
std::vector<GeTensorDescPtr> output_desc_list_{};
std::string op_type_{};
bool is_dynamic_{false};
bool need_skip_execute_ = false;
};
using AclKernelModPtr = std::shared_ptr<AclKernelMod>;

View File

@ -126,7 +126,7 @@ class TestUnsortedSegmentArithmeticNet(nn.Cell):
return self.func(x, segment_ids, self.num_segments)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard