forked from mindspore-Ecosystem/mindspore
some words misspelled,it has modified
This commit is contained in:
parent
0dd77bf9dc
commit
e031e60a6e
|
@ -87,7 +87,7 @@ class DetectionEngine:
|
||||||
|
|
||||||
def _nms(self, predicts, threshold):
|
def _nms(self, predicts, threshold):
|
||||||
"""Calculate NMS."""
|
"""Calculate NMS."""
|
||||||
# conver xywh -> xmin ymin xmax ymax
|
# convert xywh -> xmin ymin xmax ymax
|
||||||
x1 = predicts[:, 0]
|
x1 = predicts[:, 0]
|
||||||
y1 = predicts[:, 1]
|
y1 = predicts[:, 1]
|
||||||
x2 = x1 + predicts[:, 2]
|
x2 = x1 + predicts[:, 2]
|
||||||
|
@ -111,8 +111,8 @@ class DetectionEngine:
|
||||||
intersect_area = intersect_w * intersect_h
|
intersect_area = intersect_w * intersect_h
|
||||||
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||||
|
|
||||||
indexs = np.where(ovr <= threshold)[0]
|
indexes = np.where(ovr <= threshold)[0]
|
||||||
order = order[indexs + 1]
|
order = order[indexes + 1]
|
||||||
return reserved_boxes
|
return reserved_boxes
|
||||||
|
|
||||||
def write_result(self):
|
def write_result(self):
|
||||||
|
@ -179,7 +179,7 @@ class DetectionEngine:
|
||||||
|
|
||||||
x_top_left = x - w / 2.
|
x_top_left = x - w / 2.
|
||||||
y_top_left = y - h / 2.
|
y_top_left = y - h / 2.
|
||||||
# creat all False
|
# create all False
|
||||||
flag = np.random.random(cls_emb.shape) > sys.maxsize
|
flag = np.random.random(cls_emb.shape) > sys.maxsize
|
||||||
for i in range(flag.shape[0]):
|
for i in range(flag.shape[0]):
|
||||||
c = cls_argmax[i]
|
c = cls_argmax[i]
|
||||||
|
|
|
@ -58,7 +58,7 @@ cp ../*.py ./eval
|
||||||
cp -r ../src ./eval
|
cp -r ../src ./eval
|
||||||
cd ./eval || exit
|
cd ./eval || exit
|
||||||
env > env.log
|
env > env.log
|
||||||
echo "start infering for device $DEVICE_ID"
|
echo "start inferring for device $DEVICE_ID"
|
||||||
python eval.py \
|
python eval.py \
|
||||||
--data_dir=$DATASET_PATH \
|
--data_dir=$DATASET_PATH \
|
||||||
--pretrained=$CHECKPOINT_PATH \
|
--pretrained=$CHECKPOINT_PATH \
|
||||||
|
|
|
@ -58,7 +58,7 @@ cp ../*.py ./eval
|
||||||
cp -r ../src ./eval
|
cp -r ../src ./eval
|
||||||
cd ./eval || exit
|
cd ./eval || exit
|
||||||
env > env.log
|
env > env.log
|
||||||
echo "start infering for device $DEVICE_ID"
|
echo "start inferring for device $DEVICE_ID"
|
||||||
python eval.py \
|
python eval.py \
|
||||||
--device_target="GPU" \
|
--device_target="GPU" \
|
||||||
--data_dir=$DATASET_PATH \
|
--data_dir=$DATASET_PATH \
|
||||||
|
|
|
@ -39,7 +39,7 @@ def build_network():
|
||||||
|
|
||||||
|
|
||||||
def convert(weights_file, output_file):
|
def convert(weights_file, output_file):
|
||||||
"""Conver weight to mindspore ckpt."""
|
"""Convert weight to mindspore ckpt."""
|
||||||
params = build_network()
|
params = build_network()
|
||||||
weights = load_weight(weights_file)
|
weights = load_weight(weights_file)
|
||||||
index = 0
|
index = 0
|
||||||
|
|
|
@ -59,7 +59,7 @@ class YoloBlock(nn.Cell):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels: Integer. Input channel.
|
in_channels: Integer. Input channel.
|
||||||
out_chls: Interger. Middle channel.
|
out_chls: Integer. Middle channel.
|
||||||
out_channels: Integer. Output channel.
|
out_channels: Integer. Output channel.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -108,7 +108,7 @@ class YOLOv3(nn.Cell):
|
||||||
Args:
|
Args:
|
||||||
backbone_shape: List. Darknet output channels shape.
|
backbone_shape: List. Darknet output channels shape.
|
||||||
backbone: Cell. Backbone Network.
|
backbone: Cell. Backbone Network.
|
||||||
out_channel: Interger. Output channel.
|
out_channel: Integer. Output channel.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, output tensor.
|
Tensor, output tensor.
|
||||||
|
|
|
@ -45,7 +45,7 @@ def has_valid_annotation(anno):
|
||||||
# if all boxes have close to zero area, there is no annotation
|
# if all boxes have close to zero area, there is no annotation
|
||||||
if _has_only_empty_bbox(anno):
|
if _has_only_empty_bbox(anno):
|
||||||
return False
|
return False
|
||||||
# keypoints task have a slight different critera for considering
|
# keypoints task have a slight different criteria for considering
|
||||||
# if an annotation is valid
|
# if an annotation is valid
|
||||||
if "keypoints" not in anno[0]:
|
if "keypoints" not in anno[0]:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -131,9 +131,7 @@ def conver_training_shape(args):
|
||||||
return training_shape
|
return training_shape
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def network_init(args):
|
||||||
"""Train function."""
|
|
||||||
args = parse_args()
|
|
||||||
devid = int(os.getenv('DEVICE_ID', '0'))
|
devid = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||||
device_target=args.device_target, save_graphs=True, device_id=devid)
|
device_target=args.device_target, save_graphs=True, device_id=devid)
|
||||||
|
@ -145,26 +143,21 @@ def train():
|
||||||
init("nccl")
|
init("nccl")
|
||||||
args.rank = get_rank()
|
args.rank = get_rank()
|
||||||
args.group_size = get_group_size()
|
args.group_size = get_group_size()
|
||||||
# select for master rank save ckpt or all rank save, compatiable for model parallel
|
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||||
args.rank_save_ckpt_flag = 0
|
args.rank_save_ckpt_flag = 0
|
||||||
if args.is_save_on_master:
|
if args.is_save_on_master:
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
args.rank_save_ckpt_flag = 1
|
args.rank_save_ckpt_flag = 1
|
||||||
else:
|
else:
|
||||||
args.rank_save_ckpt_flag = 1
|
args.rank_save_ckpt_flag = 1
|
||||||
|
|
||||||
# logger
|
# logger
|
||||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||||
args.logger.save_args(args)
|
args.logger.save_args(args)
|
||||||
|
|
||||||
if args.need_profiler:
|
|
||||||
from mindspore.profiler.profiling import Profiler
|
|
||||||
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
|
|
||||||
|
|
||||||
loss_meter = AverageMeter('loss')
|
|
||||||
|
|
||||||
|
def parallel_init(args):
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
parallel_mode = ParallelMode.STAND_ALONE
|
parallel_mode = ParallelMode.STAND_ALONE
|
||||||
degree = 1
|
degree = 1
|
||||||
|
@ -173,6 +166,17 @@ def train():
|
||||||
degree = get_group_size()
|
degree = get_group_size()
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""Train function."""
|
||||||
|
args = parse_args()
|
||||||
|
network_init(args)
|
||||||
|
if args.need_profiler:
|
||||||
|
from mindspore.profiler.profiling import Profiler
|
||||||
|
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
|
||||||
|
|
||||||
|
loss_meter = AverageMeter('loss')
|
||||||
|
parallel_init(args)
|
||||||
|
|
||||||
network = YOLOV3DarkNet53(is_training=True)
|
network = YOLOV3DarkNet53(is_training=True)
|
||||||
# default is kaiming-normal
|
# default is kaiming-normal
|
||||||
default_recurisive_init(network)
|
default_recurisive_init(network)
|
||||||
|
@ -182,7 +186,6 @@ def train():
|
||||||
args.logger.info('finish get network')
|
args.logger.info('finish get network')
|
||||||
|
|
||||||
config = ConfigYOLOV3DarkNet53()
|
config = ConfigYOLOV3DarkNet53()
|
||||||
|
|
||||||
config.label_smooth = args.label_smooth
|
config.label_smooth = args.label_smooth
|
||||||
config.label_smooth_factor = args.label_smooth_factor
|
config.label_smooth_factor = args.label_smooth_factor
|
||||||
|
|
||||||
|
@ -202,7 +205,6 @@ def train():
|
||||||
args.ckpt_interval = args.steps_per_epoch
|
args.ckpt_interval = args.steps_per_epoch
|
||||||
|
|
||||||
lr = get_lr(args)
|
lr = get_lr(args)
|
||||||
|
|
||||||
opt = Momentum(params=get_param_groups(network),
|
opt = Momentum(params=get_param_groups(network),
|
||||||
learning_rate=Tensor(lr),
|
learning_rate=Tensor(lr),
|
||||||
momentum=args.momentum,
|
momentum=args.momentum,
|
||||||
|
@ -281,7 +283,6 @@ def train():
|
||||||
if i == 10:
|
if i == 10:
|
||||||
profiler.analyse()
|
profiler.analyse()
|
||||||
break
|
break
|
||||||
|
|
||||||
args.logger.info('==========end training===============')
|
args.logger.info('==========end training===============')
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue