!51092 fix backend valuenode type select

Merge pull request !51092 from 王禹程/main_clean
This commit is contained in:
i-robot 2023-03-30 02:09:41 +00:00 committed by Gitee
commit a1cdafc534
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 7 additions and 9 deletions

View File

@ -173,13 +173,14 @@ abstract::AbstractBasePtr MakeNewAbstract(const AnfNodePtr &input, const tensor:
auto abs = input->abstract();
abstract::AbstractBasePtr new_abs;
if (abs->isa<abstract::AbstractTensor>()) {
new_abs = abs->Clone();
new_abs->set_value(depended_value);
// Set user data for PyExecute infer.
if (input->has_user_data<kernel::PyExecuteOutputUserData>()) {
new_abs = abs->Clone();
new_abs->set_value(depended_value);
const auto &output_data = input->user_data<kernel::PyExecuteOutputUserData>();
new_abs->set_user_data<kernel::PyExecuteOutputUserData>(output_data);
} else {
return depended_value->ToAbstract();
}
} else if (abs->isa<abstract::AbstractScalar>()) {
auto type = depended_value->Dtype()->type_id();

View File

@ -575,15 +575,12 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> outp
std::vector<string> format = {host_tensor_ptr->device_info().host_format_};
output_format = format[0] == kOpFormat_DEFAULT ? output_format : format;
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
} else {
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {common::AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
std::vector<TypeId> output_type = {common::AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}