[ME] Support Raise list and tuple, prompt user variable scenario error when raising Tensor.

This commit is contained in:
Margaret_wangrui 2022-05-26 21:01:43 +08:00
parent 087d122f9f
commit a4c2cdf2b2
3 changed files with 395 additions and 25 deletions

View File

@ -1089,6 +1089,7 @@ AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &arg
void Parser::ParseStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes) {
for (size_t i = 0; i < args.size(); ++i) {
AnfNodePtr node = ParseExprNode(block, args[i]);
node = HandleInterpret(block, node, args[i]);
(void)str_nodes->emplace_back(node);
}
}
@ -2350,6 +2351,7 @@ AnfNodePtr Parser::ParseJoinedStr(const FunctionBlockPtr &block, const py::objec
std::vector<AnfNodePtr> value_nodes{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < py_values.size(); ++i) {
AnfNodePtr str_value = ParseExprNode(block, py_values[i]);
str_value = HandleInterpret(block, str_value, py_values[i]);
(void)value_nodes.emplace_back(str_value);
}
auto func_graph = block->func_graph();

View File

@ -2040,8 +2040,8 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
auto cur_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(cur_graph);
if (cur_graph->is_tensor_condition_branch()) {
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios."
<< "Tensor type data cannot exist in the conditional statement."
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
<< "Tensor type data cannot exist in the conditional statement. "
<< "Please check your conditions which raise node is located at: "
<< trace::GetDebugInfo(node->debug_info());
}
@ -2050,7 +2050,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
MS_LOG(EXCEPTION) << "No active exception to reraise.";
}
std::string exception_type = GetScalarStringValue(args_spec_list[0]);
std::string exception_type = GetExceptionType(args_spec_list[0]);
auto iter = exception_types_map.find(exception_type);
if (iter == exception_types_map.end()) {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type
@ -2062,48 +2062,171 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
// Process raise ValueError()
MS_EXCEPTION(type);
}
std::string exception_string = "";
for (size_t index = 1; index < args_spec_list.size(); ++index) {
exception_string += GetExceptionString(args_spec_list[index]);
std::string exception_string;
// Processed in units of nodes. Raise ValueError(xxxx)
size_t index_begin = 2;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();
bool need_out_symbol = inputs.size() > 3;
if (need_out_symbol) {
exception_string += "(";
}
for (size_t index = index_begin; index < inputs.size(); ++index) {
const auto input = inputs[index];
auto input_abs = args_spec_list[index - 1];
MS_EXCEPTION_IF_NULL(input_abs);
bool need_symbol = CheckNeedSymbol(input, input_abs);
if (need_symbol) {
exception_string += "'";
}
exception_string += GetExceptionString(input_abs, input, node);
if (need_symbol) {
exception_string += "'";
}
if (index != inputs.size() - 1) {
exception_string += ", ";
}
}
if (need_out_symbol) {
exception_string += ")";
}
MS_EXCEPTION(type) << exception_string;
return nullptr;
}
private:
std::string GetExceptionString(const AbstractBasePtr &arg) {
std::string exception_str = "";
if (arg->isa<abstract::AbstractTuple>()) {
// string need add quotation marks
bool CheckNeedSymbol(const AnfNodePtr &input, const AbstractBasePtr &abs) {
bool need_symbol = false;
if (abs->isa<abstract::AbstractScalar>()) {
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
auto scalar_value = scalar->BuildValue();
if (scalar_value->isa<StringImm>()) {
need_symbol = true;
}
} else if (abs->isa<abstract::AbstractSequence>()) {
auto abs_list = abs->cast<abstract::AbstractSequencePtr>();
const auto &elements = abs_list->elements();
for (auto &element : elements) {
if (element->isa<abstract::AbstractScalar>()) {
auto scalar = element->cast<abstract::AbstractScalarPtr>();
auto scalar_value = scalar->BuildValue();
if (scalar_value->isa<StringImm>()) {
need_symbol = true;
break;
}
}
}
}
return need_symbol;
}
std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
std::string exception_str;
if (arg->isa<abstract::AbstractTensor>()) {
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
<< "Tensor type data cannot exist in the raise statement. "
<< "Please check your raise statement which is located at: "
<< trace::GetDebugInfo(node->debug_info());
} else if (arg->isa<abstract::AbstractTuple>()) {
return GetTupleString(arg, input, node);
} else if (arg->isa<abstract::AbstractList>()) {
return GetListString(arg, input, node);
} else {
// Process raise ValueError
exception_str += GetScalarStringValue(arg, node);
}
return exception_str;
}
std::string GetTupleString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
std::string exception_str;
// Process raise ValueError("str")
auto arg_tuple = arg->cast<abstract::AbstractTuplePtr>();
const auto &arg_tuple_elements = arg_tuple->elements();
if (arg_tuple_elements.size() == 0) {
MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty.";
}
if (arg_tuple_elements.size() > 1) {
exception_str += "(";
}
for (size_t index = 0; index < arg_tuple_elements.size(); ++index) {
auto &element = arg_tuple_elements[index];
exception_str += GetScalarStringValue(element);
exception_str += GetExceptionString(element, input, node);
if (index != arg_tuple_elements.size() - 1 && !IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
exception_str += ", ";
}
} else {
// Process raise ValueError
exception_str += GetScalarStringValue(arg);
}
if (arg_tuple_elements.size() > 1) {
exception_str += ")";
}
return exception_str;
}
std::string GetScalarStringValue(const AbstractBasePtr &abs) {
std::string str = "";
std::string GetListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node) {
std::string exception_str;
// Process raise ValueError("str")
auto arg_list = arg->cast<abstract::AbstractListPtr>();
const auto &arg_list_elements = arg_list->elements();
if (arg_list_elements.size() == 0) {
MS_LOG(EXCEPTION) << "The arg_list_elements can't be empty.";
}
if (arg_list_elements.size() > 1) {
exception_str += "[";
}
for (size_t index = 0; index < arg_list_elements.size(); ++index) {
auto &element = arg_list_elements[index];
exception_str += GetExceptionString(element, input, node);
if (index != arg_list_elements.size() - 1 && !IsPrimitiveCNode(input, prim::kPrimMakeList)) {
exception_str += ", ";
}
}
if (arg_list_elements.size() > 1) {
exception_str += "]";
}
return exception_str;
}
std::string GetExceptionType(const AbstractBasePtr &abs) {
std::string str;
if (abs->isa<abstract::AbstractScalar>()) {
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
auto scalar_value = scalar->BuildValue();
if (scalar_value->isa<Int64Imm>()) {
str = std::to_string(GetValue<int64_t>(scalar_value));
} else if (scalar_value->isa<StringImm>()) {
if (scalar_value->isa<StringImm>()) {
str = GetValue<std::string>(scalar_value);
}
return str;
}
MS_LOG(EXCEPTION) << "The abstract of exception type is not scalar: " << abs->ToString();
}
std::string GetScalarStringValue(const AbstractBasePtr &abs, const AnfNodePtr &node) {
std::string str;
if (abs->isa<abstract::AbstractScalar>()) {
auto scalar = abs->cast<abstract::AbstractScalarPtr>();
auto scalar_value = scalar->BuildValue();
auto scalar_type = scalar->BuildType();
if (scalar_value->isa<Int64Imm>()) {
str = std::to_string(GetValue<int64_t>(scalar_value));
} else if (scalar_value->isa<Int32Imm>()) {
str = std::to_string(GetValue<int32_t>(scalar_value));
} else if (scalar_type->isa<Float>()) {
str = std::to_string(GetValue<float>(scalar_value));
} else if (scalar_value->isa<BoolImm>()) {
str = std::to_string(GetValue<bool>(scalar_value));
} else if (scalar_value->isa<StringImm>()) {
str = GetValue<std::string>(scalar_value);
} else {
str = scalar_value->ToString();
}
return str;
}
MS_LOG(DEBUG) << "The abstract is not scalar: " << abs->ToString();
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
<< "Tensor type data cannot exist in the raise statement. "
<< "Please check your raise statement which is located at: "
<< trace::GetDebugInfo(node->debug_info());
}
};
struct PrimitiveImplInferValue {

View File

@ -286,7 +286,7 @@ def test_raise_11():
net = RaiseNet()
res = net(11)
print("res:", res)
assert "The input can not be 11." in str(raise_info_11.value)
assert "('The input can not be ', 11, '.')" in str(raise_info_11.value)
@pytest.mark.level0
@ -535,3 +535,248 @@ def test_raise_21():
net = RaiseNet()
res = net(1)
print("res:", res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_tensor_1():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = Tensor([1])
raise AssertionError(x)
with pytest.raises(RuntimeError) as raise_info_tensor_1:
net = RaiseNet()
res = net()
print("res:", res)
assert "Currently only supports raise in constant scenarios." in str(raise_info_tensor_1.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_tensor_2():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
raise AssertionError(Tensor(1))
with pytest.raises(RuntimeError) as raise_info_tensor_2:
net = RaiseNet()
res = net()
print("res:", res)
assert "Currently only supports raise in constant scenarios." in str(raise_info_tensor_2.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_list():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = [1, 2, 3, 4]
raise ValueError(x)
with pytest.raises(ValueError) as raise_info_list:
net = RaiseNet()
res = net()
print("res:", res)
assert "[1, 2, 3, 4]" in str(raise_info_list.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_tuple():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = (1, 2, 3, 4)
raise ValueError(x)
with pytest.raises(ValueError) as raise_info_tuple:
net = RaiseNet()
res = net()
print("res:", res)
assert "(1, 2, 3, 4)" in str(raise_info_tuple.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_string_tuple():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = (1, 2, 3, 4)
raise ValueError("test_string_tuple", x)
with pytest.raises(ValueError) as raise_info_string_tuple:
net = RaiseNet()
res = net()
print("res:", res)
assert "'test_string_tuple', (1, 2, 3, 4)" in str(raise_info_string_tuple.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_string_list():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = [1, 2, 3, 4]
raise ValueError("test_string_list", x)
with pytest.raises(ValueError) as raise_info_string_list:
net = RaiseNet()
res = net()
print("res:", res)
assert "'test_string_list', [1, 2, 3, 4]" in str(raise_info_string_list.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_float():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = 1.1
raise ValueError(x)
with pytest.raises(ValueError) as raise_info_float:
net = RaiseNet()
res = net()
print("res:", res)
assert "1.100000" in str(raise_info_float.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_nested_list():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = [1, 2.0]
y = [x, x]
raise ValueError(x, y)
with pytest.raises(ValueError) as raise_info_nested_list:
net = RaiseNet()
res = net()
print("res:", res)
assert "([1, 2.000000], [[1, 2.000000], [1, 2.000000]])" in str(raise_info_nested_list.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_nested_tuple():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = (1, 2.0)
y = (x, x)
raise ValueError(x, y)
with pytest.raises(ValueError) as raise_info_nested_tuple:
net = RaiseNet()
res = net()
print("res:", res)
assert "((1, 2.000000), ((1, 2.000000), (1, 2.000000)))" in str(raise_info_nested_tuple.value)
@pytest.mark.skip(reason='Not support dict yet')
def test_raise_dict():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = {'a': 1, 'b': 2}
raise ValueError(x)
with pytest.raises(ValueError) as raise_info_dict:
net = RaiseNet()
res = net()
print("res:", res)
assert "{'a': 1, 'b': 2}" in str(raise_info_dict.value)
@pytest.mark.skip(reason='Not support Tensor in Joined string yet')
def test_raise_joinedstr_tensor():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
raise RuntimeError(f"The input should not be {Tensor([1])}.")
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
net = RaiseNet()
res = net()
print("res:", res)
assert "The input should not be [1]" in str(raise_info_joinedstr_tensor.value)