show accurate code line when use uninitialized var in for

(cherry picked from commit e3056ed9b2)
This commit is contained in:
buxue 2021-02-23 18:24:22 +08:00
parent 3b9843f57e
commit de343a0e00
5 changed files with 151 additions and 33 deletions

View File

@ -205,6 +205,7 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
MS_EXCEPTION_IF_NULL(pred);
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
AnfNodePtr arg_node = pred->ReadVariable(var);
arg_node->set_debug_info(phi->debug_info());
CNodePtr jump = pred->jumps_[this];
jump->add_input(arg_node);
}
@ -257,12 +258,13 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
MS_EXCEPTION_IF_NULL(phi);
std::string var = phi_nodes_[phi];
MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
if (prev_blocks_.size() == 0) {
if (prev_blocks_.empty()) {
MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
return false;
}
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
if (arg_node != nullptr) {
arg_node->set_debug_info(phi->debug_info());
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with "
<< arg_node->DebugString();
// Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
@ -299,7 +301,7 @@ void FunctionBlock::Mature() {
const auto &graphParamVec = func_graph_->parameters();
for (auto &paramItr : graphParamVec) {
MS_EXCEPTION_IF_NULL(paramItr);
ParameterPtr param = paramItr->cast<ParameterPtr>();
auto param = paramItr->cast<ParameterPtr>();
if (phi_nodes_.find(param) != phi_nodes_.cend()) {
SetPhiArgument(param);
}
@ -321,7 +323,7 @@ CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
}
// Perform a jump from this block to target block
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) {
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr &node) {
if (func_graph()->get_return() != nullptr) {
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
@ -407,7 +409,7 @@ void FunctionBlock::FindIsolatedNodes() {
void FunctionBlock::AddIsolatedNode(const AnfNodePtr &target) { isolated_nodes_.add(target); }
void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
if (isolated_nodes_.size() == 0) {
if (isolated_nodes_.empty()) {
return;
}
@ -415,7 +417,7 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &node : isolated_nodes_) {
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(2) << " in " << func_graph()->ToString();
if (node->func_graph() == func_graph()) {
if (node->func_graph() == func_graph() || node->isa<Parameter>()) {
states.emplace_back(node);
} else {
MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(2) << " in " << func_graph()->ToString();
@ -438,7 +440,7 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
AnfNodePtr old_output = nullptr;
auto return_node = func_graph()->get_return();
if (return_node) {
if (return_node->inputs().size() < 1) {
if (return_node->inputs().empty()) {
MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
}
old_output = return_node->input(1);

View File

@ -57,7 +57,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
void Mature();
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
void Jump(const FunctionBlockPtr &block, AnfNodePtr node);
void Jump(const FunctionBlockPtr &block, const AnfNodePtr &node);
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock,
bool unroll_loop = true);

View File

@ -156,7 +156,7 @@ AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &
auto top_graph = func_graph;
// If the parameter node has been created , return it
AnfNodePtr para_node = nullptr;
for (auto param : top_graph->parameters()) {
for (const auto &param : top_graph->parameters()) {
auto param_node = dyn_cast<Parameter>(param);
if (param_node != nullptr && param_node->name() == param_name) {
para_node = param;
@ -179,15 +179,15 @@ AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &
void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) {
auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>();
for (size_t i = 0; i < params.size(); i++) {
(void)AppendParameterObj(top_graph, params[i]);
for (const auto &param : params) {
(void)AppendParameterObj(top_graph, param);
}
}
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
// Check whether the functions referred by this function and itself are missing 'return' statement
auto mng = Manage(fn, false);
for (auto func_graph : mng->func_graphs()) {
for (const auto &func_graph : mng->func_graphs()) {
if (func_graph->get_return() != nullptr) {
continue;
}
@ -198,7 +198,7 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &as
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
}
// Clear manager info after checking missing return
for (auto fg : mng->func_graphs()) {
for (const auto &fg : mng->func_graphs()) {
fg->ClearAllManagerInfo();
}
}
@ -306,7 +306,7 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
current_fg->debug_info()->set_name(function_name);
MS_EXCEPTION_IF_NULL(ast_);
py::list deco_list = node.attr("decorator_list");
if (deco_list.size() > 0) {
if (!deco_list.empty()) {
current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
}
@ -548,7 +548,7 @@ AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
} else {
// no else actually
MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj);
errcode_ = PARSE_NODE_TYPE_UNKOWN;
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
return nullptr;
}
}
@ -600,7 +600,7 @@ AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object
} else {
// no else actually
MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
errcode_ = PARSE_NODE_TYPE_UNKOWN;
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
return nullptr;
}
}
@ -1528,7 +1528,7 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
AnfNodePtr target_node = ParseExprNode(block, targ);
MS_EXCEPTION_IF_NULL(target_node);
std::string attr_name = targ.attr("attr").cast<std::string>();
auto attr_name = targ.attr("attr").cast<std::string>();
std::string var_name = "self." + attr_name;
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
@ -1560,7 +1560,7 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
// Getitem apply should return the sequence data structure itself
std::string var_name;
if (ast_->IsClassMember(value_obj)) {
std::string attr_name = value_obj.attr("attr").cast<std::string>();
auto attr_name = value_obj.attr("attr").cast<std::string>();
var_name = "self." + attr_name;
if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function.";
@ -1675,9 +1675,9 @@ void Parser::RemoveUnnecessaryPhis() {
MS_EXCEPTION_IF_NULL(block);
removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis),
[](std::pair<ParameterPtr, AnfNodePtr> pair) { return pair.first; });
[](const std::pair<ParameterPtr, AnfNodePtr> &pair) { return pair.first; });
}
if (removable_phis.size() == 0) {
if (removable_phis.empty()) {
return;
}
auto fg_name = func_graph_->ToString();
@ -1693,14 +1693,14 @@ void Parser::RemoveUnnecessaryPhis() {
for (FunctionBlockPtr &block : func_block_list_) {
MS_EXCEPTION_IF_NULL(block);
auto &local_removable_phis = block->removable_phis();
if (local_removable_phis.size() == 0) {
if (local_removable_phis.empty()) {
continue;
}
auto func_graph = block->func_graph();
auto &parameters = func_graph->parameters();
std::vector<AnfNodePtr> new_parameters(parameters.size());
auto it = std::copy_if(
parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) {
parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](const AnfNodePtr &param) {
return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
});
@ -1708,7 +1708,7 @@ void Parser::RemoveUnnecessaryPhis() {
new_parameters.resize(std::distance(new_parameters.begin(), it));
func_graph->set_parameters(new_parameters);
}
for (auto fg : mng->func_graphs()) {
for (const auto &fg : mng->func_graphs()) {
fg->ClearAllManagerInfo();
}
}
@ -1812,7 +1812,7 @@ bool ParseAst::IsClassMember(const py::object &node) {
return ret.cast<bool>();
}
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "FuncGraph is null";
return false;
@ -1846,7 +1846,7 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
}
// Generate and copy a ValueNode, or a CNode with its child nodes
static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr &param_node) {
static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(param_node);
if (param_node->isa<ValueNode>()) {
return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());

View File

@ -42,7 +42,7 @@ enum ParseStatusCode : int64_t {
PARSE_PARAMETER_INVALID, // parameter is invalid
PARSE_NO_RETURN, // function no return node
PARSE_NODE_TYPE_NO_MATCH, // ast node type is error
PARSE_NODE_TYPE_UNKOWN, // node type is unkown
PARSE_NODE_TYPE_UNKNOWN, // node type is unknown
PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node
PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string
PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported
@ -54,7 +54,7 @@ enum ParseStatusCode : int64_t {
// NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
// not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make
// the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
// when function call depth execeeds the limit `context.get_context('max_call_depth')`.
// when function call depth exceeds the limit `context.get_context('max_call_depth')`.
const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max();
class AstNodeType;
@ -191,7 +191,7 @@ class Parser {
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
// parse one ast statement node
FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
// parse an ast expresion node
// parse an ast expression node
AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
@ -363,7 +363,7 @@ class ParseAst {
};
// update the graph flags
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph);

View File

@ -30,12 +30,14 @@ def test_use_undefined_var():
self.value = [11, 22, 33, 44]
def construct(self, x):
ret = x + c
ret = x + a
return ret
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'c' is not defined" in str(err.value)
assert "The name 'a' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(33)" in str(err.value)
assert "ret = x + a" in str(err.value)
def test_insert_undefined_var():
@ -45,13 +47,15 @@ def test_insert_undefined_var():
self.value = [11, 22, 33, 44]
def construct(self, x):
c
b
ret = x + x
return ret
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'c' is not defined" in str(err.value)
assert "The name 'b' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(50)" in str(err.value)
assert "b" in str(err.value)
def test_insert_undefined_var_compute():
@ -61,13 +65,125 @@ def test_insert_undefined_var_compute():
self.value = [11, 22, 33, 44]
def construct(self, x):
c + d
c + x
ret = x + x
return ret
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'c' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(68)" in str(err.value)
assert "c + x" in str(err.value)
def test_use_undefined_var_in_for():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [11, 22, 33, 44]
def construct(self, x):
for i in self.value:
x = x + d + i
return x
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'd' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(86)" in str(err.value)
assert "x = x + d + i" in str(err.value)
def test_insert_undefined_var_in_for():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [11, 22, 33, 44]
def construct(self, x):
for i in self.value:
e
x = x + i
return x
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'e' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(104)" in str(err.value)
assert "e" in str(err.value)
def test_insert_undefined_var_compute_for():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [11, 22, 33, 44]
def construct(self, x):
for i in self.value:
f + i
x = x + i
return x
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'f' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(123)" in str(err.value)
assert "f + i" in str(err.value)
def test_use_undefined_var_in_while():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
while x < 0:
x = x - g
return x
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'g' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(141)" in str(err.value)
assert "x = x - g" in str(err.value)
def test_insert_undefined_var_in_while():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [11, 22, 33, 44]
def construct(self, x):
while x < 0:
h
x = x - 1
return x
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'h' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(159)" in str(err.value)
assert "h" in str(err.value)
def test_insert_undefined_var_compute_while():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [11, 22, 33, 44]
def construct(self, x):
while x < 0:
x + i
x = x - 1
return x
net = Net()
with pytest.raises(NameError) as err:
net(Tensor(np.arange(4)))
assert "The name 'i' is not defined" in str(err.value)
assert "tests/ut/python/pipeline/parse/test_use_undefined_var.py(178)" in str(err.value)
assert "x + i" in str(err.value)
def test_insert_defined_var():