forked from mindspore-Ecosystem/mindspore
clean up codecheck warnings by deleting unused files
This commit is contained in:
parent
b0e9132745
commit
a63e823b62
|
@ -1,315 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Function:
|
||||
Use to control the federated learning cluster
|
||||
Usage:
|
||||
python fl_restful_tool.py [http_type] [ip] [port] [request_name] [server_num] [instance_param] [metrics_file_path]
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from enum import Enum
|
||||
import requests
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
"""
|
||||
Response Status
|
||||
"""
|
||||
SUCCESS = "0"
|
||||
FAILED = "1"
|
||||
|
||||
|
||||
class Restful(Enum):
|
||||
"""
|
||||
Define restful interface constant
|
||||
"""
|
||||
SCALE = "scale"
|
||||
SCALE_OUT = "scaleout"
|
||||
SCALE_IN = "scalein"
|
||||
NODES = "nodes"
|
||||
GET_INSTANCE_DETAIL = "getInstanceDetail"
|
||||
NEW_INSTANCE = "newInstance"
|
||||
QUERY_INSTANCE = "queryInstance"
|
||||
ENABLE_FLS = "enableFLS"
|
||||
DISABLE_FLS = "disableFLS"
|
||||
STATE = "state"
|
||||
SCALE_OUT_ROLLBACK = "scaleoutRollback"
|
||||
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--http_type", type=str, default="http", help="http or https")
|
||||
parser.add_argument("--ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=6666)
|
||||
parser.add_argument("--request_name", type=str, default="")
|
||||
|
||||
parser.add_argument("--server_num", type=int, default=0)
|
||||
parser.add_argument("--instance_param", type=str, default="")
|
||||
parser.add_argument("--metrics_file_path", type=str, default="/opt/huawei/mindspore/hybrid_albert/metrics.json")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
http_type = args.http_type
|
||||
ip = args.ip
|
||||
port = args.port
|
||||
request_name = args.request_name
|
||||
server_num = args.server_num
|
||||
instance_param = args.instance_param
|
||||
metrics_file_path = args.metrics_file_path
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
session = requests.Session()
|
||||
BASE_URL = '{}://{}:{}/'.format(http_type, ip, str(port))
|
||||
|
||||
|
||||
def call_scale():
|
||||
"""
|
||||
call cluster scale out or scale in
|
||||
"""
|
||||
if server_num == 0:
|
||||
return process_self_define_json(Status.FAILED.value, "error. server_num is 0")
|
||||
|
||||
node_ids = json.loads(call_nodes())["result"]
|
||||
cluster_abstract_node_num = len(node_ids)
|
||||
if cluster_abstract_node_num == 0:
|
||||
return process_self_define_json(Status.FAILED.value, "error. cluster abstract node num is 0")
|
||||
|
||||
cluster_server_node_num = 0
|
||||
cluster_worker_node_num = 0
|
||||
cluster_server_node_base_name = ''
|
||||
for i in range(0, cluster_abstract_node_num):
|
||||
if node_ids[i]['role'] == 'WORKER':
|
||||
cluster_worker_node_num = cluster_worker_node_num + 1
|
||||
elif node_ids[i]['role'] == 'SERVER':
|
||||
cluster_server_node_num = cluster_server_node_num + 1
|
||||
cluster_server_node_name = str(node_ids[i]['nodeId'])
|
||||
index = cluster_server_node_name.rindex('-')
|
||||
cluster_server_node_base_name = cluster_server_node_name[0:index]
|
||||
else:
|
||||
pass
|
||||
if cluster_server_node_num == server_num:
|
||||
return process_self_define_json(Status.FAILED.value, "error. cluster server num is same with server_num.")
|
||||
if cluster_server_node_num > server_num:
|
||||
scale_in_len = cluster_server_node_num - server_num
|
||||
scale_in_node_ids = []
|
||||
for index in range(cluster_server_node_num - scale_in_len, cluster_server_node_num):
|
||||
scale_in_node_name = cluster_server_node_base_name + "-" + str(index)
|
||||
scale_in_node_ids.append(scale_in_node_name)
|
||||
return call_scalein(scale_in_node_ids)
|
||||
return call_scaleout(server_num - cluster_server_node_num)
|
||||
|
||||
|
||||
def call_scaleout(scale_out_server_num, scale_out_worker_num=0):
|
||||
"""
|
||||
call scaleout
|
||||
"""
|
||||
url = BASE_URL + Restful.SCALE_OUT.value
|
||||
data = {"server_num": scale_out_server_num, "worker_num": scale_out_worker_num}
|
||||
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
|
||||
result = "scale out server num is " + str(scale_out_server_num)
|
||||
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
|
||||
|
||||
|
||||
def call_scaleout_rollback():
|
||||
"""
|
||||
call scaleout rollback
|
||||
"""
|
||||
url = BASE_URL + Restful.SCALE_OUT_ROLLBACK.value
|
||||
res = session.get(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
||||
|
||||
|
||||
def call_scalein(scale_in_node_ids):
|
||||
"""
|
||||
call cluster to scale in
|
||||
"""
|
||||
if not scale_in_node_ids:
|
||||
return process_self_define_json(Status.FAILED.value, "error. node ids is empty.")
|
||||
|
||||
url = BASE_URL + Restful.SCALE_IN.value
|
||||
data = {"node_ids": scale_in_node_ids}
|
||||
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
result = "scale in node ids is " + str(scale_in_node_ids)
|
||||
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
|
||||
|
||||
|
||||
def call_nodes():
|
||||
"""
|
||||
get nodes info
|
||||
"""
|
||||
url = BASE_URL + Restful.NODES.value
|
||||
res = session.get(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["nodeIds"])
|
||||
|
||||
|
||||
def call_get_instance_detail():
|
||||
"""
|
||||
get cluster instance detail
|
||||
"""
|
||||
if not os.path.exists(metrics_file_path):
|
||||
return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
|
||||
|
||||
ans_json_obj = {}
|
||||
|
||||
with open(metrics_file_path, 'r') as f:
|
||||
metrics_list = f.readlines()
|
||||
|
||||
if not metrics_list:
|
||||
return process_self_define_json(Status.FAILED.value, "error. metrics file has no content")
|
||||
|
||||
|
||||
last_metrics = metrics_list[len(metrics_list) - 1]
|
||||
last_metrics_obj = json.loads(last_metrics)
|
||||
|
||||
ans_json_obj["code"] = Status.SUCCESS.value
|
||||
ans_json_obj["describe"] = "get instance metrics detail successful."
|
||||
ans_json_obj["result"] = last_metrics_obj
|
||||
|
||||
return json.dumps(ans_json_obj)
|
||||
|
||||
|
||||
def call_new_instance():
|
||||
"""
|
||||
call cluster new instance
|
||||
"""
|
||||
if instance_param == "":
|
||||
return process_self_define_json(Status.FAILED.value, "error. instance_param is empty.")
|
||||
instance_param_list = instance_param.split(sep=",")
|
||||
instance_param_json_obj = {}
|
||||
|
||||
url = BASE_URL + Restful.NEW_INSTANCE.value
|
||||
for cur in instance_param_list:
|
||||
pair = cur.split(sep="=")
|
||||
instance_param_json_obj[pair[0]] = float(pair[1])
|
||||
|
||||
data = json.dumps(instance_param_json_obj)
|
||||
res = session.post(url, verify=False, data=data)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
||||
|
||||
|
||||
def call_query_instance():
|
||||
"""
|
||||
query cluster instance
|
||||
"""
|
||||
url = BASE_URL + Restful.QUERY_INSTANCE.value
|
||||
res = session.post(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["result"])
|
||||
|
||||
|
||||
def call_enable_fls():
|
||||
"""
|
||||
enable cluster fls
|
||||
"""
|
||||
url = BASE_URL + Restful.ENABLE_FLS.value
|
||||
res = session.post(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
||||
|
||||
|
||||
def call_disable_fls():
|
||||
"""
|
||||
disable cluster fls
|
||||
"""
|
||||
url = BASE_URL + Restful.DISABLE_FLS.value
|
||||
res = session.post(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
||||
|
||||
|
||||
def call_state():
|
||||
"""
|
||||
get cluster state
|
||||
"""
|
||||
url = BASE_URL + Restful.STATE.value
|
||||
res = session.get(url, verify=False)
|
||||
res_json = json.loads(res.text)
|
||||
if res_json["code"] == Status.FAILED.value:
|
||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
||||
result = res_json['cluster_state']
|
||||
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
|
||||
|
||||
|
||||
def process_result_json(code, describe, result):
|
||||
"""
|
||||
process result json
|
||||
"""
|
||||
result_dict = {"code": code, "describe": describe, "result": result}
|
||||
return json.dumps(result_dict)
|
||||
|
||||
|
||||
def process_self_define_json(code, describe):
|
||||
"""
|
||||
process self define json
|
||||
"""
|
||||
result_dict = {"code": code, "describe": describe}
|
||||
return json.dumps(result_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if request_name == Restful.SCALE.value:
|
||||
print(call_scale())
|
||||
|
||||
elif request_name == Restful.NODES.value:
|
||||
print(call_nodes())
|
||||
|
||||
elif request_name == Restful.GET_INSTANCE_DETAIL.value:
|
||||
print(call_get_instance_detail())
|
||||
|
||||
elif request_name == Restful.NEW_INSTANCE.value:
|
||||
print(call_new_instance())
|
||||
|
||||
elif request_name == Restful.QUERY_INSTANCE.value:
|
||||
print(call_query_instance())
|
||||
|
||||
elif request_name == Restful.ENABLE_FLS.value:
|
||||
print(call_enable_fls())
|
||||
|
||||
elif request_name == Restful.DISABLE_FLS.value:
|
||||
print(call_disable_fls())
|
||||
|
||||
elif request_name == Restful.STATE.value:
|
||||
print(call_state())
|
||||
|
||||
elif request_name == Restful.SCALE_OUT_ROLLBACK.value:
|
||||
print(call_scaleout_rollback())
|
||||
|
||||
else:
|
||||
print(process_self_define_json(1, "error. request_name is not found!"))
|
|
@ -1,99 +0,0 @@
|
|||
# 映射数据文件到对应的脚本源码
|
||||
|
||||
## 文档功能与适用场景
|
||||
|
||||
在MindSpore进行计算调试,怀疑遇到精度问题时可以选择dump文件进行对比。此时用户希望知道dump文件夹下的每个数据文件对应的Python源码。
|
||||
本文的主要目的为指导用户使用该工具进行数据文件到python源码的映射。
|
||||
此指导文档适合运行在 **Ascend硬件** 环境下的计算。
|
||||
|
||||
## 辅助工具使用
|
||||
|
||||
1. 使用脚本的3步操作:
|
||||
- 用户在训练脚本里设置context.set_context(mode=context.GRAPH_MODE, save_graphs=True),进行图文件的保存。
|
||||
- 用户开启dump数据功能,参考<https://www.mindspore.cn/tutorial/training/zh-CN/r1.1/advanced_use/custom_debugging_info.html#dump>
|
||||
- 获取dump数据文件的op_num,然后通过辅助脚本进行解析。如数据文件:`Default--network-TrainOneStepCell--network-WithLossCell--_backbone-
|
||||
ResNet--layer2-SequentialCell--0-ResidualBlock--conv2-Conv2d--Cast-op954_input_0_shape_128_128_3_3_kNumberTypeFloat32_DefaultFormat.bin`.
|
||||
可观察到Cast-op954,说明该算子的op_num为op954, 如下图所示。
|
||||
![image](./images/op_image.png)
|
||||
脚本名: **[map_file_to_code.py](https://gitee.com/mindspore/mindspore/blob/master/scripts/map_dump_file_to_code/map_file_to_code.py)**; 执行方式:
|
||||
|
||||
```ruby
|
||||
python3 map_file_to_code.py
|
||||
--graph_path(-p) [the graph path, default is the current path](option)
|
||||
--dump_op(-o) [Dump operator id, case insensitive, such as 'op954'.](required)
|
||||
For example:
|
||||
python3 map_file_to_code.py -p graph_path -o op954
|
||||
```
|
||||
|
||||
2. 解析效果
|
||||
解析文件时通常有2种情况:
|
||||
① 匹配时会显示出调用栈过程,需要用户在调用栈中查找自己的源码:
|
||||
|
||||
```ruby
|
||||
[INFO] Start to map the dump file to source code.
|
||||
[INFO] Find operation 'Cast'.
|
||||
In file /data1/jzg/mindspore/mindspore/nn/layer/conv.py(253)/
|
||||
output = self.conv2d(x, self.weight)
|
||||
In file /data1/jzg/dump_to_code/resnet/scripts/train/src/resnet.py(166)/
|
||||
out = self.conv2(out)
|
||||
In file /data1/jzg/mindspore/mindspore/nn/layer/container.py(173)/
|
||||
for cell in self.cell_list:
|
||||
In file /data1/jzg/dump_to_code/resnet/scripts/train/src/resnet.py(323)/ # 用户代码行
|
||||
c3 = self.layer2(c2)
|
||||
In file /data1/jzg/mindspore/mindspore/train/amp.py(101)/
|
||||
out = self._backbone(data)
|
||||
In file /data1/jzg/mindspore/mindspore/nn/wrap/cell_wrapper.py(247)/
|
||||
loss = self.network(*inputs)
|
||||
In file /data1/jzg/mindspore/mindspore/train/dataset_helper.py(87)/
|
||||
return self.network(*outputs)
|
||||
```
|
||||
|
||||
② 未匹配,在图中未找对应节点的调用栈:
|
||||
|
||||
```ruby
|
||||
[INFO] Start to map the dump file to source code.
|
||||
[WARNING] Cannot find cast's source code in ir file. # 未找到cast算子的信息
|
||||
```
|
||||
|
||||
3. 手动代码查找
|
||||
这里还会存在些特殊情况,需要用户进行自行查找。通过将dump的数据文件名中的'--'替换为'/'可获取到算子的full_name, 如下图所示:
|
||||
![image](./images/replace_symbol.png)
|
||||
input和output文件名shape后面的数据为对应算子的输入输出shape信息。然后利用算子的full_name和输入输出信息回到源码中进行对应代码的查找。
|
||||
举个例子说明如何手动在代码中查找指定full_name和shape的算子,例如full_name为: `Default/network/network/aspp/aspp_pooling/ResizeNearestNeighbor`,输入的shape为[8, 256, 1, 1], dtype为float32。
|
||||
可以观察到其scope为: `Default/network/network/aspp/aspp_pooling`,算子名为: `ResizeNearestNeighbor`。注意:scope中会存在Default、network自动填充,Default表示正向,network为网络名。
|
||||
查看以下用户定义的代码,首先我们先分析scope: `Default/network/network/aspp/aspp_pooling`。由network/aspp可定位到算子的定义与调用处分别为26行与31行,继续由`network/aspp/aspp_pooling`,可以定位到定义与调用处分别为4行与8行,然后通过算子名`ResizeNearestNeighbor`可以定位至定义与调用处分别为16行与19行。最后若存在相同scope下存在相同的算子名时,需要通过输入的shape进行进一步判断。
|
||||
|
||||
```ruby
|
||||
1 class ASPP(nn.Cell):
|
||||
2 def __init__(self):
|
||||
3 super(ASPP, self).__init__()
|
||||
4 self.aspp_pooling = ASPPPooling()
|
||||
5 self.drop = nn.Dropout(0.3)
|
||||
6
|
||||
7 def construct(self, x):
|
||||
8 x = self.aspp_pooling(x)
|
||||
9 x = self.drop(x)
|
||||
10 return x
|
||||
11
|
||||
12 class ASPPPooling(nn.Cell):
|
||||
13 def __init__(self):
|
||||
14 super(ASPPPooling, self).__init__()
|
||||
15 self.shape = P.Shape()
|
||||
16 self.resizenearestneighbor = P.ResizeNearestNeighbor((size[2], size[3]), True)
|
||||
17 def construct(self, x):
|
||||
18 size = self.shape(x)
|
||||
19 out = self.resizenearestneighbor(x)
|
||||
20 return out
|
||||
21
|
||||
22 # 主结构
|
||||
23 class DeepLabV3(nn.Cell):
|
||||
24 def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False):
|
||||
25 super(DeepLabV3, self).__init__()
|
||||
26 self.aspp = ASPP()
|
||||
27 self.shape = P.Shape()
|
||||
28
|
||||
29 def construct(self, x):
|
||||
30 size = self.shape(x)
|
||||
31 out = self.aspp(x)
|
||||
32 return out
|
||||
```
|
Binary file not shown.
Before Width: | Height: | Size: 14 KiB |
Binary file not shown.
Before Width: | Height: | Size: 7.2 KiB |
|
@ -1,156 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""map_file_to_code"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
class ParseIrInfo:
|
||||
"""
|
||||
Parse and return the operation info from ir file.
|
||||
"""
|
||||
|
||||
def __init__(self, ir_file):
|
||||
self.no_in_file_operation = []
|
||||
self.ir_file_path = self.ir_path_parse(ir_file)
|
||||
self.operation_info_dict = self.ir_info_parse()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.operation_info_dict)
|
||||
|
||||
def ir_path_parse(self, ir_file):
|
||||
"""
|
||||
parse the map file path.
|
||||
"""
|
||||
if ir_file == "":
|
||||
print("[WARNING] No graph_path parameter, use current path as graph path.")
|
||||
ir_file = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
map_ir_file = ""
|
||||
file_size = 0
|
||||
map_ir_filename = "trace_code_graph"
|
||||
for filename in os.listdir(os.path.join(ir_file)):
|
||||
if map_ir_filename not in filename:
|
||||
continue
|
||||
tmp_file = os.path.join(ir_file, filename)
|
||||
tmp_file_size = os.path.getsize(tmp_file)
|
||||
if tmp_file_size >= file_size:
|
||||
file_size = tmp_file_size
|
||||
map_ir_file = tmp_file
|
||||
if map_ir_file == "":
|
||||
exit("[ERROR] Please set \"save_graphs=True\" in context to save {} file!".format(map_ir_filename))
|
||||
return map_ir_file
|
||||
|
||||
def ir_info_parse(self):
|
||||
"""
|
||||
parse the ir file and save code line corresponding to the operator
|
||||
"""
|
||||
|
||||
all_op_info_dict = {} # recode all operation info
|
||||
single_op_info_dict = {} # recode single operation info
|
||||
op_start_char_flag = False # Start operator fragment
|
||||
op_end_char_flag = False # End of operator fragment
|
||||
op_start_info_num = 0 # Accumulate the num to recode operation
|
||||
operation_line = 0 # The line number of the operator
|
||||
op_start_line_num = 0 # The line number of starting operator information
|
||||
op_start_info_flag = False # Start operator information
|
||||
|
||||
with open(self.ir_file_path, 'r+') as file:
|
||||
txt_context_list = file.readlines()
|
||||
|
||||
for line_num, txt_context in enumerate(txt_context_list):
|
||||
txt_context = txt_context.strip()
|
||||
# Start operator fragment
|
||||
if txt_context.endswith(") {"):
|
||||
op_start_char_flag = True
|
||||
op_end_char_flag = False
|
||||
|
||||
# End of operator fragment
|
||||
if txt_context == "}":
|
||||
op_end_char_flag = True
|
||||
|
||||
# Determine whether it is operator information
|
||||
if txt_context.startswith("%") and ") = " in txt_context and txt_context[1].isdigit():
|
||||
op_start_info_flag = True
|
||||
op_start_line_num = line_num
|
||||
op_start_info_num += 1
|
||||
single_op_info_dict = {"in_file": []}
|
||||
|
||||
# Judge and start to recode operation info
|
||||
if op_start_char_flag and not op_end_char_flag and op_start_info_flag and line_num != op_start_line_num:
|
||||
if "-op" in txt_context and txt_context.split("-op")[-1].split(")")[0].isdigit():
|
||||
single_op_info_dict["origin_op_name"] = txt_context.split("-op")[0].split("/")[-1]
|
||||
single_op_info_dict["op_name"] = txt_context.split("-op")[0].split("/")[-1].lower()
|
||||
single_op_info_dict["op_num"] = "op" + txt_context.split("-op")[-1].split(")")[0]
|
||||
operation_line = line_num
|
||||
if "In file" in txt_context:
|
||||
in_file_info = txt_context.split("#")[-1].strip().rstrip("/")
|
||||
single_op_info_dict["in_file"].append(in_file_info)
|
||||
if line_num - operation_line == 1 and "In file" not in txt_context and "op_num" in single_op_info_dict:
|
||||
self.no_in_file_operation.append(single_op_info_dict["op_num"])
|
||||
op_start_info_flag = False
|
||||
all_op_info_dict[op_start_info_num] = single_op_info_dict
|
||||
|
||||
return all_op_info_dict
|
||||
|
||||
|
||||
class MapOperationToLine:
|
||||
"""
|
||||
to show operation info
|
||||
"""
|
||||
def __init__(self, dump_op, ir_info_dict):
|
||||
self.dump_op = dump_op
|
||||
self.ir_info_dict = ir_info_dict
|
||||
|
||||
def show_operator_info(self):
|
||||
"""
|
||||
find operator
|
||||
"""
|
||||
origin_dump_op_name = self.dump_op.split("-")[0]
|
||||
dump_op_name = origin_dump_op_name.lower()
|
||||
dump_op_num = self.dump_op.split("-")[-1]
|
||||
for _, op_info in self.ir_info_dict.items():
|
||||
if op_info["op_num"] == dump_op_num and op_info["in_file"] is not None:
|
||||
if dump_op_name in (dump_op_num, op_info["op_name"]):
|
||||
if not op_info["in_file"]:
|
||||
print("[WARNING] Cannot find {}'s source code in ir file.".format(op_info["origin_op_name"]))
|
||||
return False
|
||||
print("[INFO] Find operation '{}'.".format(op_info["origin_op_name"]))
|
||||
for line in op_info["in_file"]:
|
||||
print(" {}".format(line.split(" ")[0]))
|
||||
print(" {}".format(line.split(" ")[-1]))
|
||||
return True
|
||||
print("[WARNING] Cannot find operation {}'s in ir file.".format(origin_dump_op_name))
|
||||
return False
|
||||
|
||||
|
||||
def start_find(dump_op, map_code_file):
|
||||
"""
|
||||
start find error operation in code.
|
||||
"""
|
||||
|
||||
print("[INFO] Start to map the dump file to source code.")
|
||||
ir_op_info_dict = ParseIrInfo(map_code_file).operation_info_dict
|
||||
MapOperationToLine(dump_op, ir_op_info_dict).show_operator_info()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Find the dump operator in the user code')
|
||||
parser.add_argument('--graph_path', '-p', type=str, default="", help='Save graph files path (option)')
|
||||
parser.add_argument('--dump_op', '-o', type=str, default="", required=True,
|
||||
help="Dump operator id, case insensitive, such as 'op3352'.")
|
||||
args_opt = parser.parse_args()
|
||||
start_find(args_opt.dump_op, args_opt.graph_path)
|
Loading…
Reference in New Issue