From 934958290e7bbff534f5e0afaae9894e54370e1c Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Thu, 4 Jun 2020 16:55:03 +0800 Subject: [PATCH] set value node & parameter's device dtype to the node connected with's device info --- mindspore/ccsrc/device/ascend/kernel_select_ascend.cc | 2 +- mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 6e6e7419fd0..df19e7708ab 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -176,7 +176,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); - std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; + std::vector output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index b573cb33bb6..1203f4d406e 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -298,7 +298,10 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { - if (node->isa() || node->isa()) { + if (node->isa()) { + return true; + } + if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { return true; } return false;