fix bug train.py and test.py

This commit is contained in:
maijianqiang 2021-06-19 15:08:17 +08:00
parent 6fb3981170
commit a6010c8ac3
4 changed files with 46 additions and 17 deletions

View File

@ -55,13 +55,17 @@ TEST_BUFFER_SIZE: 4
TEST_DROP_REMAINDER: False TEST_DROP_REMAINDER: False
INFERENCE: True INFERENCE: True
# ======================================================================================
#export options #export options
device_id: 0 device_id: 0
batch_size: 1 batch_size: 1
file_name: "psenet" file_name: "psenet"
file_format: "MINDIR" file_format: "MINDIR"
# ======================================================================================
#postprocess
result_path: "./scripts/result_Files"
img_path: ""
--- ---
# Help description for each configuration # Help description for each configuration
@ -80,4 +84,6 @@ batch_size: "batch size"
file_name: "output file name" file_name: "output file name"
file_format: "file format choices[AIR, MINDIR, ONNX]" file_format: "file format choices[AIR, MINDIR, ONNX]"
object_home: "your direction name" object_home: "your direction name"
modelarts_home: "modelarts working path" modelarts_home: "modelarts working path"
result_path: "result Files path."
img_path: "image files path."

View File

@ -18,19 +18,19 @@ import os
import math import math
import operator import operator
from functools import reduce from functools import reduce
import argparse
import numpy as np import numpy as np
import cv2 import cv2
from src.model_utils.config import config
from src.config import config
from src.ETSNET.pse import pse from src.ETSNET.pse import pse
def sort_to_clockwise(points): def sort_to_clockwise(points):
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2)) center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees( clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True) math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
return clockwise_points return clockwise_points
def write_result_as_txt(image_name, img_bboxes, path): def write_result_as_txt(image_name, img_bboxes, path):
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path) os.makedirs(path)
@ -51,10 +51,6 @@ def get_img(image_path):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image return image
parser = argparse.ArgumentParser(description='postprocess')
parser.add_argument("--result_path", type=str, default="./scripts/result_Files", help='result Files path.')
parser.add_argument("--img_path", type=str, default="", help='image files path.')
args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
if not os.path.isdir('./res/submit_ic15/'): if not os.path.isdir('./res/submit_ic15/'):
@ -62,17 +58,17 @@ if __name__ == "__main__":
if not os.path.isdir('./res/vis_ic15/'): if not os.path.isdir('./res/vis_ic15/'):
os.makedirs('./res/vis_ic15/') os.makedirs('./res/vis_ic15/')
file_list = os.listdir(args.img_path) file_list = os.listdir(config.img_path)
for k in file_list: for k in file_list:
if os.path.splitext(k)[-1].lower() in ['.jpg', '.jpeg', '.png']: if os.path.splitext(k)[-1].lower() in ['.jpg', '.jpeg', '.png']:
img_path = os.path.join(args.img_path, k) img_path = os.path.join(config.img_path, k)
img = get_img(img_path).reshape(1, 720, 1280, 3) img = get_img(img_path).reshape(1, 720, 1280, 3)
img = img[0].astype(np.uint8).copy() img = img[0].astype(np.uint8).copy()
img_name = os.path.split(img_path)[-1] img_name = os.path.split(img_path)[-1]
score = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_0.bin'), np.float32) score = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_0.bin'), np.float32)
score = score.reshape(1, 1, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE) score = score.reshape(1, 1, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
kernels = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_1.bin'), bool) kernels = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_1.bin'), bool)
kernels = kernels.reshape(1, config.KERNEL_NUM, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE) kernels = kernels.reshape(1, config.KERNEL_NUM, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
score = np.squeeze(score) score = np.squeeze(score)
kernels = np.squeeze(kernels) kernels = np.squeeze(kernels)

View File

@ -84,7 +84,7 @@ def modelarts_pre_process():
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command)) os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path)) os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
os.system('cd {}&&sed -i s/\r// scripts/run_eval_ascend.sh') os.system('cd {}&&sed -i s/\r// scripts/run_eval_ascend.sh'.format(local_path))
def modelarts_post_process(): def modelarts_post_process():

View File

@ -14,7 +14,8 @@
# ============================================================================ # ============================================================================
from ast import literal_eval as liter import ast
import operator
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.communication.management import init from mindspore.communication.management import init
@ -37,6 +38,32 @@ set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id()) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
binOps = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Mod: operator.mod
}
def arithmeticeval(s):
node = ast.parse(s, mode='eval')
def _eval(node):
if isinstance(node, ast.BinOp):
return binOps[type(node.op)](_eval(node.left), _eval(node.right))
if isinstance(node, ast.Num):
return node.n
if isinstance(node, ast.Expression):
return _eval(node.body)
raise Exception('unsupported type{}'.format(node))
return _eval(node.body)
def modelarts_pre_process(): def modelarts_pre_process():
pass pass
@ -44,8 +71,8 @@ def modelarts_pre_process():
@moxing_wrapper(pre_process=modelarts_pre_process) @moxing_wrapper(pre_process=modelarts_pre_process)
def train(): def train():
rank_id = 0 rank_id = 0
config.BASE_LR = liter(config.BASE_LR) config.BASE_LR = arithmeticeval(config.BASE_LR)
config.WARMUP_RATIO = liter(config.WARMUP_RATIO) config.WARMUP_RATIO = arithmeticeval(config.WARMUP_RATIO)
device_num = get_device_num() device_num = get_device_num()
if config.run_distribute: if config.run_distribute: