fix bug train.py and test.py

This commit is contained in:
maijianqiang 2021-06-19 15:08:17 +08:00
parent 6fb3981170
commit a6010c8ac3
4 changed files with 46 additions and 17 deletions

View File

@ -55,13 +55,17 @@ TEST_BUFFER_SIZE: 4
TEST_DROP_REMAINDER: False
INFERENCE: True
# ======================================================================================
#export options
device_id: 0
batch_size: 1
file_name: "psenet"
file_format: "MINDIR"
# ======================================================================================
#postprocess
result_path: "./scripts/result_Files"
img_path: ""
---
# Help description for each configuration
@ -81,3 +85,5 @@ file_name: "output file name"
file_format: "file format choices[AIR, MINDIR, ONNX]"
object_home: "your direction name"
modelarts_home: "modelarts working path"
result_path: "result Files path."
img_path: "image files path."

View File

@ -18,19 +18,19 @@ import os
import math
import operator
from functools import reduce
import argparse
import numpy as np
import cv2
from src.config import config
from src.model_utils.config import config
from src.ETSNET.pse import pse
def sort_to_clockwise(points):
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
return clockwise_points
def write_result_as_txt(image_name, img_bboxes, path):
if not os.path.isdir(path):
os.makedirs(path)
@ -51,10 +51,6 @@ def get_img(image_path):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
parser = argparse.ArgumentParser(description='postprocess')
parser.add_argument("--result_path", type=str, default="./scripts/result_Files", help='result Files path.')
parser.add_argument("--img_path", type=str, default="", help='image files path.')
args = parser.parse_args()
if __name__ == "__main__":
if not os.path.isdir('./res/submit_ic15/'):
@ -62,17 +58,17 @@ if __name__ == "__main__":
if not os.path.isdir('./res/vis_ic15/'):
os.makedirs('./res/vis_ic15/')
file_list = os.listdir(args.img_path)
file_list = os.listdir(config.img_path)
for k in file_list:
if os.path.splitext(k)[-1].lower() in ['.jpg', '.jpeg', '.png']:
img_path = os.path.join(args.img_path, k)
img_path = os.path.join(config.img_path, k)
img = get_img(img_path).reshape(1, 720, 1280, 3)
img = img[0].astype(np.uint8).copy()
img_name = os.path.split(img_path)[-1]
score = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_0.bin'), np.float32)
score = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_0.bin'), np.float32)
score = score.reshape(1, 1, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
kernels = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_1.bin'), bool)
kernels = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_1.bin'), bool)
kernels = kernels.reshape(1, config.KERNEL_NUM, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
score = np.squeeze(score)
kernels = np.squeeze(kernels)

View File

@ -84,7 +84,7 @@ def modelarts_pre_process():
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
os.system('cd {}&&sed -i s/\r// scripts/run_eval_ascend.sh')
os.system('cd {}&&sed -i s/\r// scripts/run_eval_ascend.sh'.format(local_path))
def modelarts_post_process():

View File

@ -14,7 +14,8 @@
# ============================================================================
from ast import literal_eval as liter
import ast
import operator
import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init
@ -37,6 +38,32 @@ set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
binOps = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Mod: operator.mod
}
def arithmeticeval(s):
node = ast.parse(s, mode='eval')
def _eval(node):
if isinstance(node, ast.BinOp):
return binOps[type(node.op)](_eval(node.left), _eval(node.right))
if isinstance(node, ast.Num):
return node.n
if isinstance(node, ast.Expression):
return _eval(node.body)
raise Exception('unsupported type{}'.format(node))
return _eval(node.body)
def modelarts_pre_process():
pass
@ -44,8 +71,8 @@ def modelarts_pre_process():
@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
rank_id = 0
config.BASE_LR = liter(config.BASE_LR)
config.WARMUP_RATIO = liter(config.WARMUP_RATIO)
config.BASE_LR = arithmeticeval(config.BASE_LR)
config.WARMUP_RATIO = arithmeticeval(config.WARMUP_RATIO)
device_num = get_device_num()
if config.run_distribute: