forked from mindspore-Ecosystem/mindspore
support convert bool scalar and tensor to number tensor
This commit is contained in:
parent
32a72c1979
commit
62f7dc49e5
|
@ -255,9 +255,6 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
|
||||||
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
|
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if ((arg_type_id == kNumberTypeBool || it->second == kNumberTypeBool) && arg_type_id != it->second) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
|
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -101,9 +101,7 @@ def test_pow():
|
||||||
result = testpow(input_tensor, power)
|
result = testpow(input_tensor, power)
|
||||||
assert np.all(result.asnumpy() == expect)
|
assert np.all(result.asnumpy() == expect)
|
||||||
net = PowNet()
|
net = PowNet()
|
||||||
with pytest.raises(TypeError):
|
|
||||||
net(input_tensor, True)
|
net(input_tensor, True)
|
||||||
with pytest.raises(TypeError):
|
|
||||||
net(input_tensor, power2)
|
net(input_tensor, power2)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -293,13 +293,6 @@ raise_set = [
|
||||||
'desc_inputs': [5.0],
|
'desc_inputs': [5.0],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
|
||||||
# input x is Tensor(bool)
|
|
||||||
('Pow1', {
|
|
||||||
'block': (P.Pow(),
|
|
||||||
{'exception': TypeError, 'error_keywords': ['Pow']}),
|
|
||||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
|
|
||||||
# input is not Tensor
|
# input is not Tensor
|
||||||
('Exp1', {
|
('Exp1', {
|
||||||
'block': (P.Exp(),
|
'block': (P.Exp(),
|
||||||
|
|
Loading…
Reference in New Issue