forked from mindspore-Ecosystem/mindspore
!7187 improve the implicit conversion rule when there are int tensor and float number
Merge pull request !7187 from zhangbuxue/improve_the_implicit_conversion_rule_when_there_are_int_tensor_and_float_number
This commit is contained in:
commit
8d77d4fa90
|
@ -68,8 +68,7 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
|
||||||
*max_type_number = type_number;
|
*max_type_number = type_number;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id,
|
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, TypeId *arg_type_id, TypeId *arg_type = nullptr) {
|
||||||
TypeId *arg_type = nullptr) {
|
|
||||||
if (arg_type_origin->isa<TensorType>()) {
|
if (arg_type_origin->isa<TensorType>()) {
|
||||||
auto tensor = arg_type_origin->cast<TensorTypePtr>();
|
auto tensor = arg_type_origin->cast<TensorTypePtr>();
|
||||||
auto tensor_type = tensor->element();
|
auto tensor_type = tensor->element();
|
||||||
|
@ -102,8 +101,7 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t>
|
||||||
for (const auto &index : indices) {
|
for (const auto &index : indices) {
|
||||||
TypeId arg_type_id = kTypeUnknown;
|
TypeId arg_type_id = kTypeUnknown;
|
||||||
TypeId arg_type = kTypeUnknown;
|
TypeId arg_type = kTypeUnknown;
|
||||||
auto is_write = (write_indices.find(index) != write_indices.end());
|
if (!GetTensorOrScalarTypeInfo(input_types[index], &arg_type_id, &arg_type)) {
|
||||||
if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (arg_type != kObjectTypeTensorType) {
|
if (arg_type != kObjectTypeTensorType) {
|
||||||
|
@ -144,6 +142,10 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t>
|
||||||
max_type_id = kNumberTypeFloat32;
|
max_type_id = kNumberTypeFloat32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 &&
|
||||||
|
has_scalar_float32) {
|
||||||
|
max_type_id = kNumberTypeFloat32;
|
||||||
|
}
|
||||||
return max_type_id;
|
return max_type_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,7 +220,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
||||||
|
|
||||||
TypeId arg_type_id = kTypeUnknown;
|
TypeId arg_type_id = kTypeUnknown;
|
||||||
auto arg_value = input_types[i];
|
auto arg_value = input_types[i];
|
||||||
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
|
(void)GetTensorOrScalarTypeInfo(arg_value, &arg_type_id);
|
||||||
auto it_map = type_name_map.find(arg_type_id);
|
auto it_map = type_name_map.find(arg_type_id);
|
||||||
if (it_map == type_name_map.end()) {
|
if (it_map == type_name_map.end()) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -223,6 +223,10 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
||||||
max_type = TypeId::kNumberTypeFloat32;
|
max_type = TypeId::kNumberTypeFloat32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
|
||||||
|
max_type != TypeId::kNumberTypeFloat64 && has_float) {
|
||||||
|
max_type = TypeId::kNumberTypeFloat32;
|
||||||
|
}
|
||||||
if (max_type == TypeId::kNumberTypeUInt8 && has_int8) {
|
if (max_type == TypeId::kNumberTypeUInt8 && has_int8) {
|
||||||
max_type = TypeId::kNumberTypeInt16;
|
max_type = TypeId::kNumberTypeInt16;
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,7 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
||||||
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
|
||||||
std::set<AnfNodePtr> key_ward_para_nodes;
|
std::set<AnfNodePtr> kwarg_nodes;
|
||||||
for (const auto &kwarg : kwarg_list) {
|
for (const auto &kwarg : kwarg_list) {
|
||||||
MS_EXCEPTION_IF_NULL(kwarg);
|
MS_EXCEPTION_IF_NULL(kwarg);
|
||||||
std::string kw_param_name = kwarg->get_key();
|
std::string kw_param_name = kwarg->get_key();
|
||||||
|
@ -160,14 +160,13 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
||||||
} else {
|
} else {
|
||||||
auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
|
auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
|
||||||
// multiply values found given for parameter
|
// multiply values found given for parameter
|
||||||
if (node_itr != specialized_parameter_list->end() &&
|
if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) {
|
||||||
key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) {
|
|
||||||
MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
|
MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
|
||||||
} else {
|
} else {
|
||||||
specialized_parameter_list->push_back(param_node);
|
specialized_parameter_list->push_back(param_node);
|
||||||
auto extract_node = specialized_graph->NewCNode(
|
auto extract_node = specialized_graph->NewCNode(
|
||||||
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
|
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
|
||||||
key_ward_para_nodes.insert(param_node);
|
kwarg_nodes.insert(param_node);
|
||||||
(void)repl_nodes->emplace(param_node, extract_node);
|
(void)repl_nodes->emplace(param_node, extract_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue