forked from mindspore-Ecosystem/mindspore
fix bug train.py and test.py
This commit is contained in:
parent
6fb3981170
commit
a6010c8ac3
|
@ -55,13 +55,17 @@ TEST_BUFFER_SIZE: 4
|
||||||
TEST_DROP_REMAINDER: False
|
TEST_DROP_REMAINDER: False
|
||||||
INFERENCE: True
|
INFERENCE: True
|
||||||
|
|
||||||
|
# ======================================================================================
|
||||||
#export options
|
#export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
file_name: "psenet"
|
file_name: "psenet"
|
||||||
file_format: "MINDIR"
|
file_format: "MINDIR"
|
||||||
|
|
||||||
|
# ======================================================================================
|
||||||
|
#postprocess
|
||||||
|
result_path: "./scripts/result_Files"
|
||||||
|
img_path: ""
|
||||||
|
|
||||||
---
|
---
|
||||||
# Help description for each configuration
|
# Help description for each configuration
|
||||||
|
@ -80,4 +84,6 @@ batch_size: "batch size"
|
||||||
file_name: "output file name"
|
file_name: "output file name"
|
||||||
file_format: "file format choices[AIR, MINDIR, ONNX]"
|
file_format: "file format choices[AIR, MINDIR, ONNX]"
|
||||||
object_home: "your direction name"
|
object_home: "your direction name"
|
||||||
modelarts_home: "modelarts working path"
|
modelarts_home: "modelarts working path"
|
||||||
|
result_path: "result Files path."
|
||||||
|
img_path: "image files path."
|
||||||
|
|
|
@ -18,19 +18,19 @@ import os
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import argparse
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
from src.model_utils.config import config
|
||||||
from src.config import config
|
|
||||||
from src.ETSNET.pse import pse
|
from src.ETSNET.pse import pse
|
||||||
|
|
||||||
|
|
||||||
def sort_to_clockwise(points):
|
def sort_to_clockwise(points):
|
||||||
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
|
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
|
||||||
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
|
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
|
||||||
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
|
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
|
||||||
return clockwise_points
|
return clockwise_points
|
||||||
|
|
||||||
|
|
||||||
def write_result_as_txt(image_name, img_bboxes, path):
|
def write_result_as_txt(image_name, img_bboxes, path):
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
@ -51,10 +51,6 @@ def get_img(image_path):
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='postprocess')
|
|
||||||
parser.add_argument("--result_path", type=str, default="./scripts/result_Files", help='result Files path.')
|
|
||||||
parser.add_argument("--img_path", type=str, default="", help='image files path.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if not os.path.isdir('./res/submit_ic15/'):
|
if not os.path.isdir('./res/submit_ic15/'):
|
||||||
|
@ -62,17 +58,17 @@ if __name__ == "__main__":
|
||||||
if not os.path.isdir('./res/vis_ic15/'):
|
if not os.path.isdir('./res/vis_ic15/'):
|
||||||
os.makedirs('./res/vis_ic15/')
|
os.makedirs('./res/vis_ic15/')
|
||||||
|
|
||||||
file_list = os.listdir(args.img_path)
|
file_list = os.listdir(config.img_path)
|
||||||
for k in file_list:
|
for k in file_list:
|
||||||
if os.path.splitext(k)[-1].lower() in ['.jpg', '.jpeg', '.png']:
|
if os.path.splitext(k)[-1].lower() in ['.jpg', '.jpeg', '.png']:
|
||||||
img_path = os.path.join(args.img_path, k)
|
img_path = os.path.join(config.img_path, k)
|
||||||
img = get_img(img_path).reshape(1, 720, 1280, 3)
|
img = get_img(img_path).reshape(1, 720, 1280, 3)
|
||||||
img = img[0].astype(np.uint8).copy()
|
img = img[0].astype(np.uint8).copy()
|
||||||
img_name = os.path.split(img_path)[-1]
|
img_name = os.path.split(img_path)[-1]
|
||||||
|
|
||||||
score = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_0.bin'), np.float32)
|
score = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_0.bin'), np.float32)
|
||||||
score = score.reshape(1, 1, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
|
score = score.reshape(1, 1, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
|
||||||
kernels = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_1.bin'), bool)
|
kernels = np.fromfile(os.path.join(config.result_path, k.split('.')[0] + '_1.bin'), bool)
|
||||||
kernels = kernels.reshape(1, config.KERNEL_NUM, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
|
kernels = kernels.reshape(1, config.KERNEL_NUM, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE)
|
||||||
score = np.squeeze(score)
|
score = np.squeeze(score)
|
||||||
kernels = np.squeeze(kernels)
|
kernels = np.squeeze(kernels)
|
||||||
|
|
|
@ -84,7 +84,7 @@ def modelarts_pre_process():
|
||||||
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
|
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
|
||||||
|
|
||||||
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
|
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
|
||||||
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh')
|
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh'.format(local_path))
|
||||||
|
|
||||||
|
|
||||||
def modelarts_post_process():
|
def modelarts_post_process():
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
from ast import literal_eval as liter
|
import ast
|
||||||
|
import operator
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init
|
||||||
|
@ -37,6 +38,32 @@ set_seed(1)
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
|
||||||
|
|
||||||
|
|
||||||
|
binOps = {
|
||||||
|
ast.Add: operator.add,
|
||||||
|
ast.Sub: operator.sub,
|
||||||
|
ast.Mult: operator.mul,
|
||||||
|
ast.Div: operator.truediv,
|
||||||
|
ast.Mod: operator.mod
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def arithmeticeval(s):
|
||||||
|
node = ast.parse(s, mode='eval')
|
||||||
|
|
||||||
|
def _eval(node):
|
||||||
|
if isinstance(node, ast.BinOp):
|
||||||
|
return binOps[type(node.op)](_eval(node.left), _eval(node.right))
|
||||||
|
|
||||||
|
if isinstance(node, ast.Num):
|
||||||
|
return node.n
|
||||||
|
|
||||||
|
if isinstance(node, ast.Expression):
|
||||||
|
return _eval(node.body)
|
||||||
|
|
||||||
|
raise Exception('unsupported type{}'.format(node))
|
||||||
|
return _eval(node.body)
|
||||||
|
|
||||||
|
|
||||||
def modelarts_pre_process():
|
def modelarts_pre_process():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -44,8 +71,8 @@ def modelarts_pre_process():
|
||||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
def train():
|
def train():
|
||||||
rank_id = 0
|
rank_id = 0
|
||||||
config.BASE_LR = liter(config.BASE_LR)
|
config.BASE_LR = arithmeticeval(config.BASE_LR)
|
||||||
config.WARMUP_RATIO = liter(config.WARMUP_RATIO)
|
config.WARMUP_RATIO = arithmeticeval(config.WARMUP_RATIO)
|
||||||
|
|
||||||
device_num = get_device_num()
|
device_num = get_device_num()
|
||||||
if config.run_distribute:
|
if config.run_distribute:
|
||||||
|
|
Loading…
Reference in New Issue