!25155 Support ops, nn and numpy namespaces, and add test case.

Merge pull request !25155 from 张清华/opt_fallback
This commit is contained in:
i-robot 2021-10-23 03:50:53 +00:00 committed by Gitee
commit 6f413ee29d
9 changed files with 132 additions and 79 deletions

View File

@ -174,21 +174,21 @@ def resolve_symbol(namespace, symbol):
# If need trope the obj
if resolve_ in convert_object_map:
resolve_ = convert_object_map.get(resolve_)
logger.debug("convert resolve = %r", resolve_)
logger.debug("Convert resolve = %r", resolve_)
if resolve_ == NO_IMPLEMENT:
raise NotImplementedError(f"Not support for `{symbol}`.")
raise NotImplementedError(f"Not support for '{symbol}'.")
except Exception as e:
if isinstance(e, NotImplementedError):
raise e
resolve_ = None
logger.debug("resolve exception occurred, value = %r", e)
logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
logger.debug("Resolve exception occurred, value = %r", e)
logger.debug("Resolve type is invalid, namespace = %s, symbol = %s",
namespace.__str__(), symbol)
if isinstance(resolve_, _MindsporeFunctionExecutor):
logger.debug("resolve class _MindsporeFunctionExecutor, resolve fn instead.")
logger.debug("Resolve class _MindsporeFunctionExecutor, resolve fn instead.")
resolve_ = resolve_.fn
logger.debug(f'found: {symbol} in {namespace.__str__()}, resolve: {resolve_} / {type(resolve_)}')
logger.debug(f"Found '{symbol}' in {namespace.__str__()}, resolved: {resolve_} / {type(resolve_)}")
return resolve_
@ -267,8 +267,8 @@ def get_obj_type(obj):
else:
# here for ndarray, just print its shape (in case of the array to large and print many data in screen)
is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
raise TypeError(f'Not support for this object with type `{type(obj)}` and '
f'{"shape" if is_ndarray else "value"} `{obj.shape if is_ndarray else obj}`.')
raise TypeError(f"Not support for this object with type '{type(obj)}' and "
f"{'shape' if is_ndarray else 'value'} '{obj.shape if is_ndarray else obj}'.")
return obj_type
@ -383,10 +383,10 @@ def get_object_description(obj, fname, fline):
"""return method or funcition description for error report, include location, class name, etc."""
if isinstance(obj, types.MethodType):
obj_cls = obj.__self__.__class__
class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}'
class_name = f"{obj_cls.__module__}.{obj_cls.__qualname__}"
cls_fname = inspect.getfile(obj_cls)
_, cls_fline = inspect.getsourcelines(obj_cls)
class_loc = f'{cls_fname}:{cls_fline}'
class_loc = f"{cls_fname}:{cls_fline}"
return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>"
if isinstance(obj, types.FunctionType):
return f"function '{obj.__name__}' at {fname}:{fline}"
@ -460,7 +460,7 @@ def get_ast_type(node):
def get_node_type(node):
"""Process an ast node."""
method_name = f'{node.__class__.__name__}'
method_name = f"{node.__class__.__name__}"
node_type = [method_name]
# judge the ast main type
if isinstance(node, ast.stmt):
@ -511,7 +511,7 @@ def eval_script(exp_str, params):
if len(params) != 2:
raise ValueError(f"eval_script(), params tuple length is wrong, params: {params}")
logger.debug(f'exp_str: {exp_str}, params: {params}')
logger.debug(f"exp_str: '{exp_str}', params: '{params}'")
global_params = params[0]
local_params = params[1]
obj = eval(exp_str, global_params, local_params)
@ -539,6 +539,7 @@ class Parser:
# Used to resolve mindspore builtin ops namespace.
self.ms_common_ns = CellNamespace('mindspore.common')
self.ms_nn_ns = CellNamespace('mindspore.nn')
self.ms_ops_ns = CellNamespace('mindspore.ops')
self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite')
self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops')
@ -586,93 +587,106 @@ class Parser:
def is_unsupported_namespace(self, value):
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
logger.debug(f'`{value}` unsupported: {unsupported}.')
logger.debug(f"'{value}' unsupported: {unsupported}.")
return unsupported
def get_namespace_symbol(self, var: str):
"""Get symbol type and namespace and symbol."""
if var in self.closure_namespace:
logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}")
logger.debug(f"Found '{var}' in closure_namespace {self.closure_namespace.__str__()}")
return self.closure_namespace, var
if var in self.global_namespace:
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}")
logger.debug(f"Found '{var}' in global_namespace {self.global_namespace.__str__()}")
value = self.global_namespace[var]
if self.is_unsupported_namespace(value):
error_info = f"The builtin function '{var}' of python is not supported in graph mode."
return None, var, error_info
return self.global_namespace, var
error_info = f"The symbol '{var}' is not defined in function '{self.function_name}'."
error_info = f"The name '{var}' is not defined in function '{self.function_name}'."
return None, var, error_info
def is_unsupported_builtin_type(self, value_type):
"""To check if not supported builtin type"""
logger.debug(f'value_type: {value_type}, {type([])}, {type(())}.')
logger.debug(f"value_type: {value_type}, {type([])}, {type(())}.")
return value_type in (list, tuple)
def is_supported_namespace_module(self, value):
"""To check if the module is allowed to support."""
# Check `mindspore` namespace.
if not hasattr(value, '__name__'):
logger.debug(f'`{str(value)}` has no `__name__` attribute.')
logger.debug(f"'{str(value)}' has no '__name__' attribute, we suppose it's supported.")
return True
name = value.__name__
if name == 'mindspore':
logger.debug(f'Found `{name}` in mindspore root namespace.')
logger.debug(f"Found 'mindspore' root namespace.")
return True
if name == 'mindspore.ops':
logger.debug(f"Found 'mindspore.ops' namespace.")
return True
if name == 'mindspore.nn':
logger.debug(f"Found 'mindspore.nn' namespace.")
return True
if name == 'mindspore.numpy':
logger.debug(f"Found 'mindspore.numpy' namespace.")
return True
# Check `Tensor` namespace.
if value == Tensor:
logger.debug(f'Not support `{name}`.')
logger.debug(f"Not support '{name}'.")
return False
# Check `builtins` namespace.
if hasattr(value, '__module__'): # Not types.ModuleType
mod = value.__module__
if mod == 'builtins':
logger.debug(f'Found `{name}` in `builtins` namespace.')
logger.debug(f"Found '{name}' in 'builtins' namespace.")
return True
# We suppose it's supported if not a Module.
if not isinstance(value, types.ModuleType):
logger.debug(f'Found `{name}`, not a module.')
logger.debug(f"Found '{name}', not a module.")
return True
# Check supported Module namespace.
rightmost_name = name.split('.')[-1]
if rightmost_name in self.ms_ops_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.')
logger.debug(f"Found '{name}'({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.")
return True
if rightmost_name in self.ms_ops_c_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.')
logger.debug(f"Found '{name}'({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.")
return True
if rightmost_name in self.ms_ops_c_multitype_ns:
logger.debug(
f'Found `{name}`({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.')
f"Found '{name}'({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.")
return True
if rightmost_name in self.ms_ops_p_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.')
logger.debug(f"Found '{name}'({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.")
return True
if rightmost_name in self.ms_common_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_common_ns)}.')
logger.debug(f"Found '{name}'({rightmost_name}) in common namespace: {str(self.ms_common_ns)}.")
return True
# Support nn.layer. To check if exclude other module.
if rightmost_name in self.ms_nn_ns:
logger.info(f"Found '{name}'({rightmost_name}) in nn namespace: {str(self.ms_nn_ns)}.")
return True
if rightmost_name in trope_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {str(trope_ns)}.')
logger.debug(f"Found '{name}'({rightmost_name}) in trope namespace: {str(trope_ns)}.")
return True
logger.error(f'Not found `{name}` in mindspore supported namespace.')
logger.error(f"Not found '{name}' in mindspore supported namespace.")
return False
def get_builtin_namespace_symbol(self, var: str):
"""Get mindspore builtin namespace and symbol."""
if var in self.closure_namespace:
logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}.")
logger.debug(f"Found '{var}' in closure_namespace {self.closure_namespace.__str__()}.")
return self.closure_namespace, var
if var in self.global_namespace:
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}.")
logger.debug(f"Found '{var}' in global_namespace {self.global_namespace.__str__()}.")
value = self.global_namespace[var]
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.")
logger.debug(f"value: {type(value)}, '{value_str}', hasattr(__name__): {hasattr(value, '__name__')}.")
# To check if allowed to support.
if self.is_unsupported_namespace(value):
return self.global_namespace, var, value
@ -683,7 +697,7 @@ class Parser:
return self.global_namespace, var
error_info = f"The name '{var}' is not defined, or not supported in graph mode."
logger.debug(f'error info: {error_info}')
logger.debug(f"error_info: {error_info}")
return None, var, error_info
def analyze_super(self, class_type_node, subclass_instance):

