support convert bool scalar and tensor to number tensor

This commit is contained in:
buxue 2020-06-01 16:04:41 +08:00
parent 32a72c1979
commit 62f7dc49e5
3 changed files with 2 additions and 14 deletions

View File

@ -255,9 +255,6 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
continue; continue;
} }
if ((arg_type_id == kNumberTypeBool || it->second == kNumberTypeBool) && arg_type_id != it->second) {
continue;
}
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
} }
} }

View File

@ -101,9 +101,7 @@ def test_pow():
result = testpow(input_tensor, power) result = testpow(input_tensor, power)
assert np.all(result.asnumpy() == expect) assert np.all(result.asnumpy() == expect)
net = PowNet() net = PowNet()
with pytest.raises(TypeError):
net(input_tensor, True) net(input_tensor, True)
with pytest.raises(TypeError):
net(input_tensor, power2) net(input_tensor, power2)

View File

@ -293,13 +293,6 @@ raise_set = [
'desc_inputs': [5.0], 'desc_inputs': [5.0],
'skip': ['backward']}), 'skip': ['backward']}),
# input x is Tensor(bool)
('Pow1', {
'block': (P.Pow(),
{'exception': TypeError, 'error_keywords': ['Pow']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0],
'skip': ['backward']}),
# input is not Tensor # input is not Tensor
('Exp1', { ('Exp1', {
'block': (P.Exp(), 'block': (P.Exp(),