fix bug of trace info in resolve process

This commit is contained in:
chenfei 2022-07-29 14:32:33 +08:00
parent 89e3a499b1
commit 1915654910
4 changed files with 126 additions and 21 deletions

View File

@ -32,6 +32,7 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) {
// node is get_attr node
return parse::ResolveSymbolWithAttr(optimizer->manager(), object_node, attr_node, node);
}
// {prim::kPrimGetAttr, namespace, attr}

View File

@ -456,20 +456,21 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
}
// Resolve Cell GetAttr operation.
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
const AnfNodePtr &attr) {
MS_EXCEPTION_IF_NULL(node);
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
const AnfNodePtr &get_attr_node) {
MS_EXCEPTION_IF_NULL(resolve_node);
MS_EXCEPTION_IF_NULL(attr);
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Manager is nullptr.";
}
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", attr: " << attr->ToString();
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceResolve>(get_attr_node->debug_info()));
if (!data_converter::IsCellInstance(obj)) {
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, resolve_node);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
MS_EXCEPTION_IF_NULL(node->func_graph());
AnfNodePtr res_node = node->func_graph()->NewCNode(std::move(inputs));
MS_EXCEPTION_IF_NULL(get_attr_node->func_graph());
AnfNodePtr res_node = get_attr_node->func_graph()->NewCNodeInOrder(std::move(inputs));
TraceManager::ClearParseOrResolveDebugInfo();
return res_node;
}
@ -482,15 +483,15 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
MS_EXCEPTION_IF_NULL(node->func_graph());
AnfNodePtr resolved_node = node->func_graph()->NewCNode(std::move(inputs));
MS_EXCEPTION_IF_NULL(get_attr_node->func_graph());
AnfNodePtr resolved_node = get_attr_node->func_graph()->NewCNodeInOrder(std::move(inputs));
TraceManager::ClearParseOrResolveDebugInfo();
return resolved_node;
}
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
const CNodePtr &operand_cnode) {
const CNodePtr &get_attr_node) {
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto sequence = obj.cast<py::sequence>();
@ -508,14 +509,14 @@ AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py:
if (count_cell == sequence_size) {
// Resolve Cell instances.
for (size_t i = 0; i < sequence_size; ++i) {
auto res = ResolveCellWithAttr(manager, sequence[i], resolve_node, attr);
auto res = ResolveCellWithAttr(manager, sequence[i], resolve_node, attr, get_attr_node);
inputs.emplace_back(res);
}
} else if (count_msclass == sequence_size) {
// Resolve MsClass instances.
for (size_t i = 0; i < sequence_size; ++i) {
auto attr_str = GetValue<std::string>(GetValueNode(attr));
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, operand_cnode);
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, get_attr_node);
(void)inputs.emplace_back(res);
}
} else {
@ -524,33 +525,33 @@ AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py:
constexpr auto prim_index = 0;
constexpr auto index_index = 2;
auto fg = operand_cnode->func_graph();
auto fg = get_attr_node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto make_tuple_node = fg->NewCNodeInOrder(inputs);
return fg->NewCNodeInOrder({operand_cnode->input(prim_index), make_tuple_node, operand_cnode->input(index_index)});
return fg->NewCNodeInOrder({get_attr_node->input(prim_index), make_tuple_node, get_attr_node->input(index_index)});
}
AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node) {
const AnfNodePtr &attr_node, const AnfNodePtr &get_attr_node) {
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
auto [name_space, symbol] = GetNamespaceAndSymbol(object_node);
auto module_name = name_space->module();
constexpr std::string_view parse_super_name = "namespace";
if (module_name.find(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
symbol->symbol() != parse_super_name) {
auto symbol_obj = GetSymbolObject(name_space, symbol, node);
return ResolveCellWithAttr(manager, symbol_obj, object_node, attr_node);
auto symbol_obj = GetSymbolObject(name_space, symbol, get_attr_node);
return ResolveCellWithAttr(manager, symbol_obj, object_node, attr_node, get_attr_node);
}
return nullptr;
}
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &cls_obj,
const std::string &attr, const AnfNodePtr &node) {
const std::string &attr, const AnfNodePtr &get_attr_node) {
// Get attribute or method from ms_class obj.
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(get_attr_node);
MS_LOG(DEBUG) << "Resolve ms_class obj (" << py::str(cls_obj) << ") with attr " << attr << ".";
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
TraceGuard trace_guard(std::make_shared<TraceResolve>(get_attr_node->debug_info()));
constexpr size_t prefix_index = 0;
if (attr.size() > 0 && attr[prefix_index] == '_') {
@ -560,7 +561,7 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::
MS_LOG(EXCEPTION) << py::str(cls_obj) << " has not attribute: " << attr << ".";
}
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, get_attr_node);
TraceManager::ClearParseOrResolveDebugInfo();
return res_node;
}

View File

@ -0,0 +1,52 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import nn
from mindspore import ops
from mindspore import Tensor
import numpy as np
class ArgsPares:
def __init__(self):
self.tt1 = 1
class Conv2dMean(nn.Cell):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=1)
self.mean = ops.ReduceMean(keep_dims=False)
self.relu = ops.ReLU()
self.y = ArgsPares()
def construct(self, x):
x = self.relu(x)
for _ in range(3):
x = self.y.tt1
x = self.conv1(x)
x = self.mean(x, (2, 3))
return x
def test_catch_exception_of_get_outer_class_attr():
"""
Feature: Resolve.
Description: execute this testcase to raise a exception, and print code stack info
for testcase:test_check_for_body_get_outer_class_attr_log.py::test_catch_exception_stack_trace_log
Expectation: raise exception with expected code stack info.
"""
x = Tensor(np.ones((3, 32, 32)).astype(np.float32))
Conv2dMean()(x)

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_catch_exception_stack_trace_log():
"""
Feature: Resolve.
Description: execute the testcase 'for_body_get_outer_class_attr.py::test_catch_exception_of_get_outer_class_attr'
and check the log info.
Expectation: the error code exist in log info.
"""
file_name = "for_body_get_outer_class_attr.py"
log_file_name = "for_body_get_outer_class_attr.log"
function_name = "::test_catch_exception_of_get_outer_class_attr"
_cur_dir = os.path.dirname(os.path.realpath(__file__))
file_name = os.path.join(_cur_dir, file_name)
assert os.path.exists(file_name)
log_file_name = os.path.join(_cur_dir, log_file_name)
if os.path.exists(log_file_name):
os.remove(log_file_name)
assert not os.path.exists(log_file_name)
cmd_first = f"GLOG_v=2 pytest -s " + file_name + function_name + " > " + log_file_name + " 2>&1"
out = os.popen(cmd_first)
out.read()
assert os.path.exists(log_file_name)
with open(log_file_name, "r") as f_first:
data_first = f_first.read()
assert "Not supported to get attribute" in data_first
assert "x = self.y.tt1" in data_first
# Clean files
os.remove(log_file_name)