[ME] Support Raise list and tuple, prompt user variable scenario error when raising Tensor.
This commit is contained in:
parent
087d122f9f
commit
a4c2cdf2b2
|
@ -1089,6 +1089,7 @@ AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &arg
|
|||
void Parser::ParseStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes) {
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
AnfNodePtr node = ParseExprNode(block, args[i]);
|
||||
node = HandleInterpret(block, node, args[i]);
|
||||
(void)str_nodes->emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
@ -2350,6 +2351,7 @@ AnfNodePtr Parser::ParseJoinedStr(const FunctionBlockPtr &block, const py::objec
|
|||
std::vector<AnfNodePtr> value_nodes{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t i = 0; i < py_values.size(); ++i) {
|
||||
AnfNodePtr str_value = ParseExprNode(block, py_values[i]);
|
||||
str_value = HandleInterpret(block, str_value, py_values[i]);
|
||||
(void)value_nodes.emplace_back(str_value);
|
||||
}
|
||||
auto func_graph = block->func_graph();
|
||||
|
|
|
@ -2040,8 +2040,8 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
auto cur_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cur_graph);
|
||||
if (cur_graph->is_tensor_condition_branch()) {
|
||||
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios."
|
||||
<< "Tensor type data cannot exist in the conditional statement."
|
||||
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
|
||||
<< "Tensor type data cannot exist in the conditional statement. "
|
||||
<< "Please check your conditions which raise node is located at: "
|
||||
<< trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
|
@ -2050,7 +2050,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
MS_LOG(EXCEPTION) << "No active exception to reraise.";
|
||||
}
|
||||
|
||||
std::string exception_type = GetScalarStringValue(args_spec_list[0]);
|
||||
std::string exception_type = GetExceptionType(args_spec_list[0]);
|
||||
auto iter = exception_types_map.find(exception_type);
|
||||
if (iter == exception_types_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type
|
||||
|
@ -2062,47 +2062,170 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
// Process raise ValueError()
|
||||
MS_EXCEPTION(type);
|
||||
}
|
||||
std::string exception_string = "";
|
||||
for (size_t index = 1; index < args_spec_list.size(); ++index) {
|
||||
exception_string += GetExceptionString(args_spec_list[index]);
|
||||
std::string exception_string;
|
||||
// Processed in units of nodes. Raise ValueError(xxxx)
|
||||
size_t index_begin = 2;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
bool need_out_symbol = inputs.size() > 3;
|
||||
if (need_out_symbol) {
|
||||
exception_string += "(";
|
||||
}
|
||||
for (size_t index = index_begin; index < inputs.size(); ++index) {
|
||||
const auto input = inputs[index];
|
||||
auto input_abs = args_spec_list[index - 1];
|
||||
MS_EXCEPTION_IF_NULL(input_abs);
|
||||
bool need_symbol = CheckNeedSymbol(input, input_abs);
|
||||
if (need_symbol) {
|
||||
exception_string += "'";
|
||||
}
|
||||
exception_string += GetExceptionString(input_abs, input, node);
|
||||
if (need_symbol) {
|
||||
exception_string += "'";
|
||||
}
|
||||
if (index != inputs.size() - 1) {
|
||||
exception_string += ", ";
|
||||
}
|
||||
}
|
||||
if (need_out_symbol) {
|
||||
exception_string += ")";
|
||||
}
|
||||
MS_EXCEPTION(type) << exception_string;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string GetExceptionString(const AbstractBasePtr &arg) {
|
||||
std::string exception_str = "";
|
||||
if (arg->isa<abstract::AbstractTuple>()) {
|
||||
// Process raise ValueError("str")
|
||||
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
|
||||
const auto &arg_tuple_elements = arg_tuple->elements();
|
||||
if (arg_tuple_elements.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty.";
|
||||
// string need add quotation marks
|
||||
bool CheckNeedSymbol(const AnfNodePtr &input, const AbstractBasePtr &abs) {
|
||||
bool need_symbol = false;
|
||||
if (abs->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_value = scalar->BuildValue();
|
||||
if (scalar_value->isa<StringImm>()) {
|
||||
need_symbol = true;
|
||||
}
|
||||
for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
|
||||
auto &element = arg_tuple_elements[index];
|
||||
exception_str += GetScalarStringValue(element);
|
||||
} else if (abs->isa<abstract::AbstractSequence>()) {
|
||||
auto abs_list = abs->cast<abstract::AbstractSequencePtr>();
|
||||
const auto &elements = abs_list->elements();
|
||||
for (auto &element : elements) {
|
||||
if (element->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = element->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_value = scalar->BuildValue();
|
||||
if (scalar_value->isa<StringImm>()) {
|
||||
need_symbol = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return need_symbol;
|
||||
}
|
||||
std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
|
||||
std::string exception_str;
|
||||
if (arg->isa<abstract::AbstractTensor>()) {
|
||||
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
|
||||
<< "Tensor type data cannot exist in the raise statement. "
|
||||
<< "Please check your raise statement which is located at: "
|
||||
<< trace::GetDebugInfo(node->debug_info());
|
||||
} else if (arg->isa<abstract::AbstractTuple>()) {
|
||||
return GetTupleString(arg, input, node);
|
||||
} else if (arg->isa<abstract::AbstractList>()) {
|
||||
return GetListString(arg, input, node);
|
||||
} else {
|
||||
// Process raise ValueError
|
||||
exception_str += GetScalarStringValue(arg);
|
||||
exception_str += GetScalarStringValue(arg, node);
|
||||
}
|
||||
return exception_str;
|
||||
}
|
||||
|
||||
std::string GetScalarStringValue(const AbstractBasePtr &abs) {
|
||||
std::string str = "";
|
||||
std::string GetTupleString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
|
||||
std::string exception_str;
|
||||
// Process raise ValueError("str")
|
||||
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
|
||||
const auto &arg_tuple_elements = arg_tuple->elements();
|
||||
if (arg_tuple_elements.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty.";
|
||||
}
|
||||
if (arg_tuple_elements.size() > 1) {
|
||||
exception_str += "(";
|
||||
}
|
||||
for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
|
||||
auto &element = arg_tuple_elements[index];
|
||||
exception_str += GetExceptionString(element, input, node);
|
||||
if (index != arg_tuple_elements.size() - 1 && !IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
|
||||
exception_str += ", ";
|
||||
}
|
||||
}
|
||||
if (arg_tuple_elements.size() > 1) {
|
||||
exception_str += ")";
|
||||
}
|
||||
return exception_str;
|
||||
}
|
||||
|
||||
std::string GetListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
|
||||
std::string exception_str;
|
||||
// Process raise ValueError("str")
|
||||
auto arg_list = arg->cast<abstract::AbstractListPtr>();
|
||||
const auto &arg_list_elements = arg_list->elements();
|
||||
if (arg_list_elements.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The arg_list_elements can't be empty.";
|
||||
}
|
||||
if (arg_list_elements.size() > 1) {
|
||||
exception_str += "[";
|
||||
}
|
||||
for (size_t index = 0; index < arg_list_elements.size(); ++index) {
|
||||
auto &element = arg_list_elements[index];
|
||||
exception_str += GetExceptionString(element, input, node);
|
||||
if (index != arg_list_elements.size() - 1 && !IsPrimitiveCNode(input, prim::kPrimMakeList)) {
|
||||
exception_str += ", ";
|
||||
}
|
||||
}
|
||||
if (arg_list_elements.size() > 1) {
|
||||
exception_str += "]";
|
||||
}
|
||||
return exception_str;
|
||||
}
|
||||
|
||||
std::string GetExceptionType(const AbstractBasePtr &abs) {
|
||||
std::string str;
|
||||
if (abs->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_value = scalar->BuildValue();
|
||||
if (scalar_value->isa<Int64Imm>()) {
|
||||
str = std::to_string(GetValue<int64_t>(scalar_value));
|
||||
} else if (scalar_value->isa<StringImm>()) {
|
||||
if (scalar_value->isa<StringImm>()) {
|
||||
str = GetValue<std::string>(scalar_value);
|
||||
}
|
||||
return str;
|
||||
}
|
||||
return str;
|
||||
MS_LOG(EXCEPTION) << "The abstract of exception type is not scalar: " << abs->ToString();
|
||||
}
|
||||
|
||||
std::string GetScalarStringValue(const AbstractBasePtr &abs, const AnfNodePtr &node) {
|
||||
std::string str;
|
||||
if (abs->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_value = scalar->BuildValue();
|
||||
auto scalar_type = scalar->BuildType();
|
||||
if (scalar_value->isa<Int64Imm>()) {
|
||||
str = std::to_string(GetValue<int64_t>(scalar_value));
|
||||
} else if (scalar_value->isa<Int32Imm>()) {
|
||||
str = std::to_string(GetValue<int32_t>(scalar_value));
|
||||
} else if (scalar_type->isa<Float>()) {
|
||||
str = std::to_string(GetValue<float>(scalar_value));
|
||||
} else if (scalar_value->isa<BoolImm>()) {
|
||||
str = std::to_string(GetValue<bool>(scalar_value));
|
||||
} else if (scalar_value->isa<StringImm>()) {
|
||||
str = GetValue<std::string>(scalar_value);
|
||||
} else {
|
||||
str = scalar_value->ToString();
|
||||
}
|
||||
return str;
|
||||
}
|
||||
MS_LOG(DEBUG) << "The abstract is not scalar: " << abs->ToString();
|
||||
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
|
||||
<< "Tensor type data cannot exist in the raise statement. "
|
||||
<< "Please check your raise statement which is located at: "
|
||||
<< trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -286,7 +286,7 @@ def test_raise_11():
|
|||
net = RaiseNet()
|
||||
res = net(11)
|
||||
print("res:", res)
|
||||
assert "The input can not be 11." in str(raise_info_11.value)
|
||||
assert "('The input can not be ', 11, '.')" in str(raise_info_11.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -535,3 +535,248 @@ def test_raise_21():
|
|||
net = RaiseNet()
|
||||
res = net(1)
|
||||
print("res:", res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_tensor_1():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = Tensor([1])
|
||||
raise AssertionError(x)
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_tensor_1:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "Currently only supports raise in constant scenarios." in str(raise_info_tensor_1.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_tensor_2():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
raise AssertionError(Tensor(1))
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_tensor_2:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "Currently only supports raise in constant scenarios." in str(raise_info_tensor_2.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_list():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = [1, 2, 3, 4]
|
||||
raise ValueError(x)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_list:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "[1, 2, 3, 4]" in str(raise_info_list.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_tuple():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = (1, 2, 3, 4)
|
||||
raise ValueError(x)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_tuple:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "(1, 2, 3, 4)" in str(raise_info_tuple.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_string_tuple():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = (1, 2, 3, 4)
|
||||
raise ValueError("test_string_tuple", x)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_string_tuple:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "'test_string_tuple', (1, 2, 3, 4)" in str(raise_info_string_tuple.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_string_list():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = [1, 2, 3, 4]
|
||||
raise ValueError("test_string_list", x)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_string_list:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "'test_string_list', [1, 2, 3, 4]" in str(raise_info_string_list.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_float():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = 1.1
|
||||
raise ValueError(x)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_float:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "1.100000" in str(raise_info_float.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_nested_list():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = [1, 2.0]
|
||||
y = [x, x]
|
||||
raise ValueError(x, y)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_nested_list:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "([1, 2.000000], [[1, 2.000000], [1, 2.000000]])" in str(raise_info_nested_list.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_nested_tuple():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = (1, 2.0)
|
||||
y = (x, x)
|
||||
raise ValueError(x, y)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_nested_tuple:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "((1, 2.000000), ((1, 2.000000), (1, 2.000000)))" in str(raise_info_nested_tuple.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support dict yet')
|
||||
def test_raise_dict():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = {'a': 1, 'b': 2}
|
||||
raise ValueError(x)
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_dict:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "{'a': 1, 'b': 2}" in str(raise_info_dict.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support Tensor in Joined string yet')
|
||||
def test_raise_joinedstr_tensor():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
raise RuntimeError(f"The input should not be {Tensor([1])}.")
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "The input should not be [1]" in str(raise_info_joinedstr_tensor.value)
|
||||
|
|
Loading…
Reference in New Issue