!51092 fix backend valuenode type select
Merge pull request !51092 from 王禹程/main_clean
This commit is contained in:
commit
a1cdafc534
|
@ -173,13 +173,14 @@ abstract::AbstractBasePtr MakeNewAbstract(const AnfNodePtr &input, const tensor:
|
|||
auto abs = input->abstract();
|
||||
abstract::AbstractBasePtr new_abs;
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
new_abs = abs->Clone();
|
||||
new_abs->set_value(depended_value);
|
||||
|
||||
// Set user data for PyExecute infer.
|
||||
if (input->has_user_data<kernel::PyExecuteOutputUserData>()) {
|
||||
new_abs = abs->Clone();
|
||||
new_abs->set_value(depended_value);
|
||||
const auto &output_data = input->user_data<kernel::PyExecuteOutputUserData>();
|
||||
new_abs->set_user_data<kernel::PyExecuteOutputUserData>(output_data);
|
||||
} else {
|
||||
return depended_value->ToAbstract();
|
||||
}
|
||||
} else if (abs->isa<abstract::AbstractScalar>()) {
|
||||
auto type = depended_value->Dtype()->type_id();
|
||||
|
|
|
@ -575,15 +575,12 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> outp
|
|||
std::vector<string> format = {host_tensor_ptr->device_info().host_format_};
|
||||
output_format = format[0] == kOpFormat_DEFAULT ? output_format : format;
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
||||
} else {
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {common::AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
||||
}
|
||||
std::vector<TypeId> output_type = {common::AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue