forked from mindspore-Ecosystem/mindspore
clean up codecheck warnings by deleting unused files
This commit is contained in:
parent
b0e9132745
commit
a63e823b62
|
@ -1,315 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""
|
|
||||||
Function:
|
|
||||||
Use to control the federated learning cluster
|
|
||||||
Usage:
|
|
||||||
python fl_restful_tool.py [http_type] [ip] [port] [request_name] [server_num] [instance_param] [metrics_file_path]
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
from enum import Enum
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
class Status(Enum):
|
|
||||||
"""
|
|
||||||
Response Status
|
|
||||||
"""
|
|
||||||
SUCCESS = "0"
|
|
||||||
FAILED = "1"
|
|
||||||
|
|
||||||
|
|
||||||
class Restful(Enum):
|
|
||||||
"""
|
|
||||||
Define restful interface constant
|
|
||||||
"""
|
|
||||||
SCALE = "scale"
|
|
||||||
SCALE_OUT = "scaleout"
|
|
||||||
SCALE_IN = "scalein"
|
|
||||||
NODES = "nodes"
|
|
||||||
GET_INSTANCE_DETAIL = "getInstanceDetail"
|
|
||||||
NEW_INSTANCE = "newInstance"
|
|
||||||
QUERY_INSTANCE = "queryInstance"
|
|
||||||
ENABLE_FLS = "enableFLS"
|
|
||||||
DISABLE_FLS = "disableFLS"
|
|
||||||
STATE = "state"
|
|
||||||
SCALE_OUT_ROLLBACK = "scaleoutRollback"
|
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--http_type", type=str, default="http", help="http or https")
|
|
||||||
parser.add_argument("--ip", type=str, default="127.0.0.1")
|
|
||||||
parser.add_argument("--port", type=int, default=6666)
|
|
||||||
parser.add_argument("--request_name", type=str, default="")
|
|
||||||
|
|
||||||
parser.add_argument("--server_num", type=int, default=0)
|
|
||||||
parser.add_argument("--instance_param", type=str, default="")
|
|
||||||
parser.add_argument("--metrics_file_path", type=str, default="/opt/huawei/mindspore/hybrid_albert/metrics.json")
|
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
|
||||||
http_type = args.http_type
|
|
||||||
ip = args.ip
|
|
||||||
port = args.port
|
|
||||||
request_name = args.request_name
|
|
||||||
server_num = args.server_num
|
|
||||||
instance_param = args.instance_param
|
|
||||||
metrics_file_path = args.metrics_file_path
|
|
||||||
|
|
||||||
headers = {'Content-Type': 'application/json'}
|
|
||||||
session = requests.Session()
|
|
||||||
BASE_URL = '{}://{}:{}/'.format(http_type, ip, str(port))
|
|
||||||
|
|
||||||
|
|
||||||
def call_scale():
|
|
||||||
"""
|
|
||||||
call cluster scale out or scale in
|
|
||||||
"""
|
|
||||||
if server_num == 0:
|
|
||||||
return process_self_define_json(Status.FAILED.value, "error. server_num is 0")
|
|
||||||
|
|
||||||
node_ids = json.loads(call_nodes())["result"]
|
|
||||||
cluster_abstract_node_num = len(node_ids)
|
|
||||||
if cluster_abstract_node_num == 0:
|
|
||||||
return process_self_define_json(Status.FAILED.value, "error. cluster abstract node num is 0")
|
|
||||||
|
|
||||||
cluster_server_node_num = 0
|
|
||||||
cluster_worker_node_num = 0
|
|
||||||
cluster_server_node_base_name = ''
|
|
||||||
for i in range(0, cluster_abstract_node_num):
|
|
||||||
if node_ids[i]['role'] == 'WORKER':
|
|
||||||
cluster_worker_node_num = cluster_worker_node_num + 1
|
|
||||||
elif node_ids[i]['role'] == 'SERVER':
|
|
||||||
cluster_server_node_num = cluster_server_node_num + 1
|
|
||||||
cluster_server_node_name = str(node_ids[i]['nodeId'])
|
|
||||||
index = cluster_server_node_name.rindex('-')
|
|
||||||
cluster_server_node_base_name = cluster_server_node_name[0:index]
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
if cluster_server_node_num == server_num:
|
|
||||||
return process_self_define_json(Status.FAILED.value, "error. cluster server num is same with server_num.")
|
|
||||||
if cluster_server_node_num > server_num:
|
|
||||||
scale_in_len = cluster_server_node_num - server_num
|
|
||||||
scale_in_node_ids = []
|
|
||||||
for index in range(cluster_server_node_num - scale_in_len, cluster_server_node_num):
|
|
||||||
scale_in_node_name = cluster_server_node_base_name + "-" + str(index)
|
|
||||||
scale_in_node_ids.append(scale_in_node_name)
|
|
||||||
return call_scalein(scale_in_node_ids)
|
|
||||||
return call_scaleout(server_num - cluster_server_node_num)
|
|
||||||
|
|
||||||
|
|
||||||
def call_scaleout(scale_out_server_num, scale_out_worker_num=0):
|
|
||||||
"""
|
|
||||||
call scaleout
|
|
||||||
"""
|
|
||||||
url = BASE_URL + Restful.SCALE_OUT.value
|
|
||||||
data = {"server_num": scale_out_server_num, "worker_num": scale_out_worker_num}
|
|
||||||
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
|
|
||||||
result = "scale out server num is " + str(scale_out_server_num)
|
|
||||||
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
|
|
||||||
|
|
||||||
|
|
||||||
def call_scaleout_rollback():
|
|
||||||
"""
|
|
||||||
call scaleout rollback
|
|
||||||
"""
|
|
||||||
url = BASE_URL + Restful.SCALE_OUT_ROLLBACK.value
|
|
||||||
res = session.get(url, verify=False)
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
|
||||||
|
|
||||||
|
|
||||||
def call_scalein(scale_in_node_ids):
|
|
||||||
"""
|
|
||||||
call cluster to scale in
|
|
||||||
"""
|
|
||||||
if not scale_in_node_ids:
|
|
||||||
return process_self_define_json(Status.FAILED.value, "error. node ids is empty.")
|
|
||||||
|
|
||||||
url = BASE_URL + Restful.SCALE_IN.value
|
|
||||||
data = {"node_ids": scale_in_node_ids}
|
|
||||||
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
result = "scale in node ids is " + str(scale_in_node_ids)
|
|
||||||
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
|
|
||||||
|
|
||||||
|
|
||||||
def call_nodes():
|
|
||||||
"""
|
|
||||||
get nodes info
|
|
||||||
"""
|
|
||||||
url = BASE_URL + Restful.NODES.value
|
|
||||||
res = session.get(url, verify=False)
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["nodeIds"])
|
|
||||||
|
|
||||||
|
|
||||||
def call_get_instance_detail():
|
|
||||||
"""
|
|
||||||
get cluster instance detail
|
|
||||||
"""
|
|
||||||
if not os.path.exists(metrics_file_path):
|
|
||||||
return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
|
|
||||||
|
|
||||||
ans_json_obj = {}
|
|
||||||
|
|
||||||
with open(metrics_file_path, 'r') as f:
|
|
||||||
metrics_list = f.readlines()
|
|
||||||
|
|
||||||
if not metrics_list:
|
|
||||||
return process_self_define_json(Status.FAILED.value, "error. metrics file has no content")
|
|
||||||
|
|
||||||
|
|
||||||
last_metrics = metrics_list[len(metrics_list) - 1]
|
|
||||||
last_metrics_obj = json.loads(last_metrics)
|
|
||||||
|
|
||||||
ans_json_obj["code"] = Status.SUCCESS.value
|
|
||||||
ans_json_obj["describe"] = "get instance metrics detail successful."
|
|
||||||
ans_json_obj["result"] = last_metrics_obj
|
|
||||||
|
|
||||||
return json.dumps(ans_json_obj)
|
|
||||||
|
|
||||||
|
|
||||||
def call_new_instance():
|
|
||||||
"""
|
|
||||||
call cluster new instance
|
|
||||||
"""
|
|
||||||
if instance_param == "":
|
|
||||||
return process_self_define_json(Status.FAILED.value, "error. instance_param is empty.")
|
|
||||||
instance_param_list = instance_param.split(sep=",")
|
|
||||||
instance_param_json_obj = {}
|
|
||||||
|
|
||||||
url = BASE_URL + Restful.NEW_INSTANCE.value
|
|
||||||
for cur in instance_param_list:
|
|
||||||
pair = cur.split(sep="=")
|
|
||||||
instance_param_json_obj[pair[0]] = float(pair[1])
|
|
||||||
|
|
||||||
data = json.dumps(instance_param_json_obj)
|
|
||||||
res = session.post(url, verify=False, data=data)
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
|
||||||
|
|
||||||
|
|
||||||
def call_query_instance():
|
|
||||||
"""
|
|
||||||
query cluster instance
|
|
||||||
"""
|
|
||||||
url = BASE_URL + Restful.QUERY_INSTANCE.value
|
|
||||||
res = session.post(url, verify=False)
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["result"])
|
|
||||||
|
|
||||||
|
|
||||||
def call_enable_fls():
|
|
||||||
"""
|
|
||||||
enable cluster fls
|
|
||||||
"""
|
|
||||||
url = BASE_URL + Restful.ENABLE_FLS.value
|
|
||||||
res = session.post(url, verify=False)
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
|
||||||
|
|
||||||
|
|
||||||
def call_disable_fls():
|
|
||||||
"""
|
|
||||||
disable cluster fls
|
|
||||||
"""
|
|
||||||
url = BASE_URL + Restful.DISABLE_FLS.value
|
|
||||||
res = session.post(url, verify=False)
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
|
|
||||||
|
|
||||||
|
|
||||||
def call_state():
|
|
||||||
"""
|
|
||||||
get cluster state
|
|
||||||
"""
|
|
||||||
url = BASE_URL + Restful.STATE.value
|
|
||||||
res = session.get(url, verify=False)
|
|
||||||
res_json = json.loads(res.text)
|
|
||||||
if res_json["code"] == Status.FAILED.value:
|
|
||||||
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
|
|
||||||
result = res_json['cluster_state']
|
|
||||||
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
|
|
||||||
|
|
||||||
|
|
||||||
def process_result_json(code, describe, result):
|
|
||||||
"""
|
|
||||||
process result json
|
|
||||||
"""
|
|
||||||
result_dict = {"code": code, "describe": describe, "result": result}
|
|
||||||
return json.dumps(result_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def process_self_define_json(code, describe):
|
|
||||||
"""
|
|
||||||
process self define json
|
|
||||||
"""
|
|
||||||
result_dict = {"code": code, "describe": describe}
|
|
||||||
return json.dumps(result_dict)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
if request_name == Restful.SCALE.value:
|
|
||||||
print(call_scale())
|
|
||||||
|
|
||||||
elif request_name == Restful.NODES.value:
|
|
||||||
print(call_nodes())
|
|
||||||
|
|
||||||
elif request_name == Restful.GET_INSTANCE_DETAIL.value:
|
|
||||||
print(call_get_instance_detail())
|
|
||||||
|
|
||||||
elif request_name == Restful.NEW_INSTANCE.value:
|
|
||||||
print(call_new_instance())
|
|
||||||
|
|
||||||
elif request_name == Restful.QUERY_INSTANCE.value:
|
|
||||||
print(call_query_instance())
|
|
||||||
|
|
||||||
elif request_name == Restful.ENABLE_FLS.value:
|
|
||||||
print(call_enable_fls())
|
|
||||||
|
|
||||||
elif request_name == Restful.DISABLE_FLS.value:
|
|
||||||
print(call_disable_fls())
|
|
||||||
|
|
||||||
elif request_name == Restful.STATE.value:
|
|
||||||
print(call_state())
|
|
||||||
|
|
||||||
elif request_name == Restful.SCALE_OUT_ROLLBACK.value:
|
|
||||||
print(call_scaleout_rollback())
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(process_self_define_json(1, "error. request_name is not found!"))
|
|
|
@ -1,99 +0,0 @@
|
||||||
# 映射数据文件到对应的脚本源码
|
|
||||||
|
|
||||||
## 文档功能与适用场景
|
|
||||||
|
|
||||||
在MindSpore进行计算调试,怀疑遇到精度问题时可以选择dump文件进行对比。此时用户希望知道dump文件夹下的每个数据文件对应的Python源码。
|
|
||||||
本文的主要目的为指导用户使用该工具进行数据文件到python源码的映射。
|
|
||||||
此指导文档适合运行在 **Ascend硬件** 环境下的计算。
|
|
||||||
|
|
||||||
## 辅助工具使用
|
|
||||||
|
|
||||||
1. 使用脚本的3步操作:
|
|
||||||
- 用户在训练脚本里设置context.set_context(mode=context.GRAPH_MODE, save_graphs=True),进行图文件的保存。
|
|
||||||
- 用户开启dump数据功能,参考<https://www.mindspore.cn/tutorial/training/zh-CN/r1.1/advanced_use/custom_debugging_info.html#dump>
|
|
||||||
- 获取dump数据文件的op_num,然后通过辅助脚本进行解析。如数据文件:`Default--network-TrainOneStepCell--network-WithLossCell--_backbone-
|
|
||||||
ResNet--layer2-SequentialCell--0-ResidualBlock--conv2-Conv2d--Cast-op954_input_0_shape_128_128_3_3_kNumberTypeFloat32_DefaultFormat.bin`.
|
|
||||||
可观察到Cast-op954,说明该算子的op_num为op954, 如下图所示。
|
|
||||||
![image](./images/op_image.png)
|
|
||||||
脚本名: **[map_file_to_code.py](https://gitee.com/mindspore/mindspore/blob/master/scripts/map_dump_file_to_code/map_file_to_code.py)**; 执行方式:
|
|
||||||
|
|
||||||
```ruby
|
|
||||||
python3 map_file_to_code.py
|
|
||||||
--graph_path(-p) [the graph path, default is the current path](option)
|
|
||||||
--dump_op(-o) [Dump operator id, case insensitive, such as 'op954'.](required)
|
|
||||||
For example:
|
|
||||||
python3 map_file_to_code.py -p graph_path -o op954
|
|
||||||
```
|
|
||||||
|
|
||||||
2. 解析效果
|
|
||||||
解析文件时通常有2种情况:
|
|
||||||
① 匹配时会显示出调用栈过程,需要用户在调用栈中查找自己的源码:
|
|
||||||
|
|
||||||
```ruby
|
|
||||||
[INFO] Start to map the dump file to source code.
|
|
||||||
[INFO] Find operation 'Cast'.
|
|
||||||
In file /data1/jzg/mindspore/mindspore/nn/layer/conv.py(253)/
|
|
||||||
output = self.conv2d(x, self.weight)
|
|
||||||
In file /data1/jzg/dump_to_code/resnet/scripts/train/src/resnet.py(166)/
|
|
||||||
out = self.conv2(out)
|
|
||||||
In file /data1/jzg/mindspore/mindspore/nn/layer/container.py(173)/
|
|
||||||
for cell in self.cell_list:
|
|
||||||
In file /data1/jzg/dump_to_code/resnet/scripts/train/src/resnet.py(323)/ # 用户代码行
|
|
||||||
c3 = self.layer2(c2)
|
|
||||||
In file /data1/jzg/mindspore/mindspore/train/amp.py(101)/
|
|
||||||
out = self._backbone(data)
|
|
||||||
In file /data1/jzg/mindspore/mindspore/nn/wrap/cell_wrapper.py(247)/
|
|
||||||
loss = self.network(*inputs)
|
|
||||||
In file /data1/jzg/mindspore/mindspore/train/dataset_helper.py(87)/
|
|
||||||
return self.network(*outputs)
|
|
||||||
```
|
|
||||||
|
|
||||||
② 未匹配,在图中未找对应节点的调用栈:
|
|
||||||
|
|
||||||
```ruby
|
|
||||||
[INFO] Start to map the dump file to source code.
|
|
||||||
[WARNING] Cannot find cast's source code in ir file. # 未找到cast算子的信息
|
|
||||||
```
|
|
||||||
|
|
||||||
3. 手动代码查找
|
|
||||||
这里还会存在些特殊情况,需要用户进行自行查找。通过将dump的数据文件名中的'--'替换为'/'可获取到算子的full_name, 如下图所示:
|
|
||||||
![image](./images/replace_symbol.png)
|
|
||||||
input和output文件名shape后面的数据为对应算子的输入输出shape信息。然后利用算子的full_name和输入输出信息回到源码中进行对应代码的查找。
|
|
||||||
举个例子说明如何手动在代码中查找指定full_name和shape的算子,例如full_name为: `Default/network/network/aspp/aspp_pooling/ResizeNearestNeighbor`,输入的shape为[8, 256, 1, 1], dtype为float32。
|
|
||||||
可以观察到其scope为: `Default/network/network/aspp/aspp_pooling`,算子名为: `ResizeNearestNeighbor`。注意:scope中会存在Default、network自动填充,Default表示正向,network为网络名。
|
|
||||||
查看以下用户定义的代码,首先我们先分析scope: `Default/network/network/aspp/aspp_pooling`。由network/aspp可定位到算子的定义与调用处分别为26行与31行,继续由`network/aspp/aspp_pooling`,可以定位到定义与调用处分别为4行与8行,然后通过算子名`ResizeNearestNeighbor`可以定位至定义与调用处分别为16行与19行。最后若存在相同scope下存在相同的算子名时,需要通过输入的shape进行进一步判断。
|
|
||||||
|
|
||||||
```ruby
|
|
||||||
1 class ASPP(nn.Cell):
|
|
||||||
2 def __init__(self):
|
|
||||||
3 super(ASPP, self).__init__()
|
|
||||||
4 self.aspp_pooling = ASPPPooling()
|
|
||||||
5 self.drop = nn.Dropout(0.3)
|
|
||||||
6
|
|
||||||
7 def construct(self, x):
|
|
||||||
8 x = self.aspp_pooling(x)
|
|
||||||
9 x = self.drop(x)
|
|
||||||
10 return x
|
|
||||||
11
|
|
||||||
12 class ASPPPooling(nn.Cell):
|
|
||||||
13 def __init__(self):
|
|
||||||
14 super(ASPPPooling, self).__init__()
|
|
||||||
15 self.shape = P.Shape()
|
|
||||||
16 self.resizenearestneighbor = P.ResizeNearestNeighbor((size[2], size[3]), True)
|
|
||||||
17 def construct(self, x):
|
|
||||||
18 size = self.shape(x)
|
|
||||||
19 out = self.resizenearestneighbor(x)
|
|
||||||
20 return out
|
|
||||||
21
|
|
||||||
22 # 主结构
|
|
||||||
23 class DeepLabV3(nn.Cell):
|
|
||||||
24 def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False):
|
|
||||||
25 super(DeepLabV3, self).__init__()
|
|
||||||
26 self.aspp = ASPP()
|
|
||||||
27 self.shape = P.Shape()
|
|
||||||
28
|
|
||||||
29 def construct(self, x):
|
|
||||||
30 size = self.shape(x)
|
|
||||||
31 out = self.aspp(x)
|
|
||||||
32 return out
|
|
||||||
```
|
|
Binary file not shown.
Before Width: | Height: | Size: 14 KiB |
Binary file not shown.
Before Width: | Height: | Size: 7.2 KiB |
|
@ -1,156 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""map_file_to_code"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
|
|
||||||
class ParseIrInfo:
|
|
||||||
"""
|
|
||||||
Parse and return the operation info from ir file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ir_file):
|
|
||||||
self.no_in_file_operation = []
|
|
||||||
self.ir_file_path = self.ir_path_parse(ir_file)
|
|
||||||
self.operation_info_dict = self.ir_info_parse()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.operation_info_dict)
|
|
||||||
|
|
||||||
def ir_path_parse(self, ir_file):
|
|
||||||
"""
|
|
||||||
parse the map file path.
|
|
||||||
"""
|
|
||||||
if ir_file == "":
|
|
||||||
print("[WARNING] No graph_path parameter, use current path as graph path.")
|
|
||||||
ir_file = os.path.abspath(os.path.dirname(__file__))
|
|
||||||
|
|
||||||
map_ir_file = ""
|
|
||||||
file_size = 0
|
|
||||||
map_ir_filename = "trace_code_graph"
|
|
||||||
for filename in os.listdir(os.path.join(ir_file)):
|
|
||||||
if map_ir_filename not in filename:
|
|
||||||
continue
|
|
||||||
tmp_file = os.path.join(ir_file, filename)
|
|
||||||
tmp_file_size = os.path.getsize(tmp_file)
|
|
||||||
if tmp_file_size >= file_size:
|
|
||||||
file_size = tmp_file_size
|
|
||||||
map_ir_file = tmp_file
|
|
||||||
if map_ir_file == "":
|
|
||||||
exit("[ERROR] Please set \"save_graphs=True\" in context to save {} file!".format(map_ir_filename))
|
|
||||||
return map_ir_file
|
|
||||||
|
|
||||||
def ir_info_parse(self):
|
|
||||||
"""
|
|
||||||
parse the ir file and save code line corresponding to the operator
|
|
||||||
"""
|
|
||||||
|
|
||||||
all_op_info_dict = {} # recode all operation info
|
|
||||||
single_op_info_dict = {} # recode single operation info
|
|
||||||
op_start_char_flag = False # Start operator fragment
|
|
||||||
op_end_char_flag = False # End of operator fragment
|
|
||||||
op_start_info_num = 0 # Accumulate the num to recode operation
|
|
||||||
operation_line = 0 # The line number of the operator
|
|
||||||
op_start_line_num = 0 # The line number of starting operator information
|
|
||||||
op_start_info_flag = False # Start operator information
|
|
||||||
|
|
||||||
with open(self.ir_file_path, 'r+') as file:
|
|
||||||
txt_context_list = file.readlines()
|
|
||||||
|
|
||||||
for line_num, txt_context in enumerate(txt_context_list):
|
|
||||||
txt_context = txt_context.strip()
|
|
||||||
# Start operator fragment
|
|
||||||
if txt_context.endswith(") {"):
|
|
||||||
op_start_char_flag = True
|
|
||||||
op_end_char_flag = False
|
|
||||||
|
|
||||||
# End of operator fragment
|
|
||||||
if txt_context == "}":
|
|
||||||
op_end_char_flag = True
|
|
||||||
|
|
||||||
# Determine whether it is operator information
|
|
||||||
if txt_context.startswith("%") and ") = " in txt_context and txt_context[1].isdigit():
|
|
||||||
op_start_info_flag = True
|
|
||||||
op_start_line_num = line_num
|
|
||||||
op_start_info_num += 1
|
|
||||||
single_op_info_dict = {"in_file": []}
|
|
||||||
|
|
||||||
# Judge and start to recode operation info
|
|
||||||
if op_start_char_flag and not op_end_char_flag and op_start_info_flag and line_num != op_start_line_num:
|
|
||||||
if "-op" in txt_context and txt_context.split("-op")[-1].split(")")[0].isdigit():
|
|
||||||
single_op_info_dict["origin_op_name"] = txt_context.split("-op")[0].split("/")[-1]
|
|
||||||
single_op_info_dict["op_name"] = txt_context.split("-op")[0].split("/")[-1].lower()
|
|
||||||
single_op_info_dict["op_num"] = "op" + txt_context.split("-op")[-1].split(")")[0]
|
|
||||||
operation_line = line_num
|
|
||||||
if "In file" in txt_context:
|
|
||||||
in_file_info = txt_context.split("#")[-1].strip().rstrip("/")
|
|
||||||
single_op_info_dict["in_file"].append(in_file_info)
|
|
||||||
if line_num - operation_line == 1 and "In file" not in txt_context and "op_num" in single_op_info_dict:
|
|
||||||
self.no_in_file_operation.append(single_op_info_dict["op_num"])
|
|
||||||
op_start_info_flag = False
|
|
||||||
all_op_info_dict[op_start_info_num] = single_op_info_dict
|
|
||||||
|
|
||||||
return all_op_info_dict
|
|
||||||
|
|
||||||
|
|
||||||
class MapOperationToLine:
|
|
||||||
"""
|
|
||||||
to show operation info
|
|
||||||
"""
|
|
||||||
def __init__(self, dump_op, ir_info_dict):
|
|
||||||
self.dump_op = dump_op
|
|
||||||
self.ir_info_dict = ir_info_dict
|
|
||||||
|
|
||||||
def show_operator_info(self):
|
|
||||||
"""
|
|
||||||
find operator
|
|
||||||
"""
|
|
||||||
origin_dump_op_name = self.dump_op.split("-")[0]
|
|
||||||
dump_op_name = origin_dump_op_name.lower()
|
|
||||||
dump_op_num = self.dump_op.split("-")[-1]
|
|
||||||
for _, op_info in self.ir_info_dict.items():
|
|
||||||
if op_info["op_num"] == dump_op_num and op_info["in_file"] is not None:
|
|
||||||
if dump_op_name in (dump_op_num, op_info["op_name"]):
|
|
||||||
if not op_info["in_file"]:
|
|
||||||
print("[WARNING] Cannot find {}'s source code in ir file.".format(op_info["origin_op_name"]))
|
|
||||||
return False
|
|
||||||
print("[INFO] Find operation '{}'.".format(op_info["origin_op_name"]))
|
|
||||||
for line in op_info["in_file"]:
|
|
||||||
print(" {}".format(line.split(" ")[0]))
|
|
||||||
print(" {}".format(line.split(" ")[-1]))
|
|
||||||
return True
|
|
||||||
print("[WARNING] Cannot find operation {}'s in ir file.".format(origin_dump_op_name))
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def start_find(dump_op, map_code_file):
|
|
||||||
"""
|
|
||||||
start find error operation in code.
|
|
||||||
"""
|
|
||||||
|
|
||||||
print("[INFO] Start to map the dump file to source code.")
|
|
||||||
ir_op_info_dict = ParseIrInfo(map_code_file).operation_info_dict
|
|
||||||
MapOperationToLine(dump_op, ir_op_info_dict).show_operator_info()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description='Find the dump operator in the user code')
|
|
||||||
parser.add_argument('--graph_path', '-p', type=str, default="", help='Save graph files path (option)')
|
|
||||||
parser.add_argument('--dump_op', '-o', type=str, default="", required=True,
|
|
||||||
help="Dump operator id, case insensitive, such as 'op3352'.")
|
|
||||||
args_opt = parser.parse_args()
|
|
||||||
start_find(args_opt.dump_op, args_opt.graph_path)
|
|
Loading…
Reference in New Issue