forked from mindspore-Ecosystem/mindspore
!454 Fix the check of bool type in auto cast
Merge pull request !454 from candanzg/bug_cast_op
This commit is contained in:
commit
0bf6717e9a
|
@ -137,6 +137,19 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
|
|||
if (it == dst_type.end() || it->second == i || !arg_value->isa<abstract::AbstractScalar>()) {
|
||||
continue;
|
||||
}
|
||||
// When scalar is of bool type, the type of tensor must also be of bool type,
|
||||
// otherwise the cast operator will not be added.
|
||||
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_type = scalar->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(scalar_type);
|
||||
if (scalar_type->type_id() == kNumberTypeBool) {
|
||||
auto tensor = args_spec_list[it->second]->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_type = tensor->element()->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
if (tensor_type->type_id() != kNumberTypeBool) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// get source node for cast
|
||||
AnfNodePtr source_node = (*op_inputs)[it->second + 1];
|
||||
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], source_node, graph);
|
||||
|
|
|
@ -745,7 +745,7 @@ class Fill(PrimitiveWithInfer):
|
|||
out = {
|
||||
'value': Tensor(ret),
|
||||
'shape': dims['value'],
|
||||
'dtype': x_nptype,
|
||||
'dtype': x_dtype,
|
||||
}
|
||||
return out
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
||||
import pipeline_for_verify_exception_for_case_by_case_config
|
||||
import pytest
|
||||
|
||||
|
||||
# pylint: disable=W0613
|
||||
|
@ -81,14 +82,29 @@ def test_sqrt():
|
|||
assert np.all(result.asnumpy() == expect)
|
||||
|
||||
|
||||
class PowNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PowNet, self).__init__()
|
||||
self.pow = P.Pow()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.pow(x, y)
|
||||
|
||||
|
||||
def test_pow():
|
||||
""" test_pow """
|
||||
input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
|
||||
power = Tensor(np.array(3.0, np.int64))
|
||||
power2 = Tensor(np.array(True, np.bool))
|
||||
testpow = P.Pow()
|
||||
expect = np.array([[8, 8], [27, 27]])
|
||||
result = testpow(input_tensor, power)
|
||||
assert np.all(result.asnumpy() == expect)
|
||||
net = PowNet()
|
||||
with pytest.raises(TypeError):
|
||||
net(input_tensor, True)
|
||||
with pytest.raises(TypeError):
|
||||
net(input_tensor, power2)
|
||||
|
||||
|
||||
def test_exp():
|
||||
|
|
Loading…
Reference in New Issue