!1503 don't set parameter's format when it's has been setted before

Merge pull request !1503 from lianliguang/master
This commit is contained in:
mindspore-ci-bot 2020-05-28 18:59:46 +08:00 committed by Gitee
commit 6d056e0c0a
2 changed files with 3 additions and 18 deletions

View File

@ -160,26 +160,13 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (real_input_node->isa<CNode>()) {
continue;
}
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// we set special device info of a input tensor.
bool is_ref = false;
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}

View File

@ -298,9 +298,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
if (node->isa<ValueNode>()) {
return true;
} else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
return true;
}
return false;