forked from mindspore-Ecosystem/mindspore
fix bug train.py and test.py
This commit is contained in:
parent
6fb3981170
commit
a6010c8ac3
|
@ -55,13 +55,17 @@ TEST_BUFFER_SIZE: 4
|
|||
TEST_DROP_REMAINDER: False
|
||||
INFERENCE: True
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
#export options
|
||||
device_id: 0
|
||||
batch_size: 1
|
||||
file_name: "psenet"
|
||||
file_format: "MINDIR"
|
||||
|
||||
# ======================================================================================
|
||||
#postprocess
|
||||
result_path: "./scripts/result_Files"
|
||||
img_path: ""
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
|
@ -80,4 +84,6 @@ batch_size: "batch size"
|
|||
file_name: "output file name"
|
||||
file_format: "file format choices[AIR, MINDIR, ONNX]"
|
||||
object_home: "your direction name"
|
||||
modelarts_home: "modelarts working path"
|
||||
modelarts_home: "modelarts working path"
|
||||
result_path: "result Files path."
|
||||
img_path: "image files path."
|
||||
|
|
|
@ -18,19 +18,19 @@ import os
|
|||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
import argparse
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
from src.ETSNET.pse import pse
|
||||
|
||||
|
||||
def sort_to_clockwise(points):
|
||||
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
|
||||
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
|
||||
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
|
||||
return clockwise_points
|
||||
|
||||
|
||||
def write_result_as_txt(image_name, img_bboxes, path):
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
|
@ -51,10 +51,6 @@ def get_img(image_path):
|
|||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
return image
|
||||
|
||||
parser = argparse.ArgumentParser(description='postprocess')
|
||||
parser.add_argument("--result_path", type=str, default="./scripts/result_Files", help='result Files path.')
|
||||
parser.add_argument("--img_path", type=str, default="", help='image files path.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.isdir('./res/submit_ic15/'):
|
||||
|
@ -62,17 +58,17 @@ if __name__ == "__main__":
|
|||
if not os.path.isdir('./res/vis_ic15/'):
|
||||
os.makedirs('./res/vis_ic15/')
|
||||
|
||||
file_list = os.listdir(args.img_path)
|
||||
file_list = os.listdir(config.img_path)
|
||||
for k in file_list:
|
||||
if os.path.splitext(k)[-1].lower() in ['.jpg', '.jpeg', '.png']:
|
||||
img_path = os.path.join(args.img_path, k)
|
||||
img_path = os.path.join(config.img_path, k)
|
||||
img = get_img(img_path).reshape(1, 720, 1280, 3)
|
||||
img = img[0].astype(np.uint8).copy()
|
||||
img_name = os.path.split(img_path)[-1]
|
||||
|
||||
score = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_0.bin'), np.float32)
|
||||
score = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_0.bin'), np.float32)
|
||||
score = score.reshape(1, 1, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
|
||||
kernels = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_1.bin'), bool)
|
||||
kernels = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_1.bin'), bool)
|
||||
kernels = kernels.reshape(1, config.KERNEL_NUM, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
|
||||
score = np.squeeze(score)
|
||||
kernels = np.squeeze(kernels)
|
||||
|
|
|
@ -84,7 +84,7 @@ def modelarts_pre_process():
|
|||
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
|
||||
|
||||
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
|
||||
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh')
|
||||
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh'.format(local_path))
|
||||
|
||||
|
||||
def modelarts_post_process():
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# ============================================================================
|
||||
|
||||
|
||||
from ast import literal_eval as liter
|
||||
import ast
|
||||
import operator
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
|
@ -37,6 +38,32 @@ set_seed(1)
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
|
||||
|
||||
|
||||
binOps = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Mod: operator.mod
|
||||
}
|
||||
|
||||
|
||||
def arithmeticeval(s):
|
||||
node = ast.parse(s, mode='eval')
|
||||
|
||||
def _eval(node):
|
||||
if isinstance(node, ast.BinOp):
|
||||
return binOps[type(node.op)](_eval(node.left), _eval(node.right))
|
||||
|
||||
if isinstance(node, ast.Num):
|
||||
return node.n
|
||||
|
||||
if isinstance(node, ast.Expression):
|
||||
return _eval(node.body)
|
||||
|
||||
raise Exception('unsupported type{}'.format(node))
|
||||
return _eval(node.body)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
@ -44,8 +71,8 @@ def modelarts_pre_process():
|
|||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
rank_id = 0
|
||||
config.BASE_LR = liter(config.BASE_LR)
|
||||
config.WARMUP_RATIO = liter(config.WARMUP_RATIO)
|
||||
config.BASE_LR = arithmeticeval(config.BASE_LR)
|
||||
config.WARMUP_RATIO = arithmeticeval(config.WARMUP_RATIO)
|
||||
|
||||
device_num = get_device_num()
|
||||
if config.run_distribute:
|
||||
|
|
Loading…
Reference in New Issue