!43056 Fix ms_funciton control flow bug

Merge pull request !43056 from luochao60/fix_ms_funtion_control_flow_tensor_address_error_20220926
This commit is contained in:
i-robot 2022-09-29 07:14:46 +00:00 committed by Gitee
commit bd64c6e234
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 34 additions and 7 deletions

View File

@ -127,6 +127,9 @@ class CompareSwitchSimplify : public OptimizerCaller {
return true;
}
auto value = GetValue<tensor::TensorPtr>(GetValueNode(node));
if (value->device_address() != nullptr) {
return true;
}
if (value->DataSize() > 1) {
return true;
}

View File

@ -270,10 +270,18 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
MS_EXCEPTION_IF_NULL(build_info);
AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
// Create device tensor.
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, tensor_size, output_format, output_type_id, ShapeVector());
device::DeviceAddressPtr address = nullptr;
if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
// If is_forward_output, get address from tensor
auto tensor = node_value->cast<TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
} else {
// Create device tensor.
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, tensor_size, output_format,
output_type_id, ShapeVector());
}
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create address for node:" << common::AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address
<< " size:" << tensor_size;
@ -287,6 +295,7 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
const auto &node = front_node_with_index.first;
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create device tensor for front node:" << front_node_with_index.first->DebugString();
@ -329,11 +338,26 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
builder->SetOutputsDeviceType(types);
}
// Create device tensor.
size_t size = 0;
size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
device::DeviceAddressPtr address =
device_context->device_res_manager_->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector());
device::DeviceAddressPtr address = nullptr;
if (node->isa<ValueNode>()) {
const auto &node_value = node->cast<ValueNodePtr>()->value();
if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
// If is_forward_output, get address from tensor
auto tensor = node_value->cast<TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
} else {
// Create device tensor.
address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id,
ShapeVector());
}
} else {
// Create device tensor.
address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id,
ShapeVector());
}
MS_EXCEPTION_IF_NULL(address);
MS_LOG(INFO) << "Create address for node that has no corresponding backend node:"
<< common::AnfAlgo::GetNodeDebugString(node) << " addr:" << address << " size:" << size