forked from OSSInnovation/mindspore
add support for bool tensor and scalar impilicit convert
This commit is contained in:
parent
541456044d
commit
d70b4c1b62
|
@ -106,6 +106,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
TypeId max_type_id = kTypeUnknown;
|
||||
size_t max_type_number = 0;
|
||||
bool has_int8 = false;
|
||||
bool has_scalar_int32 = false;
|
||||
bool has_scalar_float32 = false;
|
||||
for (const auto &index : indices) {
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
TypeId arg_type = kTypeUnknown;
|
||||
|
@ -114,6 +116,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
continue;
|
||||
}
|
||||
if (arg_type != kObjectTypeTensorType) {
|
||||
if (arg_type_id == kNumberTypeInt32) {
|
||||
has_scalar_int32 = true;
|
||||
} else if (arg_type_id == kNumberTypeFloat32) {
|
||||
has_scalar_float32 = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto it = type_map.find(arg_type_id);
|
||||
|
@ -135,6 +142,17 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
if (max_type_id == kNumberTypeUInt8 && has_int8 == true) {
|
||||
max_type_id = kNumberTypeInt16;
|
||||
}
|
||||
// if bool is the max type, see if there is scalar input
|
||||
// if so, it means that max is bool tensor, use scalar type instead.
|
||||
// for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2])
|
||||
if (max_type_id == kNumberTypeBool) {
|
||||
if (has_scalar_int32) {
|
||||
max_type_id = kNumberTypeInt32;
|
||||
}
|
||||
if (has_scalar_float32) {
|
||||
max_type_id = kNumberTypeFloat32;
|
||||
}
|
||||
}
|
||||
return max_type_id;
|
||||
}
|
||||
|
||||
|
|
|
@ -246,3 +246,21 @@ def test_tensor_auto_cast():
|
|||
bnet(t_fp32)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_fp64)
|
||||
def test_bool_tensor_and_float():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
t_bool = Tensor(np.ones([2, 1, 2, 2]).astype(np.bool), mstype.bool_)
|
||||
t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
|
||||
t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
|
||||
t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
|
||||
net = TensorFPAutoCast()
|
||||
out = net(t_bool)
|
||||
assert out.dtype == mstype.float32
|
||||
net = TensorIntAutoCast()
|
||||
out = net(t_bool)
|
||||
assert out.dtype == mstype.int32
|
||||
out = net(t_fp16)
|
||||
assert out.dtype == mstype.float16
|
||||
out = net(t_fp32)
|
||||
assert out.dtype == mstype.float32
|
||||
out = net(t_int32)
|
||||
assert out.dtype == mstype.int32
|
||||
|
|
Loading…
Reference in New Issue