forked from mindspore-Ecosystem/mindspore
!43056 Fix ms_funciton control flow bug
Merge pull request !43056 from luochao60/fix_ms_funtion_control_flow_tensor_address_error_20220926
This commit is contained in:
commit
bd64c6e234
|
@ -127,6 +127,9 @@ class CompareSwitchSimplify : public OptimizerCaller {
|
|||
return true;
|
||||
}
|
||||
auto value = GetValue<tensor::TensorPtr>(GetValueNode(node));
|
||||
if (value->device_address() != nullptr) {
|
||||
return true;
|
||||
}
|
||||
if (value->DataSize() > 1) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -270,10 +270,18 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
|
|||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
|
||||
|
||||
// Create device tensor.
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
|
||||
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, tensor_size, output_format, output_type_id, ShapeVector());
|
||||
device::DeviceAddressPtr address = nullptr;
|
||||
if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
|
||||
// If is_forward_output, get address from tensor
|
||||
auto tensor = node_value->cast<TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
} else {
|
||||
// Create device tensor.
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
|
||||
address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, tensor_size, output_format,
|
||||
output_type_id, ShapeVector());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create address for node:" << common::AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address
|
||||
<< " size:" << tensor_size;
|
||||
|
@ -287,6 +295,7 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
|
|||
void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index, const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto &node = front_node_with_index.first;
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start create device tensor for front node:" << front_node_with_index.first->DebugString();
|
||||
|
||||
|
@ -329,11 +338,26 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
|
|||
builder->SetOutputsDeviceType(types);
|
||||
}
|
||||
|
||||
// Create device tensor.
|
||||
size_t size = 0;
|
||||
size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
|
||||
device::DeviceAddressPtr address =
|
||||
device_context->device_res_manager_->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector());
|
||||
device::DeviceAddressPtr address = nullptr;
|
||||
if (node->isa<ValueNode>()) {
|
||||
const auto &node_value = node->cast<ValueNodePtr>()->value();
|
||||
if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
|
||||
// If is_forward_output, get address from tensor
|
||||
auto tensor = node_value->cast<TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
} else {
|
||||
// Create device tensor.
|
||||
address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id,
|
||||
ShapeVector());
|
||||
}
|
||||
} else {
|
||||
// Create device tensor.
|
||||
address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id,
|
||||
ShapeVector());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(INFO) << "Create address for node that has no corresponding backend node:"
|
||||
<< common::AnfAlgo::GetNodeDebugString(node) << " addr:" << address << " size:" << size
|
||||
|
|
Loading…
Reference in New Issue