!30139 Fix host format bug

Merge pull request !30139 from liangzelang/dev_master_1
This commit is contained in:
i-robot 2022-02-17 11:41:33 +00:00 committed by Gitee
commit b87cc92a94
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 0 deletions

View File

@ -392,6 +392,10 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> outp
MS_EXCEPTION_IF_NULL(selected_kernel_info);
if (IsValueNode<tensor::Tensor>(real_input_node) &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
auto host_tensor_ptr = GetValueNode<tensor::TensorPtr>(real_input_node);
MS_EXCEPTION_IF_NULL(host_tensor_ptr);
std::vector<string> format = {host_tensor_ptr->device_info().host_format_};
output_format = format[0] == kOpFormat_DEFAULT ? output_format : format;
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);