forked from mindspore-Ecosystem/mindspore
clear pylint warning
This commit is contained in:
parent
76a2f7c69d
commit
453439367c
|
@ -137,7 +137,7 @@ def ssd_bboxes_encode(boxes):
|
|||
num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
|
||||
return bboxes, t_label.astype(np.int32), num_match_num
|
||||
|
||||
def ssd_bboxes_decode(boxes, index, image_shape):
|
||||
def ssd_bboxes_decode(boxes, index):
|
||||
"""Decode predict boxes to [x, y, w, h]"""
|
||||
boxes_t = boxes[index]
|
||||
default_boxes_t = default_boxes[index]
|
||||
|
|
|
@ -110,14 +110,12 @@ def metrics(pred_data):
|
|||
pred_boxes = sample['boxes']
|
||||
boxes_scores = sample['box_scores']
|
||||
annotation = sample['annotation']
|
||||
image_shape = sample['image_shape']
|
||||
|
||||
annotation = np.squeeze(annotation, axis=0)
|
||||
image_shape = np.squeeze(image_shape, axis=0)
|
||||
|
||||
pred_labels = np.argmax(boxes_scores, axis=-1)
|
||||
index = np.nonzero(pred_labels)
|
||||
pred_boxes = ssd_bboxes_decode(pred_boxes, index, image_shape)
|
||||
pred_boxes = ssd_bboxes_decode(pred_boxes, index)
|
||||
|
||||
pred_boxes = pred_boxes.clip(0, 1)
|
||||
boxes_scores = np.max(boxes_scores, axis=-1)
|
||||
|
|
|
@ -60,7 +60,7 @@ def init_net_param(net, init='ones'):
|
|||
p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="YOLOv3 train")
|
||||
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
|
||||
"Mindrecord, default is false.")
|
||||
|
@ -153,3 +153,6 @@ if __name__ == '__main__':
|
|||
dataset_sink_mode = True
|
||||
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue