!28445 [ME][Fallback] Add the error code line information in validate.

Merge pull request !28445 from Margaret_wangrui/validate_location_2
This commit is contained in:
i-robot 2021-12-31 01:58:53 +00:00 committed by Gitee
commit 78f1e87ba5
6 changed files with 37 additions and 3 deletions

View File

@ -612,6 +612,10 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
<< ", state: " << state->DebugString(recursive_level);
func_graph_->set_output(depend_node, true);
if (return_node && return_node->debug_info()) {
auto new_return = func_graph_->get_return();
new_return->set_debug_info(return_node->debug_info());
}
}
void FunctionBlock::SetAsDeadBlock() { is_dead_block_ = true; }

View File

@ -1445,6 +1445,7 @@ class AutoMonadConverter {
void AttachToOutput(const AnfNodePtr &node) const {
auto output = GetGraphOutput();
TraceGuard guard(std::make_shared<TraceCopy>(output->debug_info()));
auto depend = NewValueNode(prim::kPrimDepend);
// If isolated nodes dependencies exist.
if (IsPrimitiveCNode(output, prim::kPrimDepend) &&

View File

@ -294,6 +294,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
ScopeGuard scope_guard(node->scope());
AnfNodeConfigPtr conf = MakeConfig(node);
TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
AnfNodePtr new_node = GetReplicatedNode(node);
MS_EXCEPTION_IF_NULL(new_node);
if (new_node->func_graph() != specialized_func_graph_) {

View File

@ -25,6 +25,7 @@
#include "ir/dtype.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "pipeline/jit/parse/resolve.h"
#include "debug/trace.h"
namespace mindspore {
namespace validator {
@ -138,6 +139,7 @@ void ValidateValueNode(const AnfNodePtr &node) {
if (IsValueNode<parse::InterpretedObject>(node)) {
MS_LOG(EXCEPTION)
<< "Should not use Python object in runtime, node: " << node->DebugString()
<< ". \nLine: " << trace::GetDebugInfo(node->debug_info())
<< "\n\nWe suppose all nodes generated by JIT Fallback would not return to outside of graph. "
<< "For more information about JIT Fallback, please refer to the FAQ at https://www.mindspore.cn.";
}
@ -163,7 +165,8 @@ void Validate(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(mgr);
AnfNodeSet &all_nodes = mgr->all_nodes();
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
while (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
node = node->cast<CNodePtr>()->input(1);
}
if (IsValueNode<ValueTuple>(node)) {

View File

@ -4187,7 +4187,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
>>> from mindspore.common.tensor import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
... def __init__(self):
... super().__init__()
... self.alloc_status = P.NPUAllocFloatStatus()
... self.get_status = P.NPUGetFloatStatus()
@ -4261,7 +4261,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
>>> from mindspore.common.tensor import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
... def __init__(self):
... super().__init__()
... self.alloc_status = P.NPUAllocFloatStatus()
... self.get_status = P.NPUGetFloatStatus()

View File

@ -200,3 +200,28 @@ def test_print_validate_tuple():
print("res1: ", res1)
print("res2: ", res2)
assert "Should not use Python object in runtime" in str(err.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_print_validate():
"""
Feature: JIT Fallback
Description: Support print.
Expectation: No exception.
"""
@ms_function
def print_func():
np_x = np.array([1, 2, 3, 4, 5])
np_y = np.array([1, 2, 3, 4, 5])
np_sum = np_x + np_y
print("np_sum: ", np_sum)
return np_sum
with pytest.raises(RuntimeError) as err:
res = print_func()
print("res: ", res)
assert "Should not use Python object in runtime" in str(err.value)