!7187 improve the implicit conversion rule when there are int tensor and float number

Merge pull request !7187 from zhangbuxue/improve_the_implicit_conversion_rule_when_there_are_int_tensor_and_float_number
This commit is contained in:
mindspore-ci-bot 2020-10-12 17:36:15 +08:00 committed by Gitee
commit 8d77d4fa90
3 changed files with 14 additions and 9 deletions

View File

@ -68,8 +68,7 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
*max_type_number = type_number;
}
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id,
TypeId *arg_type = nullptr) {
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, TypeId *arg_type_id, TypeId *arg_type = nullptr) {
if (arg_type_origin->isa<TensorType>()) {
auto tensor = arg_type_origin->cast<TensorTypePtr>();
auto tensor_type = tensor->element();
@ -102,8 +101,7 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t>
for (const auto &index : indices) {
TypeId arg_type_id = kTypeUnknown;
TypeId arg_type = kTypeUnknown;
auto is_write = (write_indices.find(index) != write_indices.end());
if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) {
if (!GetTensorOrScalarTypeInfo(input_types[index], &arg_type_id, &arg_type)) {
continue;
}
if (arg_type != kObjectTypeTensorType) {
@ -144,6 +142,10 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t>
max_type_id = kNumberTypeFloat32;
}
}
if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 &&
has_scalar_float32) {
max_type_id = kNumberTypeFloat32;
}
return max_type_id;
}
@ -218,7 +220,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
TypeId arg_type_id = kTypeUnknown;
auto arg_value = input_types[i];
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
(void)GetTensorOrScalarTypeInfo(arg_value, &arg_type_id);
auto it_map = type_name_map.find(arg_type_id);
if (it_map == type_name_map.end()) {
continue;

View File

@ -223,6 +223,10 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
max_type = TypeId::kNumberTypeFloat32;
}
}
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
max_type != TypeId::kNumberTypeFloat64 && has_float) {
max_type = TypeId::kNumberTypeFloat32;
}
if (max_type == TypeId::kNumberTypeUInt8 && has_int8) {
max_type = TypeId::kNumberTypeInt16;
}

View File

@ -126,7 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::set<AnfNodePtr> key_ward_para_nodes;
std::set<AnfNodePtr> kwarg_nodes;
for (const auto &kwarg : kwarg_list) {
MS_EXCEPTION_IF_NULL(kwarg);
std::string kw_param_name = kwarg->get_key();
@ -160,14 +160,13 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
} else {
auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
// multiply values found given for parameter
if (node_itr != specialized_parameter_list->end() &&
key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) {
if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) {
MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
} else {
specialized_parameter_list->push_back(param_node);
auto extract_node = specialized_graph->NewCNode(
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
key_ward_para_nodes.insert(param_node);
kwarg_nodes.insert(param_node);
(void)repl_nodes->emplace(param_node, extract_node);
}
}