acl bug fix
This commit is contained in:
parent
b7937c3f57
commit
129854dd74
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -198,6 +198,7 @@ int AclKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector
|
|||
AscendKernelMod::UpdateOutputSizeList();
|
||||
// Update Output desc list
|
||||
UpdateOutput(node, node_op_runtime_info);
|
||||
need_skip_execute_ = AnfAlgo::IsDynamicShapeSkipExecute(cnode);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -246,6 +247,21 @@ bool AclKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
|||
}
|
||||
auto node = anf_node_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (need_skip_execute_) {
|
||||
// Skip reduce if axis is a empty Tensor (shape = 0)
|
||||
MS_LOG(INFO) << "The node " << node->fullname_with_scope() << " need skip.";
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = device::KernelRuntime::LockRuntime(stream_ptr);
|
||||
rtError_t status = aclrtMemcpyAsync(outputs[0]->addr, inputs[0]->size, inputs[0]->addr, inputs[0]->size,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "AclrtMemcpyAsync failed for " << node->fullname_with_scope();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Execute node:" << node->fullname_with_scope() << " success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto node_op_runtime_info = node->user_data<runtime::OpRuntimeInfo>();
|
||||
bool node_acl_runtime_info_legal = node_op_runtime_info != nullptr &&
|
||||
node_op_runtime_info->acl_runtime_info_ != nullptr &&
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,6 +58,7 @@ class AclKernelMod : public AscendKernelMod {
|
|||
std::vector<GeTensorDescPtr> output_desc_list_{};
|
||||
std::string op_type_{};
|
||||
bool is_dynamic_{false};
|
||||
bool need_skip_execute_ = false;
|
||||
};
|
||||
|
||||
using AclKernelModPtr = std::shared_ptr<AclKernelMod>;
|
||||
|
|
|
@ -126,7 +126,7 @@ class TestUnsortedSegmentArithmeticNet(nn.Cell):
|
|||
return self.func(x, segment_ids, self.num_segments)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue