add find error operator in source code tool.

This commit is contained in:
jzg 2020-10-09 14:57:15 +08:00
parent fa24b35043
commit fa3c1d8980
2 changed files with 492 additions and 0 deletions

View File

@ -0,0 +1,104 @@
# 算子报错定位对应脚本源码
## 文档功能与适用场景
在MindSpore进行计算调试遇到算子报错时用户希望能够通过报错信息查找到对应的Python源码。现有的算子编译(PreCompileProcessFailed)和执行失败(run task error)只是将算子名字打印到终端无法直接指出用户代码中算子调用处的行号。本文的主要目的为指导用户通过算子的报错信息然后在ANF图文件里找到对应算子的源码。
此指导文档适合运行在 **Ascend硬件** 环境下的计算,且忽略算子融合场景和算子反向报错的情况。
## 解析流程与辅助工具使用
1. 查找对应源码行号的主要解析流程为:
① 获取报错算子的全称full_name(full_name由scope和算子名组成scope名会以Default或Gradients开头Default表示正向Gradients表示反向这里我们只关注Default的情况)。
② 获取报错算子的输入输出属性等信息。
③ 通过full_name和输入输出信息在[0-10]_py_pre_ad.dat(优先)或[0-10]_validate.dat文件中找到对应的算子和对应的代码行号。
2. 算子报错时使用脚本的3步操作
① 用户在训练脚本里设置context.set_context(mode=context.GRAPH_MODE, save_graphs=True),进行图文件的保存。
② 用户在执行代码时将日志信息定向到文件中使用方法为python xxx.py &> log_name &
③ 通过辅助脚本解析报错算子全称、输入输出信息,同时尝试查找对应代码行号。
  脚本名: **find_error_operator_in_code.py**
  执行方式:
```
python3 find_error_operator_in_code.py
--log_path [the path of log, default is the current path](option)
--log_name [the file name of log](required, log_name)
```
3. 解析效果
解析文件时通常有3种情况
① 全匹配(算子名称、输入输出的shape与dtype都匹配时)
```
[INFO] Detect "task exception error". 【检测报错类型】
[INFO] Find operation 1 times! 【在python源码中查找匹配到的次数】
[INFO] In file test_mul.py(29)/ return self.mul(x, y) 【查找到的文件名、行号与该行号的代码内容】
[INFO] Exception operator is "Default/Mul". 【出错的算子】
[INFO] Have 2 input in operator: 【该出错算子有2个输入】
input 1/1th: dtype is float16, shape is [526338, 21]. 【第一个输入的dtype和shape1/1th表示第一个输入的第1个值】
input 1/2th: dtype is float16, shape is [526338, 1]. 【第二个输入的dtype和shape】
[INFO] Have 1 output in operator: 【该出错算子有1个输出】
output 1/1th: dtype is float16, shape is [526338, 21]. 【第一个输出的dtype和shape1/1th表示第一个输出的第1个值】
```
② 单匹配(算子名称匹配、但只匹配输入或输出的shape与dtype都匹配时):
```
[INFO] Detect "compile error". 【检测报错类型】
[INFO] Find operation 1 times! 【在python源码中查找匹配到的次数】
[INFO] In file split_ops.py(17)/ return self.net(input) 【查找到的文件名、行号与该行号的代码内容】
[INFO] Exception operator is "Default/Split". 【出错的算子】
[INFO] Have 1 input in operator: 【该出错算子有1个输入】
input 1/1th: dtype is float32, shape is [32, 192, 56, 56]. 【第一个输入的dtype和shape1/1th表示第一个输入的第1个值】
[WARNING] Cannot match output information! Please check whether the operator's output is: 【输出未完全匹配告警】
[INFO] Have 1 output in operator:             【该出错算子有1个输出】
output 1/1th: dtype is float32, shape is [32, 6, 56, 56].            【第一个输出的dtype和shape1/1th表示第一个输出的第1个值】
output 2/1th: dtype is float32, shape is [32, 6, 56, 56].            【第一个输出的dtype和shape2/1th表示第一个输出的第2个值】
output 3/1th: dtype is float32, shape is [32, 6, 56, 56].            【第一个输出的dtype和shape3/1th表示第一个输出的第3个值】
```
③ 未匹配(未匹配到算子名称,或者匹配算子名称但输入输出未匹配时):
```
[INFO] Detect "task exception error". 【检测报错类型】
[WARNING] Cannot find operation! Need to find in the script based on the following information: 【未在源码中匹配到算子告警】
[INFO] Exception operator full name is "Default/test". 【出错的算子】
[INFO] Have 2 input in operator: 【该出错算子有2个输入】
input 1/1th: dtype is float16, shape is [526338, 21]. 【第一个输入的dtype和shape1/1th表示第一个输入的第1个值】
input 1/2th: dtype is float16, shape is [526338, 1]. 【第二个输入的dtype和shape】
[INFO] Have 1 output in operator:     【该出错算子有1个输出】
output 1/1th: dtype is float16, shape is [526338, 21]. 【第一个输出的dtype和shape1/1th表示第一个输出的第1个值】
[WARNING] Do you want to research in source code? set source code path to research or press enter to research in current path, input n/no to exit.
Input: 【未匹配到代码时会提示是否在脚本中进行搜索n/no表示不搜索传入代码路径在指定路径搜索回车默认表示当前路径搜索】
```
4. 手动代码查找
这里还会存在些特殊情况例如用户在源码中调用框架提供的nn.cell网络层会发现查找出来的为框架里的代码行号。此时用户需要利用工具查找出的full_name和输入输出信息回到源码中进行对应代码的查找。
举个例子说明如何手动在代码中查找指定full_name和shape的算子例如full_name为: Default/network/network/aspp/aspp_pooling/ResizeNearestNeighbor输入的shape为[8, 256, 1, 1] dtype为float32。
可以观察到其scope为: Default/network/network/aspp/aspp_pooling算子名为: ResizeNearestNeighbor。注意scope中会存在Default、network自动填充Default表示正向network为网络名。
查看以下用户定义的代码首先我们先分析scope: Default/network/network/aspp/aspp_pooling。由network/aspp可定位到算子的定义与调用处分别为26行与31行继续由network/aspp/aspp_pooling可以定位到定义与调用处分别为4行与8行然后通过算子名ResizeNearestNeighbor可以定位至定义与调用处分别为16行与19行。最后若存在相同scope下存在相同的算子名时需要通过输入的shape和dtype进行进一步判断。
```
1 class ASPP(nn.Cell):
2 def __init__(self):
3 super(ASPP, self).__init__()
4 self.aspp_pooling = ASPPPooling()
5 self.drop = nn.Dropout(0.3)
6
7 def construct(self, x):
8 x5 = self.aspp_pooling(x)
9 x = self.drop(x)
10 return x
11
12 class ASPPPooling(nn.Cell):
13 def __init__(self):
14 super(ASPPPooling, self).__init__()
15 self.shape = P.Shape()
16 self.resizenearestneighbor = P.ResizeNearestNeighbor((size[2], size[3]), True)
17 def construct(self, x):
18 size = self.shape(x)
19 out = self.resizenearestneighbor(out)
20 return out
21
22 # 主结构
23 class DeepLabV3(nn.Cell):
24 def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False):
25 super(DeepLabV3, self).__init__()
26 self.aspp = ASPP()
27 self.shape = P.Shape()
28
29 def construct(self, x):
30 size = self.shape(x)
31 out = self.aspp(out)
32 return out
```

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python3
# coding=UTF-8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""find error operator in code"""
import os
import json
import argparse
# step 1: get error_full_name, error_kernel_name, error_context, and trace_back_log.
def get_log_error_and_traceback(log_file):
"""
get error_full_name, error_kernel_name, error_context, and trace_back_log.
"""
with open(log_file, 'r+') as log:
log_context_list = log.readlines()
# get error log, error full_name, kernel_name, stream_id and task_id in log
task_stream_log = []
error_context = ""
error_stream_id = ""
error_task_id = ""
error_kernel_name = ""
for log_context in log_context_list:
if "stream_id" in log_context and "task_id" in log_context:
task_stream_log.append(log_context)
if "[ERROR]" not in log_context:
continue
error_context = error_context + log_context
if "fault kernel_name" in log_context and "stream_id" in log_context and "task_id" in log_context:
error_stream_id = log_context.split("stream_id")[-1].split(",")[0].split("=")[-1]
error_task_id = log_context.split("task_id")[-1].split(",")[0].split("=")[-1]
error_kernel_name = log_context.split("kernel_name")[-1].split(",")[0].split("=")[-1]
# get full name
error_full_name = ""
for task_stream in task_stream_log:
if "[ERROR]" in task_stream:
continue
if "stream_id:" + error_stream_id in task_stream and "task_id:" + error_task_id in task_stream:
error_full_name = task_stream.split("ask name:")[-1].split(" ")[0]
if error_full_name == "":
exit("[ERROR] Find full name in task exception need to "
"set the log level to warning at least [export GLOG_v=2].")
# get Traceback log
log_context_str = "".join(log_context_list)
trace_back_log = "Traceback" + log_context_str.split("Traceback")[-1]
return error_full_name, error_kernel_name, error_context, trace_back_log
# step 2.1: get error operator shape and dytpe in kernel_meta.
def get_shape_dtyep_in_kernel_meta(log_file_path, kernel_meta):
"""
get error operator shape and dytpe in kernel_meta.
"""
error_full_name, error_kernel_name, _, _ = get_log_error_and_traceback(log_file_path)
# the hash value of error kernel name
oper_info_dict = {}
error_kernel_name_file = error_kernel_name.split("__kernel")[0] + ".info"
error_kernel_name_file_path = os.path.join(kernel_meta, error_kernel_name_file)
with open(error_kernel_name_file_path, 'r+') as kernel_info:
kernel_info_str = kernel_info.readline()
kernel_info_json = json.loads(kernel_info_str)
oper_info_dict["full_name"] = error_full_name
oper_info_dict["inputs"] = kernel_info_json["inputs"]
oper_info_dict["input_num"] = len(kernel_info_json["inputs"])
oper_info_dict["outputs"] = kernel_info_json["outputs"]
oper_info_dict["output_num"] = len(oper_info_dict["outputs"])
return oper_info_dict
# step 2.2: get error operator shape and dytpe in tbe compile.
def get_shape_dtyep_in_tbe_compile(tbe_compile_log):
"""
get error operator shape and dytpe in tbe compile.
"""
oper_info = ""
oper_info_dict = {}
with open(tbe_compile_log, 'r+') as log:
log_context_list = log.readlines()
for compile_log in log_context_list:
if compile_log.startswith("input_args:"):
oper_info = compile_log.split("input_args:")[-1]
oper_info_json = json.loads(oper_info)
oper_info_dict["full_name"] = oper_info_json["full_name"]
oper_info_dict["inputs"] = oper_info_json["op_info"]["inputs"]
oper_info_dict["input_num"] = len(oper_info_dict["inputs"])
oper_info_dict["outputs"] = oper_info_json["op_info"]["outputs"]
oper_info_dict["output_num"] = len(oper_info_dict["outputs"])
return oper_info_dict
# step 3: show full name, input_ouput_atts info and the line in source code.
def read_all_source_code(full_name):
"""
research operation in source code
"""
code_path = input("[WARNING] Do you want to research in source code? set source code path to research or "
"press enter to research in current path, input n/no to exit.\nInput: ")
if code_path.lower() == "n" or code_path.lower() == "no":
exit("[INFO] Exit.")
code_path = code_path
if code_path == '':
code_path = os.path.abspath(os.path.dirname(__file__))
print("[INFO] Research in {}".format(code_path))
current_script_file = os.path.basename(__file__)
last_scope = full_name.split("/")[-2]
oper_name = full_name.split("/")[-1]
find_name = last_scope if last_scope not in ("Default", "network") else oper_name
source_code_file = []
for code_path, _, file_list in os.walk(code_path):
for file_name in file_list:
if file_name.endswith(".py") and file_name != current_script_file:
code_file_path = os.path.join(code_path, file_name)
source_code_file.append(code_file_path)
find_code_flag = False
for code_file_path in source_code_file:
with open(code_file_path, 'r+') as code:
code_context_list = code.readlines()
for line, code_context in enumerate(code_context_list):
if find_name in code_context:
find_code_flag = True
print("[INFO] find in {}\nLine: {}\nCode: {}\n".format(code_file_path, line + 1, code_context))
if not find_code_flag:
exit("[WARNING] Cannot find error operation in source code!")
def dtype_map(tensor_dtype, oper_name):
"""
type map dict
"""
dtype_dict = {"F8": "float8", "F16": "float16", "F32": "float32", "F64": "float64",
"I8": "int8", "I16": "int16", "I32": "int32", "I64": "int64",
"U8": "uint8", "U16": "uint16", "U32": "uint32", "U64": "uint64",
"TypeType": "TypeType", "Bool": "bool"}
if tensor_dtype not in dtype_dict:
if len(tensor_dtype) <= 10:
print("[WARNING] Type convert fail! {} not in dtype map in operator {}.".format(tensor_dtype, oper_name))
return tensor_dtype
return dtype_dict[tensor_dtype]
def operator_check(operator_name):
"""
check operator name
"""
block_list = ["tuple_getitem", "make_tuple", "Depend", "GetNext"]
return operator_name in block_list
def convert_validate_operator_info(input_tensor_str, oper_name, val_context):
"""
convert operator info of validate.dat file
"""
tensor_dtype, tensor_shape = "", ""
if "Tuple" in input_tensor_str:
pass
elif "Ref[" in input_tensor_str:
tensor_dtype = dtype_map(input_tensor_str.split("sor(")[-1].split(")][")[0], oper_name)
tensor_shape = "[" + input_tensor_str.split(")][")[-1].split("]")[0] + "]"
else:
tensor_dtype = dtype_map(input_tensor_str.split("sor(")[-1].split(")[")[0], oper_name)
tensor_shape = "[" + input_tensor_str.split(")[")[-1].split("]")[0] + "]"
if tensor_dtype == "TypeType":
tensor_dtype = dtype_map(val_context.split("DstT=")[-1].split(", ")[0], oper_name)
return tensor_dtype, tensor_shape
def get_operator_info(val_context, input_info_dict):
"""
get full name, input shape and type from validate file
"""
if "PrimitivePy::" in val_context:
oper_name = val_context.split("PrimitivePy::")[-1].split("@")[0].strip()
elif "Primitive" in val_context:
oper_name = val_context.split("Primitive::")[-1].split("(%")[0].strip()
scope_name = val_context.split("#scope:")[-1].split("#")[0].strip()
input_info = val_context.split("#scope:")[0].split("#")[-1].strip()
full_name = scope_name + "/" + oper_name
if "Tuple" in input_info:
input_info_dict["tuple_input"] = input_info
input_info_dict["tuple_input_name"] = full_name
input_info_list = input_info.lstrip("(").rstrip(")").split(", T")
input_info_dict["input_num"] = len(input_info_list)
output_info = val_context.split(" = Primitive")[0].split(":")[-1].strip()
if "Tuple" in output_info:
input_info_dict["tuple_output"] = output_info
input_info_dict["tuple_output_name"] = full_name
output_info_list = output_info.lstrip("(").rstrip(")").split(", T")
input_info_dict["output_num"] = len(output_info_list)
if not operator_check(oper_name):
for i, input_tensor_str in enumerate(input_info_list):
if i > 0: input_tensor_str = "T" + input_tensor_str
tensor_dtype, tensor_shape = convert_validate_operator_info(input_tensor_str, oper_name, val_context)
input_info_dict["input_" + str(i)] = [tensor_dtype, tensor_shape]
for i, output_tensor_str in enumerate(output_info_list):
if i > 0: output_tensor_str = "T" + output_tensor_str
tensor_dtype, tensor_shape = convert_validate_operator_info(output_tensor_str, oper_name, val_context)
input_info_dict["output_" + str(i)] = [tensor_dtype, tensor_shape]
input_info_dict["full_name"] = full_name
return input_info_dict
def py_pre_ad_parse(val_file):
"""
parse the validate file and get code line
"""
oper_start_info_num = 0
oper_end_char_flag = False
oper_start_char_flag = False
input_dict = {}
with open(val_file, 'r+') as val:
val_context_list = val.readlines()
input_info_dict = {}
for val_context in val_context_list:
val_context = val_context.strip()
if val_context.startswith(") {"):
oper_start_char_flag = True
oper_end_char_flag = False
if val_context == "}":
oper_end_char_flag = True
oper_start_info_flag = False
if val_context.startswith("%") and val_context[1].isdigit():
input_info_dict = {}
oper_start_info_num += 1
oper_start_info_flag = True
if oper_start_char_flag and not oper_end_char_flag:
if oper_start_info_num >= 1 and oper_start_info_flag:
if "PrimitivePy::" in val_context or "Primitive" in val_context:
input_info_dict = get_operator_info(val_context, input_info_dict)
if "In file" not in val_context or "in_file" in input_info_dict:
continue
input_info_dict["in_file"] = val_context.split("#")[-1].strip().rstrip("/")
input_dict[oper_start_info_num] = input_info_dict
return input_dict
def show_input_output_info(kernel_info_json, show_name):
"""
print input or ouput info to stdout
"""
if show_name + "_num" in kernel_info_json:
print("[INFO] Have {} {} in operator:".format(kernel_info_json[show_name + "_num"], show_name))
for i in range(kernel_info_json[show_name + "_num"]):
for j in range(len(kernel_info_json[show_name + "s"][i])):
if "ori_shape" not in kernel_info_json[show_name + "s"][i][j] and "dtype" not in \
kernel_info_json[show_name + "s"][i][j]:
print(" {} {}/{}th: {}.".format(show_name, j + 1, i + 1,
kernel_info_json[show_name + "s"][i][j]))
else:
kernel_dtype = kernel_info_json[show_name + "s"][i][j].setdefault("dtype", None)
kernel_shape = str(kernel_info_json[show_name + "s"][i][j].setdefault("ori_shape", None))
print(" {} {}/{}th: dtype is {}, shape is {}.".format(show_name, j + 1, i + 1,
kernel_dtype, kernel_shape))
def show_error_operator_info(val_file, kernel_info_json):
"""
find error operator
"""
error_full_name = kernel_info_json["full_name"].split("-op")[0]
input_dict = py_pre_ad_parse(val_file)
find_num = 0
for _, oper_info in input_dict.items():
if oper_info["full_name"] != error_full_name:
continue
if kernel_info_json["input_num"] > oper_info["input_num"] \
or kernel_info_json["output_num"] > oper_info["output_num"]:
continue
find_oper_input = True
for i in range(kernel_info_json["input_num"]):
for j in range(len(kernel_info_json["inputs"][i])):
kernel_dtype = kernel_info_json["inputs"][i][j].setdefault("dtype", None)
kernel_shape = str(kernel_info_json["inputs"][i][j].setdefault("ori_shape", None))
validate_dtype = oper_info["input_" + str(i)][0]
validate_shape = oper_info["input_" + str(i)][1]
if kernel_shape != validate_shape or kernel_dtype != validate_dtype:
find_oper_input = False
find_oper_output = True
for i in range(kernel_info_json["output_num"]):
for j in range(len(kernel_info_json["outputs"][i])):
kernel_dtype = kernel_info_json["outputs"][i][j].setdefault("dtype", None)
kernel_shape = str(kernel_info_json["outputs"][i][j].setdefault("ori_shape", None))
validate_dtype = oper_info["output_" + str(i)][0]
validate_shape = oper_info["output_" + str(i)][1]
if kernel_shape != validate_shape or kernel_dtype != validate_dtype:
find_oper_output = False
if find_oper_input or find_oper_output:
find_num += 1
print("[INFO] Find operation {} times!".format(find_num))
print("[INFO] {}".format(oper_info["in_file"]))
print("[INFO] Exception operator is \"{}\".".format(error_full_name))
if not find_oper_input:
print("[WARNING] Cannot match input information! Please check whether the operator's input is:")
show_input_output_info(kernel_info_json, "input")
if not find_oper_output:
print("[WARNING] Cannot match output information! Please check whether the operator's output is:")
show_input_output_info(kernel_info_json, "output")
if find_num == 0:
print("[WARNING] Cannot find operation! Need to find in the script based on the following information:")
print("[INFO] Exception operator full name is \"{}\".".format(error_full_name))
show_input_output_info(kernel_info_json, "input")
show_input_output_info(kernel_info_json, "output")
read_all_source_code(error_full_name)
def get_task_type(log_file):
"""
get the error type
"""
with open(log_file, 'r+') as log:
log_context_list = log.readlines()
for log_context in log_context_list:
if "PreCompileProcessFailed" in log_context:
return "compile"
if "run task error" in log_context:
return "exception"
return None
def start_find(log_file_path, code_line_file_path, kernel_meta_path):
"""
start find error operation in code.
"""
task = get_task_type(log_file_path)
if task == "exception":
print("[INFO] Detect \"task exception error\".")
kernel_info_json = get_shape_dtyep_in_kernel_meta(log_file_path, kernel_meta_path)
show_error_operator_info(code_line_file_path, kernel_info_json)
elif task == "compile":
print("[INFO] Detect \"compile error\".")
the_kernel_json = get_shape_dtyep_in_tbe_compile(log_file_path)
show_error_operator_info(code_line_file_path, the_kernel_json)
else:
exit("[ERROR] Currently only support task exception or tbe compile error!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Show error info')
parser.add_argument('--log_path', '-p', type=str.lower, default="", help='log file path (option)')
parser.add_argument('--log_name', '-n', type=str.lower, default="", required=True, help='log file name')
args_opt = parser.parse_args()
current_path = args_opt.log_path
if current_path == "":
print("[WARNING] No log_path parameter, use current path as log path.")
current_path = os.path.abspath(os.path.dirname(__file__))
log_name_file = os.path.join(current_path, args_opt.log_name)
kernel_meta_file = os.path.join(current_path, "kernel_meta")
code_line_file = ""
for filename in os.listdir(os.path.join(current_path)):
if "_validate.dat" in filename:
code_line_file = filename
if "_py_pre_ad.dat" in filename:
code_line_file = filename
if code_line_file == "":
exit("[ERROR] Please set \"save_graphs=True\" in context to save py_pre_ad.dat or validate.dat file!")
code_line_file = os.path.join(current_path, code_line_file)
start_find(log_name_file, code_line_file, kernel_meta_file)