forked from OSSInnovation/mindspore
!1664 revert parameter set kernel build info
Merge pull request !1664 from lianliguang/r0.3
This commit is contained in:
commit
a74e238e21
|
@ -165,8 +165,17 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|||
}
|
||||
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// we set special device info of a input tensor.
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
|
||||
bool is_ref = false;
|
||||
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
||||
if (op_info != nullptr) {
|
||||
is_ref = op_info->is_ref();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
|
||||
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
||||
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
|
||||
|
|
Loading…
Reference in New Issue