diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 177836a1d1a..2d7311efcc2 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1089,6 +1089,7 @@ AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &arg void Parser::ParseStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector *str_nodes) { for (size_t i = 0; i < args.size(); ++i) { AnfNodePtr node = ParseExprNode(block, args[i]); + node = HandleInterpret(block, node, args[i]); (void)str_nodes->emplace_back(node); } } @@ -2350,6 +2351,7 @@ AnfNodePtr Parser::ParseJoinedStr(const FunctionBlockPtr &block, const py::objec std::vector value_nodes{NewValueNode(prim::kPrimMakeTuple)}; for (size_t i = 0; i < py_values.size(); ++i) { AnfNodePtr str_value = ParseExprNode(block, py_values[i]); + str_value = HandleInterpret(block, str_value, py_values[i]); (void)value_nodes.emplace_back(str_value); } auto func_graph = block->func_graph(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 1f93956ca8a..c4712df84b9 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -2040,8 +2040,8 @@ class RaiseEvaluator : public TransitionPrimEvaluator { auto cur_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(cur_graph); if (cur_graph->is_tensor_condition_branch()) { - MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios." - << "Tensor type data cannot exist in the conditional statement." + MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. " + << "Tensor type data cannot exist in the conditional statement. " << "Please check your conditions which raise node is located at: " << trace::GetDebugInfo(node->debug_info()); } @@ -2050,7 +2050,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator { MS_LOG(EXCEPTION) << "No active exception to reraise."; } - std::string exception_type = GetScalarStringValue(args_spec_list[0]); + std::string exception_type = GetExceptionType(args_spec_list[0]); auto iter = exception_types_map.find(exception_type); if (iter == exception_types_map.end()) { MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type @@ -2062,47 +2062,170 @@ class RaiseEvaluator : public TransitionPrimEvaluator { // Process raise ValueError() MS_EXCEPTION(type); } - std::string exception_string = ""; - for (size_t index = 1; index < args_spec_list.size(); ++index) { - exception_string += GetExceptionString(args_spec_list[index]); + std::string exception_string; + // Processed in units of nodes. Raise ValueError(xxxx) + size_t index_begin = 2; + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + bool need_out_symbol = inputs.size() > 3; + if (need_out_symbol) { + exception_string += "("; + } + for (size_t index = index_begin; index < inputs.size(); ++index) { + const auto input = inputs[index]; + auto input_abs = args_spec_list[index - 1]; + MS_EXCEPTION_IF_NULL(input_abs); + bool need_symbol = CheckNeedSymbol(input, input_abs); + if (need_symbol) { + exception_string += "'"; + } + exception_string += GetExceptionString(input_abs, input, node); + if (need_symbol) { + exception_string += "'"; + } + if (index != inputs.size() - 1) { + exception_string += ", "; + } + } + if (need_out_symbol) { + exception_string += ")"; } MS_EXCEPTION(type) << exception_string; return nullptr; } private: - std::string GetExceptionString(const AbstractBasePtr &arg) { - std::string exception_str = ""; - if (arg->isa()) { - // Process raise ValueError("str") - auto arg_tuple = arg->cast(); - const auto &arg_tuple_elements = arg_tuple->elements(); - if (arg_tuple_elements.size() == 0) { - MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty."; + // string need add quotation marks + bool CheckNeedSymbol(const AnfNodePtr &input, const AbstractBasePtr &abs) { + bool need_symbol = false; + if (abs->isa()) { + auto scalar = abs->cast(); + auto scalar_value = scalar->BuildValue(); + if (scalar_value->isa()) { + need_symbol = true; } - for (size_t index = 0; index < arg_tuple_elements.size(); ++index) { - auto &element = arg_tuple_elements[index]; - exception_str += GetScalarStringValue(element); + } else if (abs->isa()) { + auto abs_list = abs->cast(); + const auto &elements = abs_list->elements(); + for (auto &element : elements) { + if (element->isa()) { + auto scalar = element->cast(); + auto scalar_value = scalar->BuildValue(); + if (scalar_value->isa()) { + need_symbol = true; + break; + } + } } + } + return need_symbol; + } + std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) { + std::string exception_str; + if (arg->isa()) { + MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. " + << "Tensor type data cannot exist in the raise statement. " + << "Please check your raise statement which is located at: " + << trace::GetDebugInfo(node->debug_info()); + } else if (arg->isa()) { + return GetTupleString(arg, input, node); + } else if (arg->isa()) { + return GetListString(arg, input, node); } else { // Process raise ValueError - exception_str += GetScalarStringValue(arg); + exception_str += GetScalarStringValue(arg, node); } return exception_str; } - std::string GetScalarStringValue(const AbstractBasePtr &abs) { - std::string str = ""; + std::string GetTupleString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) { + std::string exception_str; + // Process raise ValueError("str") + auto arg_tuple = arg->cast(); + const auto &arg_tuple_elements = arg_tuple->elements(); + if (arg_tuple_elements.size() == 0) { + MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty."; + } + if (arg_tuple_elements.size() > 1) { + exception_str += "("; + } + for (size_t index = 0; index < arg_tuple_elements.size(); ++index) { + auto &element = arg_tuple_elements[index]; + exception_str += GetExceptionString(element, input, node); + if (index != arg_tuple_elements.size() - 1 && !IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + exception_str += ", "; + } + } + if (arg_tuple_elements.size() > 1) { + exception_str += ")"; + } + return exception_str; + } + + std::string GetListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) { + std::string exception_str; + // Process raise ValueError("str") + auto arg_list = arg->cast(); + const auto &arg_list_elements = arg_list->elements(); + if (arg_list_elements.size() == 0) { + MS_LOG(EXCEPTION) << "The arg_list_elements can't be empty."; + } + if (arg_list_elements.size() > 1) { + exception_str += "["; + } + for (size_t index = 0; index < arg_list_elements.size(); ++index) { + auto &element = arg_list_elements[index]; + exception_str += GetExceptionString(element, input, node); + if (index != arg_list_elements.size() - 1 && !IsPrimitiveCNode(input, prim::kPrimMakeList)) { + exception_str += ", "; + } + } + if (arg_list_elements.size() > 1) { + exception_str += "]"; + } + return exception_str; + } + + std::string GetExceptionType(const AbstractBasePtr &abs) { + std::string str; if (abs->isa()) { auto scalar = abs->cast(); auto scalar_value = scalar->BuildValue(); - if (scalar_value->isa()) { - str = std::to_string(GetValue(scalar_value)); - } else if (scalar_value->isa()) { + if (scalar_value->isa()) { str = GetValue(scalar_value); } + return str; } - return str; + MS_LOG(EXCEPTION) << "The abstract of exception type is not scalar: " << abs->ToString(); + } + + std::string GetScalarStringValue(const AbstractBasePtr &abs, const AnfNodePtr &node) { + std::string str; + if (abs->isa()) { + auto scalar = abs->cast(); + auto scalar_value = scalar->BuildValue(); + auto scalar_type = scalar->BuildType(); + if (scalar_value->isa()) { + str = std::to_string(GetValue(scalar_value)); + } else if (scalar_value->isa()) { + str = std::to_string(GetValue(scalar_value)); + } else if (scalar_type->isa()) { + str = std::to_string(GetValue(scalar_value)); + } else if (scalar_value->isa()) { + str = std::to_string(GetValue(scalar_value)); + } else if (scalar_value->isa()) { + str = GetValue(scalar_value); + } else { + str = scalar_value->ToString(); + } + return str; + } + MS_LOG(DEBUG) << "The abstract is not scalar: " << abs->ToString(); + MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. " + << "Tensor type data cannot exist in the raise statement. " + << "Please check your raise statement which is located at: " + << trace::GetDebugInfo(node->debug_info()); } }; diff --git a/tests/st/raise/test_graph_raise.py b/tests/st/raise/test_graph_raise.py index 32cdeff39bf..a97cbd0f302 100644 --- a/tests/st/raise/test_graph_raise.py +++ b/tests/st/raise/test_graph_raise.py @@ -286,7 +286,7 @@ def test_raise_11(): net = RaiseNet() res = net(11) print("res:", res) - assert "The input can not be 11." in str(raise_info_11.value) + assert "('The input can not be ', 11, '.')" in str(raise_info_11.value) @pytest.mark.level0 @@ -535,3 +535,248 @@ def test_raise_21(): net = RaiseNet() res = net(1) print("res:", res) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_tensor_1(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = Tensor([1]) + raise AssertionError(x) + + with pytest.raises(RuntimeError) as raise_info_tensor_1: + net = RaiseNet() + res = net() + print("res:", res) + assert "Currently only supports raise in constant scenarios." in str(raise_info_tensor_1.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_tensor_2(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + raise AssertionError(Tensor(1)) + + with pytest.raises(RuntimeError) as raise_info_tensor_2: + net = RaiseNet() + res = net() + print("res:", res) + assert "Currently only supports raise in constant scenarios." in str(raise_info_tensor_2.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_list(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = [1, 2, 3, 4] + raise ValueError(x) + + with pytest.raises(ValueError) as raise_info_list: + net = RaiseNet() + res = net() + print("res:", res) + assert "[1, 2, 3, 4]" in str(raise_info_list.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_tuple(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = (1, 2, 3, 4) + raise ValueError(x) + + with pytest.raises(ValueError) as raise_info_tuple: + net = RaiseNet() + res = net() + print("res:", res) + assert "(1, 2, 3, 4)" in str(raise_info_tuple.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_string_tuple(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = (1, 2, 3, 4) + raise ValueError("test_string_tuple", x) + + with pytest.raises(ValueError) as raise_info_string_tuple: + net = RaiseNet() + res = net() + print("res:", res) + assert "'test_string_tuple', (1, 2, 3, 4)" in str(raise_info_string_tuple.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_string_list(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = [1, 2, 3, 4] + raise ValueError("test_string_list", x) + + with pytest.raises(ValueError) as raise_info_string_list: + net = RaiseNet() + res = net() + print("res:", res) + assert "'test_string_list', [1, 2, 3, 4]" in str(raise_info_string_list.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_float(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = 1.1 + raise ValueError(x) + + with pytest.raises(ValueError) as raise_info_float: + net = RaiseNet() + res = net() + print("res:", res) + assert "1.100000" in str(raise_info_float.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_nested_list(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = [1, 2.0] + y = [x, x] + raise ValueError(x, y) + + with pytest.raises(ValueError) as raise_info_nested_list: + net = RaiseNet() + res = net() + print("res:", res) + assert "([1, 2.000000], [[1, 2.000000], [1, 2.000000]])" in str(raise_info_nested_list.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_nested_tuple(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = (1, 2.0) + y = (x, x) + raise ValueError(x, y) + + with pytest.raises(ValueError) as raise_info_nested_tuple: + net = RaiseNet() + res = net() + print("res:", res) + assert "((1, 2.000000), ((1, 2.000000), (1, 2.000000)))" in str(raise_info_nested_tuple.value) + + +@pytest.mark.skip(reason='Not support dict yet') +def test_raise_dict(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = {'a': 1, 'b': 2} + raise ValueError(x) + + with pytest.raises(ValueError) as raise_info_dict: + net = RaiseNet() + res = net() + print("res:", res) + assert "{'a': 1, 'b': 2}" in str(raise_info_dict.value) + + +@pytest.mark.skip(reason='Not support Tensor in Joined string yet') +def test_raise_joinedstr_tensor(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + raise RuntimeError(f"The input should not be {Tensor([1])}.") + + with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor: + net = RaiseNet() + res = net() + print("res:", res) + assert "The input should not be [1]" in str(raise_info_joinedstr_tensor.value)