View File

@ -497,7 +497,6 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager
manager->AddFuncGraph(root);
bool changed = false;
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
auto all_node = manager->all_nodes();
for (auto &node : all_node) {

View File

@ -115,8 +115,8 @@ class PyObjectWrapper : public Named {
// InterpretedObject class wrappers interpreted python object.
class InterpretedObject : public PyObjectWrapper {
public:
explicit InterpretedObject(const py::object &obj, const std::string &name = "Interpreted object")
: PyObjectWrapper(obj, name) {}
explicit InterpretedObject(const py::object &obj, const std::string &name = "null")
: PyObjectWrapper(obj, "InterpretedObject: '" + name + "'") {}
~InterpretedObject() override = default;
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override {

View File

@ -864,8 +864,8 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
auto data_v = args_spec_list[0]->BuildValue();
MS_EXCEPTION_IF_NULL(data_v);
if (!data_v->isa<parse::NameSpace>()) {
MS_LOG(EXCEPTION) << "Not supported to get attribute for " << data_v->ToString()
<< "\nThe first argument should be a NameSpace, but got " << args_spec_list[0]->ToString();
MS_EXCEPTION(TypeError) << "Not supported to get attribute for " << data_v->ToString()
<< "\nThe first argument should be a NameSpace, but got " << args_spec_list[0]->ToString();
}
auto item_value = args_spec_list[1]->BuildValue();

View File

@ -24,6 +24,7 @@
#include "ir/manager.h"
#include "ir/dtype.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "pipeline/jit/parse/resolve.h"
namespace mindspore {
namespace validator {
@ -72,19 +73,18 @@ void ValidateOperation(const AnfNodePtr &node) {
bool CheckAbstractScalar(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
AbstractBasePtr ptrBase = node->abstract();
if (ptrBase->isa<AbstractScalar>()) {
TypePtr ptrType = ptrBase->GetTypeTrack();
MS_EXCEPTION_IF_NULL(ptrType);
if (ptrType->isa<EnvType>()) {
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString();
AbstractBasePtr abstract = node->abstract();
if (abstract->isa<AbstractScalar>()) {
TypePtr type = abstract->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->isa<EnvType>()) {
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
}
if (ptrType->isa<Problem>() || ptrType->isa<External>()) {
// only send string in external
if (type->isa<Problem>() || type->isa<External>()) {
// Only allow string type from external.
if (!IsValueNode<StringImm>(node)) {
// Validate a type.
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString()
<< " for node=" << node->DebugString();
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
}
}
return true;
@ -97,43 +97,58 @@ void ValidateAbstract(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "Node to validate is invalid";
return;
}
AbstractBasePtr ptrBase = node->abstract();
if (ptrBase == nullptr) {
AbstractBasePtr abstract = node->abstract();
if (abstract == nullptr) {
MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString();
return;
}
if (ptrBase->isa<AbstractClass>() || ptrBase->isa<AbstractJTagged>()) {
if (abstract->isa<AbstractClass>() || abstract->isa<AbstractJTagged>()) {
// Validate a type.
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString();
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
}
if (CheckAbstractScalar(node)) {
return;
}
if (ptrBase->isa<AbstractError>()) {
if (abstract->isa<AbstractError>()) {
// NOTICE: validate dead code?
MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString();
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
return;
}
bool checkAbstractIslegal =
ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>() ||
ptrBase->isa<abstract::AbstractNone>() || ptrBase->isa<abstract::AbstractMonad>();
if (checkAbstractIslegal) {
bool is_legal_abstract =
abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() || abstract->isa<AbstractTuple>() ||
abstract->isa<AbstractList>() || abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
abstract->isa<AbstractSparseTensor>() || abstract->isa<abstract::AbstractRefKey>() ||
abstract->isa<AbstractRef>() || abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
if (is_legal_abstract) {
return;
}
// Other types show exception
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString();
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString();
}
void ValidateValueNode(const AnfNodePtr &node) {
if (node == nullptr) {
MS_LOG(DEBUG) << "Node to validate is invalid";
return;
}
// InterpretedNode should be consumed during compile, not left to Runtime.
if (IsValueNode<parse::InterpretedObject>(node)) {
MS_LOG(EXCEPTION) << "Should not use Python object in runtime, node: " << node->DebugString()
<< "\n\nWe suppose all nodes generated by JIT Fallback not return to outside of graph.";
}
}
void Validate(const FuncGraphPtr &fg) {
FuncGraphManagerPtr mgr = Manage(fg, false);
MS_EXCEPTION_IF_NULL(mgr);
AnfNodeSet &all_nodes = mgr->all_nodes();
for (const auto &anf_node : all_nodes) {
ValidateOperation(anf_node);
ValidateAbstract(anf_node);
for (const auto &node : all_nodes) {
ValidateOperation(node);
ValidateValueNode(node);
}
for (const auto &node : all_nodes) {
ValidateAbstract(node);
}
}
} // namespace validator

View File

@ -32,6 +32,7 @@ namespace validator {
void Validate(const FuncGraphPtr &func_graph);
void ValidateAbstract(const AnfNodePtr &node);
void ValidateOperation(const AnfNodePtr &node);
void ValidateValueNode(const AnfNodePtr &node);
} // namespace validator
} // namespace mindspore

View File

@ -50,5 +50,15 @@ TEST_F(TestValidator, ValidateAbstract01) {
// normally, the above statement should not exit, so expected the following statement execute
EXPECT_TRUE(true);
}
/// Feature: JIT Fallback
/// Description: Make sure no interpreted node.
/// Expectation: No exception.
TEST_F(TestValidator, ValidateValueNode01) {
AnfNodePtr node = NewValueNode(static_cast<int64_t>(1));
ValidateValueNode(node);
// normally, the above statement should not exit, so expected the following statement execute
EXPECT_TRUE(true);
}
} // namespace validator
} // namespace mindspore

View File

@ -100,9 +100,23 @@ def div_mod_func2(x, y):
return Tensor(a)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func2():
def test_div_mod_func2_scalar():
"""
Feature: JIT Fallback
Description: Test divmod in graph.
Expectation: No exception.
"""
print(div_mod_func2(8, 3)) # (2, 2)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func2_tensor():
"""
Feature: JIT Fallback
Description: Test divmod in graph.
Expectation: No exception.
"""
print(div_mod_func2(Tensor(8), Tensor(3))) # name 'x' is not defined
# NameError: name 'Tensor' is not defined.
@ms_function

View File

@ -32,7 +32,7 @@ def test_use_undefined_name():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'a' is not defined" in str(err.value)
assert "The name 'a' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(29)" in \
str(err.value)
assert "ret = x + a" in str(err.value)
@ -48,7 +48,7 @@ def test_insert_undefined_name():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'b' is not defined" in str(err.value)
assert "The name 'b' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(44)" in \
str(err.value)
@ -63,7 +63,7 @@ def test_insert_undefined_name_compute():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'c' is not defined" in str(err.value)
assert "The name 'c' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(59)" in \
str(err.value)
assert "c + x" in str(err.value)
@ -80,7 +80,7 @@ def test_insert_undefined_name_in_if():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'i' is not defined" in str(err.value)
assert "The name 'i' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(76)" in \
str(err.value)
@ -98,7 +98,7 @@ def test_insert_undefined_name_in_while_inner_if():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'j' is not defined" in str(err.value)
assert "The name 'j' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(93)" in \
str(err.value)
@ -116,7 +116,7 @@ def test_insert_undefined_name_compute__in_while_inner_if():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'p' is not defined" in str(err.value)
assert "The name 'p' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(111)" in \
str(err.value)
assert "p + x" in str(err.value)
@ -139,7 +139,7 @@ def test_insert_undefined_name_compute__in_if_in_for():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'w' is not defined" in str(err.value)
assert "The name 'w' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(134)" in \
str(err.value)
assert "w" in str(err.value)
@ -161,7 +161,7 @@ def test_use_undefined_name_for_inner_if():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'y' is not defined" in str(err.value)
assert "The name 'y' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(157)" in \
str(err.value)
assert "y" in str(err.value)
@ -181,7 +181,7 @@ def test_use_undefined_name_in_for():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'd' is not defined" in str(err.value)
assert "The name 'd' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(178)" in \
str(err.value)
assert "x = x + d + i" in str(err.value)
@ -202,7 +202,7 @@ def test_insert_undefined_name_in_for():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'e' is not defined" in str(err.value)
assert "The name 'e' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(198)" in \
str(err.value)
assert "e" in str(err.value)
@ -223,7 +223,7 @@ def test_insert_undefined_name_compute_in_for():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'f' is not defined" in str(err.value)
assert "The name 'f' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(219)" in \
str(err.value)
assert "f + i" in str(err.value)
@ -239,7 +239,7 @@ def test_use_undefined_name_in_while():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'g' is not defined" in str(err.value)
assert "The name 'g' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(236)" in \
str(err.value)
assert "x = x - g" in str(err.value)
@ -256,7 +256,7 @@ def test_insert_undefined_name_in_while():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'h' is not defined" in str(err.value)
assert "The name 'h' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(252)" in \
str(err.value)
assert "h" in str(err.value)
@ -273,7 +273,7 @@ def test_insert_undefined_name_compute_while():
net = Net()
with pytest.raises(NameError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "The symbol 'i' is not defined" in str(err.value)
assert "The name 'i' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(269)" in \
str(err.value)
assert "x + i" in str(err.value)