add support for bool tensor and scalar impilicit convert

This commit is contained in:
huangdongrun 2020-07-10 17:12:08 +08:00
parent 541456044d
commit d70b4c1b62
2 changed files with 36 additions and 0 deletions

View File

@ -106,6 +106,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId max_type_id = kTypeUnknown;
size_t max_type_number = 0;
bool has_int8 = false;
bool has_scalar_int32 = false;
bool has_scalar_float32 = false;
for (const auto &index : indices) {
TypeId arg_type_id = kTypeUnknown;
TypeId arg_type = kTypeUnknown;
@ -114,6 +116,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
continue;
}
if (arg_type != kObjectTypeTensorType) {
if (arg_type_id == kNumberTypeInt32) {
has_scalar_int32 = true;
} else if (arg_type_id == kNumberTypeFloat32) {
has_scalar_float32 = true;
}
continue;
}
auto it = type_map.find(arg_type_id);
@ -135,6 +142,17 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
if (max_type_id == kNumberTypeUInt8 && has_int8 == true) {
max_type_id = kNumberTypeInt16;
}
// if bool is the max type, see if there is scalar input
// if so, it means that max is bool tensor, use scalar type instead.
// for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2])
if (max_type_id == kNumberTypeBool) {
if (has_scalar_int32) {
max_type_id = kNumberTypeInt32;
}
if (has_scalar_float32) {
max_type_id = kNumberTypeFloat32;
}
}
return max_type_id;
}

View File

@ -246,3 +246,21 @@ def test_tensor_auto_cast():
bnet(t_fp32)
with pytest.raises(TypeError):
bnet(t_fp64)
def test_bool_tensor_and_float():
context.set_context(mode=context.GRAPH_MODE)
t_bool = Tensor(np.ones([2, 1, 2, 2]).astype(np.bool), mstype.bool_)
t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
net = TensorFPAutoCast()
out = net(t_bool)
assert out.dtype == mstype.float32
net = TensorIntAutoCast()
out = net(t_bool)
assert out.dtype == mstype.int32
out = net(t_fp16)
assert out.dtype == mstype.float16
out = net(t_fp32)
assert out.dtype == mstype.float32
out = net(t_int32)
assert out.dtype == mstype.int32