From e3ea96ecd0db9e4dd0bc41398a7f269bc7465804 Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 12 Oct 2020 14:51:25 +0800 Subject: [PATCH] improve the implicit conversion rule when there are int tensor and float number --- .../frontend/operator/composite/do_signature.cc | 12 +++++++----- .../ccsrc/pipeline/pynative/pynative_execute.cc | 4 ++++ mindspore/core/ir/func_graph_extends.cc | 7 +++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc index 29e4d2f4cfd..29bafc7d57d 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -68,8 +68,7 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ *max_type_number = type_number; } -bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id, - TypeId *arg_type = nullptr) { +bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, TypeId *arg_type_id, TypeId *arg_type = nullptr) { if (arg_type_origin->isa()) { auto tensor = arg_type_origin->cast(); auto tensor_type = tensor->element(); @@ -102,8 +101,7 @@ TypeId GetMaxTypeId(const std::vector &input_types, std::vector for (const auto &index : indices) { TypeId arg_type_id = kTypeUnknown; TypeId arg_type = kTypeUnknown; - auto is_write = (write_indices.find(index) != write_indices.end()); - if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) { + if (!GetTensorOrScalarTypeInfo(input_types[index], &arg_type_id, &arg_type)) { continue; } if (arg_type != kObjectTypeTensorType) { @@ -144,6 +142,10 @@ TypeId GetMaxTypeId(const std::vector &input_types, std::vector max_type_id = kNumberTypeFloat32; } } + if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 && + has_scalar_float32) { + max_type_id = kNumberTypeFloat32; + } return max_type_id; } @@ -218,7 +220,7 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign TypeId arg_type_id = kTypeUnknown; auto arg_value = input_types[i]; - (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); + (void)GetTensorOrScalarTypeInfo(arg_value, &arg_type_id); auto it_map = type_name_map.find(arg_type_id); if (it_map == type_name_map.end()) { continue; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ab98602f8be..968d01301cf 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -223,6 +223,10 @@ std::map GetDstType(const py::tuple &py_args, max_type = TypeId::kNumberTypeFloat32; } } + if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && + max_type != TypeId::kNumberTypeFloat64 && has_float) { + max_type = TypeId::kNumberTypeFloat32; + } if (max_type == TypeId::kNumberTypeUInt8 && has_int8) { max_type = TypeId::kNumberTypeInt16; } diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 133a0725109..217cb5adf61 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -126,7 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - std::set key_ward_para_nodes; + std::set kwarg_nodes; for (const auto &kwarg : kwarg_list) { MS_EXCEPTION_IF_NULL(kwarg); std::string kw_param_name = kwarg->get_key(); @@ -160,14 +160,13 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, } else { auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); // multiply values found given for parameter - if (node_itr != specialized_parameter_list->end() && - key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) { + if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) { MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name; } else { specialized_parameter_list->push_back(param_node); auto extract_node = specialized_graph->NewCNode( {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); - key_ward_para_nodes.insert(param_node); + kwarg_nodes.insert(param_node); (void)repl_nodes->emplace(param_node, extract_node); } }