some words misspelled,it has modified

This commit is contained in:
wsq3 2021-01-29 15:54:11 +08:00
parent 0dd77bf9dc
commit e031e60a6e
7 changed files with 24 additions and 23 deletions

View File

@ -87,7 +87,7 @@ class DetectionEngine:
def _nms(self, predicts, threshold): def _nms(self, predicts, threshold):
"""Calculate NMS.""" """Calculate NMS."""
# conver xywh -> xmin ymin xmax ymax # convert xywh -> xmin ymin xmax ymax
x1 = predicts[:, 0] x1 = predicts[:, 0]
y1 = predicts[:, 1] y1 = predicts[:, 1]
x2 = x1 + predicts[:, 2] x2 = x1 + predicts[:, 2]
@ -111,8 +111,8 @@ class DetectionEngine:
intersect_area = intersect_w * intersect_h intersect_area = intersect_w * intersect_h
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area) ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
indexs = np.where(ovr <= threshold)[0] indexes = np.where(ovr <= threshold)[0]
order = order[indexs + 1] order = order[indexes + 1]
return reserved_boxes return reserved_boxes
def write_result(self): def write_result(self):
@ -179,7 +179,7 @@ class DetectionEngine:
x_top_left = x - w / 2. x_top_left = x - w / 2.
y_top_left = y - h / 2. y_top_left = y - h / 2.
# creat all False # create all False
flag = np.random.random(cls_emb.shape) > sys.maxsize flag = np.random.random(cls_emb.shape) > sys.maxsize
for i in range(flag.shape[0]): for i in range(flag.shape[0]):
c = cls_argmax[i] c = cls_argmax[i]

View File

@ -58,7 +58,7 @@ cp ../*.py ./eval
cp -r ../src ./eval cp -r ../src ./eval
cd ./eval || exit cd ./eval || exit
env > env.log env > env.log
echo "start infering for device $DEVICE_ID" echo "start inferring for device $DEVICE_ID"
python eval.py \ python eval.py \
--data_dir=$DATASET_PATH \ --data_dir=$DATASET_PATH \
--pretrained=$CHECKPOINT_PATH \ --pretrained=$CHECKPOINT_PATH \

View File

@ -58,7 +58,7 @@ cp ../*.py ./eval
cp -r ../src ./eval cp -r ../src ./eval
cd ./eval || exit cd ./eval || exit
env > env.log env > env.log
echo "start infering for device $DEVICE_ID" echo "start inferring for device $DEVICE_ID"
python eval.py \ python eval.py \
--device_target="GPU" \ --device_target="GPU" \
--data_dir=$DATASET_PATH \ --data_dir=$DATASET_PATH \

View File

@ -39,7 +39,7 @@ def build_network():
def convert(weights_file, output_file): def convert(weights_file, output_file):
"""Conver weight to mindspore ckpt.""" """Convert weight to mindspore ckpt."""
params = build_network() params = build_network()
weights = load_weight(weights_file) weights = load_weight(weights_file)
index = 0 index = 0

View File

@ -59,7 +59,7 @@ class YoloBlock(nn.Cell):
Args: Args:
in_channels: Integer. Input channel. in_channels: Integer. Input channel.
out_chls: Interger. Middle channel. out_chls: Integer. Middle channel.
out_channels: Integer. Output channel. out_channels: Integer. Output channel.
Returns: Returns:
@ -108,7 +108,7 @@ class YOLOv3(nn.Cell):
Args: Args:
backbone_shape: List. Darknet output channels shape. backbone_shape: List. Darknet output channels shape.
backbone: Cell. Backbone Network. backbone: Cell. Backbone Network.
out_channel: Interger. Output channel. out_channel: Integer. Output channel.
Returns: Returns:
Tensor, output tensor. Tensor, output tensor.

View File

@ -45,7 +45,7 @@ def has_valid_annotation(anno):
# if all boxes have close to zero area, there is no annotation # if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno): if _has_only_empty_bbox(anno):
return False return False
# keypoints task have a slight different critera for considering # keypoints task have a slight different criteria for considering
# if an annotation is valid # if an annotation is valid
if "keypoints" not in anno[0]: if "keypoints" not in anno[0]:
return True return True

View File

@ -131,9 +131,7 @@ def conver_training_shape(args):
return training_shape return training_shape
def train(): def network_init(args):
"""Train function."""
args = parse_args()
devid = int(os.getenv('DEVICE_ID', '0')) devid = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=True, device_id=devid) device_target=args.device_target, save_graphs=True, device_id=devid)
@ -145,26 +143,21 @@ def train():
init("nccl") init("nccl")
args.rank = get_rank() args.rank = get_rank()
args.group_size = get_group_size() args.group_size = get_group_size()
# select for master rank save ckpt or all rank save, compatiable for model parallel # select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0 args.rank_save_ckpt_flag = 0
if args.is_save_on_master: if args.is_save_on_master:
if args.rank == 0: if args.rank == 0:
args.rank_save_ckpt_flag = 1 args.rank_save_ckpt_flag = 1
else: else:
args.rank_save_ckpt_flag = 1 args.rank_save_ckpt_flag = 1
# logger # logger
args.outputs_dir = os.path.join(args.ckpt_path, args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank) args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args) args.logger.save_args(args)
if args.need_profiler:
from mindspore.profiler.profiling import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
loss_meter = AverageMeter('loss')
def parallel_init(args):
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE parallel_mode = ParallelMode.STAND_ALONE
degree = 1 degree = 1
@ -173,6 +166,17 @@ def train():
degree = get_group_size() degree = get_group_size()
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree) context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
def train():
"""Train function."""
args = parse_args()
network_init(args)
if args.need_profiler:
from mindspore.profiler.profiling import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
loss_meter = AverageMeter('loss')
parallel_init(args)
network = YOLOV3DarkNet53(is_training=True) network = YOLOV3DarkNet53(is_training=True)
# default is kaiming-normal # default is kaiming-normal
default_recurisive_init(network) default_recurisive_init(network)
@ -182,7 +186,6 @@ def train():
args.logger.info('finish get network') args.logger.info('finish get network')
config = ConfigYOLOV3DarkNet53() config = ConfigYOLOV3DarkNet53()
config.label_smooth = args.label_smooth config.label_smooth = args.label_smooth
config.label_smooth_factor = args.label_smooth_factor config.label_smooth_factor = args.label_smooth_factor
@ -202,7 +205,6 @@ def train():
args.ckpt_interval = args.steps_per_epoch args.ckpt_interval = args.steps_per_epoch
lr = get_lr(args) lr = get_lr(args)
opt = Momentum(params=get_param_groups(network), opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr), learning_rate=Tensor(lr),
momentum=args.momentum, momentum=args.momentum,
@ -281,7 +283,6 @@ def train():
if i == 10: if i == 10:
profiler.analyse() profiler.analyse()
break break
args.logger.info('==========end training===============') args.logger.info('==========end training===============')