forked from mindspore-Ecosystem/mindspore
!16568 modify deeptext network for clould
From: @zhanghuiyao Reviewed-by: @oacjiewen,@wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
c2413c2b0c
|
@ -65,6 +65,12 @@ Here we used 4 datasets for training, and 1 datasets for Evaluation.
|
|||
└─deeptext
|
||||
├─README.md
|
||||
├─ascend310_infer #application for 310 inference
|
||||
├─model_utils
|
||||
├─__init__.py # package init file
|
||||
├─config.py # Parse arguments
|
||||
├─device_adapter.py # Device adapter for ModelArts
|
||||
├─local_adapter.py # Local adapter
|
||||
└─moxing_adapter.py # Moxing adapter for ModelArts
|
||||
├─scripts
|
||||
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
||||
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
||||
|
@ -82,12 +88,12 @@ Here we used 4 datasets for training, and 1 datasets for Evaluation.
|
|||
├─roi_align.py # roi_align cell wrapper
|
||||
├─rpn.py # region-proposal network
|
||||
└─vgg16.py # backbone
|
||||
├─config.py # training configuration
|
||||
├─aipp.cfg # aipp config file
|
||||
├─dataset.py # data proprocessing
|
||||
├─lr_schedule.py # learning rate scheduler
|
||||
├─network_define.py # network definition
|
||||
└─utils.py # some functions which is commonly used
|
||||
├─default_config.yaml # configurations
|
||||
├─eval.py # eval net
|
||||
├─export.py # export checkpoint, surpport .onnx, .air, .mindir convert
|
||||
├─postprogress.py # post process for 310 inference
|
||||
|
@ -117,6 +123,96 @@ sh run_eval_ascend.sh [IMGS_PATH] [ANNOS_PATH] [CHECKPOINT_PATH] [COCO_TEXT_PARS
|
|||
> The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.
|
||||
> COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).
|
||||
>
|
||||
|
||||
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||
|
||||
```bash
|
||||
# Train 8p on ModelArts
|
||||
# (1) copy [COCO_TEXT_PARSER_PATH] file to /CODE_PATH/deeptext/src/
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "imgs_path='/YOUR IMGS_PATH/'" on default_config.yaml file.
|
||||
# Set "annos_path='/YOUR ANNOS_PATH/'" on default_config.yaml file.
|
||||
# Set "run_distribute=True" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_your_pretrain/'" on default_config.yaml file.
|
||||
# Set "pre_trained='/cache/checkpoint_path/YOUR PRETRAINED_PATH/'" on default_config.yaml file.
|
||||
# Set "mindrecord_dir='/cache/data/deeptext_dataset/mindrecord'" on default_config.yaml file.
|
||||
# Set "coco_root='/cache/data/deeptext_dataset/coco2017'" on default_config.yaml file.
|
||||
# Set "cocotext_json='/cache/data/deeptext_dataset/cocotext.v2.json'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "imgs_path=/YOUR IMGS_PATH/" on the website UI interface.
|
||||
# Add "annos_path=/YOUR ANNOS_PATH/" on the website UI interface.
|
||||
# Add "run_distribute=True" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_your_pretrain/'" on the website UI interface.
|
||||
# Add "pre_trained=/cache/checkpoint_path/YOUR PRETRAINED_PATH/" on the website UI interface.
|
||||
# Add "mindrecord_dir=/cache/data/deeptext_dataset/mindrecord" on the website UI interface.
|
||||
# Add "coco_root=/cache/data/deeptext_dataset/coco2017" on the website UI interface.
|
||||
# Add "cocotext_json=/cache/data/deeptext_dataset/cocotext.v2.json" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/deeptext" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
#
|
||||
# Train 1p on ModelArts
|
||||
# (1) copy [COCO_TEXT_PARSER_PATH] file to /CODE_PATH/deeptext/src/
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "imgs_path='/YOUR IMGS_PATH/'" on default_config.yaml file.
|
||||
# Set "annos_path='/YOUR ANNOS_PATH/'" on default_config.yaml file.
|
||||
# Set "run_distribute=False" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_your_pretrain/'" on default_config.yaml file.
|
||||
# Set "pre_trained='/cache/checkpoint_path/YOUR PRETRAINED_PATH/'" on default_config.yaml file.
|
||||
# Set "mindrecord_dir='/cache/data/deeptext_dataset/mindrecord'" on default_config.yaml file.
|
||||
# Set "coco_root='/cache/data/deeptext_dataset/coco2017'" on default_config.yaml file.
|
||||
# Set "cocotext_json='/cache/data/deeptext_dataset/cocotext.v2.json'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "imgs_path=/YOUR IMGS_PATH/" on the website UI interface.
|
||||
# Add "annos_path=/YOUR ANNOS_PATH/" on the website UI interface.
|
||||
# Add "run_distribute=False" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_your_pretrain/'" on the website UI interface.
|
||||
# Add "pre_trained=/cache/data/YOUR PRETRAINED_PATH/" on the website UI interface.
|
||||
# Add "mindrecord_dir=/cache/data/deeptext_dataset/mindrecord" on the website UI interface.
|
||||
# Add "coco_root=/cache/data/deeptext_dataset/coco2017" on the website UI interface.
|
||||
# Add "cocotext_json=/cache/data/deeptext_dataset/cocotext.v2.json" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/deeptext" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
#
|
||||
# Eval 1p on ModelArts
|
||||
# (1) copy [COCO_TEXT_PARSER_PATH] file to /CODE_PATH/deeptext/src/
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "imgs_path='/YOUR IMGS_PATH/'" on default_config.yaml file.
|
||||
# Set "annos_path='/YOUR ANNOS_PATH/'" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_trained_model/'" on default_config.yaml file.
|
||||
# Set "checkpoint_path='/cache/checkpoint_path/YOUR CHECKPOINT_PATH/'" on default_config.yaml file.
|
||||
# Set "mindrecord_dir='/cache/data/deeptext_dataset/mindrecord'" on default_config.yaml file.
|
||||
# Set "coco_root='/cache/data/deeptext_dataset/coco2017'" on default_config.yaml file.
|
||||
# Set "cocotext_json='/cache/data/deeptext_dataset/cocotext.v2.json'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "imgs_path=/YOUR IMGS_PATH/" on the website UI interface.
|
||||
# Add "annos_path=/YOUR ANNOS_PATH/" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_trained_model/'" on the website UI interface.
|
||||
# Add "checkpoint_path=/cache/checkpoint_path/YOUR CHECKPOINT_PATH/" on the website UI interface.
|
||||
# Add "mindrecord_dir=/cache/data/deeptext_dataset/mindrecord" on the website UI interface.
|
||||
# Add "coco_root=/cache/data/deeptext_dataset/coco2017" on the website UI interface.
|
||||
# Add "cocotext_json=/cache/data/deeptext_dataset/cocotext.v2.json" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/deeptext" on the website UI interface.
|
||||
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
|
@ -232,7 +328,7 @@ class 1 precision is 84.24%, recall is 87.40%, F1 is 85.79%
|
|||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | 229 images |
|
||||
| Batch_size | 2 |
|
||||
| Accuracy | precision=0.8801, recall=0.8277 |
|
||||
| Accuracy | F1 score is 84.50% |
|
||||
| Total time | 1 min |
|
||||
| Model for inference | 3492M (.ckpt file) |
|
||||
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "deeptext_dataset"
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
img_width: 960
|
||||
img_height: 576
|
||||
keep_ratio: False
|
||||
flip_ratio: 0.0
|
||||
photo_ratio: 0.0
|
||||
expand_ratio: 0.3
|
||||
|
||||
# anchor
|
||||
feature_shapes: [[36, 60]]
|
||||
anchor_scales: [2, 4, 6, 8, 12]
|
||||
anchor_ratios: [0.2, 0.5, 0.8, 1.0, 1.2, 1.5]
|
||||
anchor_strides: [16]
|
||||
num_anchors: 30 # 5*6
|
||||
|
||||
# rpn
|
||||
rpn_in_channels: 512
|
||||
rpn_feat_channels: 640
|
||||
rpn_loss_cls_weight: 1.0
|
||||
rpn_loss_reg_weight: 3.0
|
||||
rpn_cls_out_channels: 1
|
||||
rpn_target_means: [0., 0., 0., 0.]
|
||||
rpn_target_stds: [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# bbox_assign_sampler
|
||||
neg_iou_thr: 0.3
|
||||
pos_iou_thr: 0.5
|
||||
min_pos_iou: 0.3
|
||||
num_bboxes: 64800 # 5 * 6 * 36 * 60
|
||||
num_gts: 128
|
||||
num_expected_neg: 256
|
||||
num_expected_pos: 128
|
||||
|
||||
# proposal
|
||||
activate_num_classes: 2
|
||||
use_sigmoid_cls: True
|
||||
|
||||
# roi_align
|
||||
roi_layer:
|
||||
type: "RoIAlign"
|
||||
out_size: 7
|
||||
sample_num: 2
|
||||
|
||||
# bbox_assign_sampler_stage2
|
||||
neg_iou_thr_stage2: 0.2
|
||||
pos_iou_thr_stage2: 0.5
|
||||
min_pos_iou_stage2: 0.5
|
||||
num_bboxes_stage2: 2000
|
||||
use_ambigous_sample: True
|
||||
num_expected_pos_stage2: 128
|
||||
num_expected_amb_stage2: 128
|
||||
num_expected_neg_stage2: 640
|
||||
num_expected_total_stage2: 640
|
||||
|
||||
# rcnn
|
||||
rcnn_in_channels: 512
|
||||
rcnn_fc_out_channels: 4096
|
||||
rcnn_loss_cls_weight: 1
|
||||
rcnn_loss_reg_weight: 1
|
||||
rcnn_target_means: [0., 0., 0., 0.]
|
||||
rcnn_target_stds: [0.1, 0.1, 0.2, 0.2]
|
||||
|
||||
# train proposal
|
||||
rpn_proposal_nms_across_levels: False
|
||||
rpn_proposal_nms_pre: 2000
|
||||
rpn_proposal_nms_post: 2000
|
||||
rpn_proposal_max_num: 2000
|
||||
rpn_proposal_nms_thr: 0.7
|
||||
rpn_proposal_min_bbox_size: 0
|
||||
|
||||
# test proposal
|
||||
rpn_nms_across_levels: False
|
||||
rpn_nms_pre: 1000
|
||||
rpn_nms_post: 1000
|
||||
rpn_max_num: 1000
|
||||
rpn_nms_thr: 0.7
|
||||
rpn_min_bbox_min_size: 0
|
||||
test_score_thr: 0.80
|
||||
test_iou_thr: 0.5
|
||||
test_max_per_img: 100
|
||||
test_batch_size: 2
|
||||
rpn_head_loss_type: "CrossEntropyLoss"
|
||||
rpn_head_use_sigmoid: True
|
||||
rpn_head_weight: 1.0
|
||||
|
||||
# LR
|
||||
base_lr: 0.02
|
||||
base_step: 7856 # 982 * 8
|
||||
total_epoch: 70
|
||||
warmup_step: 50
|
||||
warmup_mode: "linear"
|
||||
warmup_ratio: 0.333333
|
||||
sgd_step: [8, 11]
|
||||
sgd_momentum: 0.9
|
||||
|
||||
# train
|
||||
batch_size: 2
|
||||
loss_scale: 1
|
||||
momentum: 0.91
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
epoch_size: 70
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 10
|
||||
keep_checkpoint_max: 5
|
||||
save_checkpoint_path: "./"
|
||||
|
||||
mindrecord_dir: "/PATH_to_DATA/deeptext_dataset/mindrecord"
|
||||
use_coco: True
|
||||
coco_root: "/PATH_to_DATA/deeptext_dataset/coco2017"
|
||||
cocotext_json: "/PATH_to_DATA/deeptext_dataset/cocotext.v2.json"
|
||||
coco_train_data_type: "train2017"
|
||||
num_classes: 3
|
||||
|
||||
# Train args
|
||||
run_distribute: False
|
||||
dataset: "coco"
|
||||
pre_trained: ""
|
||||
imgs_path: ""
|
||||
annos_path: ""
|
||||
mindrecord_prefix: "Deeptext-TRAIN"
|
||||
|
||||
# Eval args
|
||||
checkpoint_path: "test"
|
||||
eval_mindrecord_prefix: "Deeptext-TEST"
|
||||
|
||||
# Export args
|
||||
export_batch_size: 1
|
||||
file_name: "deeptext"
|
||||
file_format: "MINDIR"
|
||||
export_device_target: "Ascend"
|
||||
ckpt_file: ""
|
||||
|
||||
# Postprocess args
|
||||
result_path: ""
|
||||
label_path: ""
|
||||
img_path: ""
|
||||
|
||||
---
|
||||
|
||||
|
||||
# Help description for each configuration
|
||||
run_distribute: "Run distribute."
|
||||
dataset: "Dataset name."
|
||||
pre_trained: "Pretrained file path."
|
||||
device_id: "Device id."
|
||||
device_num: "Use device nums."
|
||||
rank_id: "Rank id."
|
||||
imgs_path: "Train images files paths, multiple paths can be separated by ','."
|
||||
annos_path: "Annotations files paths of train images, multiple paths can be separated by ','."
|
||||
mindrecord_prefix: "Prefix of mindrecord."
|
|
@ -14,33 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Evaluation for Deeptext"""
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
|
||||
from src.config import config
|
||||
from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset
|
||||
from src.utils import metrics
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Deeptext evaluation")
|
||||
parser.add_argument("--checkpoint_path", type=str, default='test', help="Checkpoint file path.")
|
||||
parser.add_argument("--imgs_path", type=str, required=True,
|
||||
help="Test images files paths, multiple paths can be separated by ','.")
|
||||
parser.add_argument("--annos_path", type=str, required=True,
|
||||
help="Annotations files paths of test images, multiple paths can be separated by ','.")
|
||||
parser.add_argument("--device_id", type=int, default=7, help="Device id, default is 7.")
|
||||
parser.add_argument("--mindrecord_prefix", type=str, default='Deeptext-TEST', help="Prefix of mindrecord.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
|
||||
|
||||
|
||||
def deeptext_eval_test(dataset_path='', ckpt_path=''):
|
||||
|
@ -55,8 +47,8 @@ def deeptext_eval_test(dataset_path='', ckpt_path=''):
|
|||
net.set_train(False)
|
||||
eval_iter = 0
|
||||
|
||||
print("\n========================================\n")
|
||||
print("Processing, please wait a moment.")
|
||||
print("\n========================================\n", flush=True)
|
||||
print("Processing, please wait a moment.", flush=True)
|
||||
max_num = 32
|
||||
|
||||
pred_data = []
|
||||
|
@ -75,12 +67,12 @@ def deeptext_eval_test(dataset_path='', ckpt_path=''):
|
|||
gt_bboxes = gt_bboxes.asnumpy()
|
||||
|
||||
gt_bboxes = gt_bboxes[gt_num.asnumpy().astype(bool), :]
|
||||
print(gt_bboxes)
|
||||
print(gt_bboxes, flush=True)
|
||||
gt_labels = gt_labels.asnumpy()
|
||||
gt_labels = gt_labels[gt_num.asnumpy().astype(bool)]
|
||||
print(gt_labels)
|
||||
print(gt_labels, flush=True)
|
||||
end = time.time()
|
||||
print("Iter {} cost time {}".format(eval_iter, end - start))
|
||||
print("Iter {} cost time {}".format(eval_iter, end - start), flush=True)
|
||||
|
||||
# output
|
||||
all_bbox = output[0]
|
||||
|
@ -108,33 +100,91 @@ def deeptext_eval_test(dataset_path='', ckpt_path=''):
|
|||
|
||||
percent = round(eval_iter / total * 100, 2)
|
||||
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r', flush=True)
|
||||
|
||||
precisions, recalls = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
print("\n========================================\n", flush=True)
|
||||
for i in range(config.num_classes - 1):
|
||||
j = i + 1
|
||||
f1 = (2 * precisions[j] * recalls[j]) / (precisions[j] + recalls[j] + 1e-6)
|
||||
print("class {} precision is {:.2f}%, recall is {:.2f}%,"
|
||||
"F1 is {:.2f}%".format(j, precisions[j] * 100, recalls[j] * 100, f1 * 100))
|
||||
"F1 is {:.2f}%".format(j, precisions[j] * 100, recalls[j] * 100, f1 * 100), flush=True)
|
||||
if config.use_ambigous_sample:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prefix = args_opt.mindrecord_prefix
|
||||
config.test_images = args_opt.imgs_path
|
||||
config.test_txts = args_opt.annos_path
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...", flush=True)
|
||||
print("unzip file num: {}".format(data_num), flush=True)
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)), flush=True)
|
||||
print("Extract Done.", flush=True)
|
||||
else:
|
||||
print("This is not zip.", flush=True)
|
||||
else:
|
||||
print("Zip has been extracted.", flush=True)
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1, flush=True)
|
||||
print("Unzip file save dir: ", save_dir_1, flush=True)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===", flush=True)
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}."
|
||||
.format(get_device_id(), zip_file_1, save_dir_1), flush=True)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
prefix = config.eval_mindrecord_prefix
|
||||
config.test_images = config.imgs_path
|
||||
config.test_txts = config.annos_path
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix)
|
||||
print("CHECKING MINDRECORD FILES ...")
|
||||
print("CHECKING MINDRECORD FILES ...", flush=True)
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
print("Create Mindrecord. It may take some time.")
|
||||
print("Create Mindrecord. It may take some time.", flush=True)
|
||||
data_to_mindrecord_byte_image(False, prefix, file_num=1)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir), flush=True)
|
||||
|
||||
print("CHECKING MINDRECORD FILES DONE!")
|
||||
print("Start Eval!")
|
||||
deeptext_eval_test(mindrecord_file, args_opt.checkpoint_path)
|
||||
print("CHECKING MINDRECORD FILES DONE!", flush=True)
|
||||
print("Start Eval!", flush=True)
|
||||
deeptext_eval_test(mindrecord_file, config.checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -13,34 +13,26 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, mindir models"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16_Infer
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='deeptext export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--file_name", type=str, default="deeptext", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument('--ckpt_file', type=str, default='', help='deeptext ckpt file.')
|
||||
args = parser.parse_args()
|
||||
from model_utils.config import config
|
||||
from model_utils.device_adapter import get_device_id
|
||||
|
||||
config.test_batch_size = args.batch_size
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
config.test_batch_size = config.export_batch_size
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.export_device_target)
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Deeptext_VGG16_Infer(config=config)
|
||||
net.set_train(False)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
|
||||
param_dict_new = {}
|
||||
for key, value in param_dict.items():
|
||||
|
@ -50,4 +42,4 @@ if __name__ == '__main__':
|
|||
|
||||
img_data = Tensor(np.zeros([config.test_batch_size, 3, config.img_height, config.img_width]), ms.float32)
|
||||
|
||||
export(net, img_data, file_name=args.file_name, file_format=args.file_format)
|
||||
export(net, img_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,19 +14,13 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Evaluation for Deeptext"""
|
||||
import argparse
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import mmcv
|
||||
from src.config import config
|
||||
from src.utils import metrics
|
||||
|
||||
parser = argparse.ArgumentParser(description="Deeptext evaluation")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result file path")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="label path")
|
||||
parser.add_argument("--img_path", type=str, required=True, help="img path")
|
||||
args_opt = parser.parse_args()
|
||||
from src.utils import metrics
|
||||
from model_utils.config import config
|
||||
|
||||
config.test_batch_size = 1
|
||||
|
||||
|
@ -70,8 +64,8 @@ def get_gt_bboxes_labels(label_file, img_file):
|
|||
def deeptext_eval_test(result_path='', label_path='', img_path=''):
|
||||
eval_iter = 0
|
||||
|
||||
print("\n========================================\n")
|
||||
print("Processing, please wait a moment.")
|
||||
print("\n========================================\n", flush=True)
|
||||
print("Processing, please wait a moment.", flush=True)
|
||||
max_num = 32
|
||||
|
||||
pred_data = []
|
||||
|
@ -109,14 +103,14 @@ def deeptext_eval_test(result_path='', label_path='', img_path=''):
|
|||
"gt_labels": gt_labels})
|
||||
|
||||
precisions, recalls = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
print("\n========================================\n", flush=True)
|
||||
for i in range(config.num_classes - 1):
|
||||
j = i + 1
|
||||
f1 = (2 * precisions[j] * recalls[j]) / (precisions[j] + recalls[j] + 1e-6)
|
||||
print("class {} precision is {:.2f}%, recall is {:.2f}%,"
|
||||
"F1 is {:.2f}%".format(j, precisions[j] * 100, recalls[j] * 100, f1 * 100))
|
||||
"F1 is {:.2f}%".format(j, precisions[j] * 100, recalls[j] * 100, f1 * 100), flush=True)
|
||||
if config.use_ambigous_sample:
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
deeptext_eval_test(args_opt.result_path, args_opt.label_path, args_opt.img_path)
|
||||
deeptext_eval_test(config.result_path, config.label_path, config.img_path)
|
||||
|
|
|
@ -70,11 +70,13 @@ do
|
|||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp -r ../model_utils ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$i --rank_id=$i --imgs_path=$PATH1 --annos_path=$PATH2 --run_distribute=True --device_num=$DEVICE_NUM --pre_trained=$PATH4 &> log &
|
||||
python train.py --imgs_path=$PATH1 --annos_path=$PATH2 --run_distribute=True --pre_trained=$PATH4 &> log &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -61,10 +61,12 @@ then
|
|||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../model_utils ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
python eval.py --device_id=$DEVICE_ID --imgs_path=$PATH1 --annos_path=$PATH2 --checkpoint_path=$PATH3 &> log &
|
||||
python eval.py --imgs_path=$PATH1 --annos_path=$PATH2 --checkpoint_path=$PATH3 &> log &
|
||||
cd ..
|
||||
|
|
|
@ -61,10 +61,12 @@ then
|
|||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --imgs_path=$PATH1 --annos_path=$PATH2 --pre_trained=$PATH3 &> log &
|
||||
python train.py --imgs_path=$PATH1 --annos_path=$PATH2 --pre_trained=$PATH3 &> log &
|
||||
cd ..
|
||||
|
|
|
@ -120,7 +120,7 @@ class Deeptext_VGG16(nn.Cell):
|
|||
stds=self.target_stds)
|
||||
|
||||
# Rcnn
|
||||
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer['out_size'] * config.roi_layer['out_size'],
|
||||
self.rcnn = Rcnn(config, config.rcnn_in_channels * config.roi_layer.out_size * config.roi_layer.out_size,
|
||||
self.train_batch_size, self.num_classes)
|
||||
|
||||
# Op declare
|
||||
|
|
|
@ -86,8 +86,8 @@ class SingleRoIExtractor(nn.Cell):
|
|||
self.out_channels = out_channels
|
||||
self.featmap_strides = featmap_strides
|
||||
self.num_levels = len(self.featmap_strides)
|
||||
self.out_size = roi_layer['out_size']
|
||||
self.sample_num = roi_layer['sample_num']
|
||||
self.out_size = roi_layer.out_size
|
||||
self.sample_num = roi_layer.sample_num
|
||||
self.roi_layers = self.build_roi_layers(self.featmap_strides)
|
||||
self.roi_layers = L.CellList(self.roi_layers)
|
||||
|
||||
|
|
|
@ -1,130 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# " :===========================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"img_width": 960,
|
||||
"img_height": 576,
|
||||
"keep_ratio": False,
|
||||
"flip_ratio": 0.0,
|
||||
"photo_ratio": 0.0,
|
||||
"expand_ratio": 0.3,
|
||||
|
||||
# anchor
|
||||
"feature_shapes": [(36, 60)],
|
||||
"anchor_scales": [2, 4, 6, 8, 12],
|
||||
"anchor_ratios": [0.2, 0.5, 0.8, 1.0, 1.2, 1.5],
|
||||
"anchor_strides": [16],
|
||||
"num_anchors": 5 * 6,
|
||||
|
||||
# rpn
|
||||
"rpn_in_channels": 512,
|
||||
"rpn_feat_channels": 640,
|
||||
"rpn_loss_cls_weight": 1.0,
|
||||
"rpn_loss_reg_weight": 3.0,
|
||||
"rpn_cls_out_channels": 1,
|
||||
"rpn_target_means": [0., 0., 0., 0.],
|
||||
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0],
|
||||
|
||||
# bbox_assign_sampler
|
||||
"neg_iou_thr": 0.3,
|
||||
"pos_iou_thr": 0.5,
|
||||
"min_pos_iou": 0.3,
|
||||
"num_bboxes": 5 * 6 * 36 * 60,
|
||||
"num_gts": 128,
|
||||
"num_expected_neg": 256,
|
||||
"num_expected_pos": 128,
|
||||
|
||||
# proposal
|
||||
"activate_num_classes": 2,
|
||||
"use_sigmoid_cls": True,
|
||||
|
||||
# roi_align
|
||||
"roi_layer": dict(type='RoIAlign', out_size=7, sample_num=2),
|
||||
|
||||
# bbox_assign_sampler_stage2
|
||||
"neg_iou_thr_stage2": 0.2,
|
||||
"pos_iou_thr_stage2": 0.5,
|
||||
"min_pos_iou_stage2": 0.5,
|
||||
"num_bboxes_stage2": 2000,
|
||||
"use_ambigous_sample": True,
|
||||
"num_expected_pos_stage2": 128,
|
||||
"num_expected_amb_stage2": 128,
|
||||
"num_expected_neg_stage2": 640,
|
||||
"num_expected_total_stage2": 640,
|
||||
|
||||
# rcnn
|
||||
"rcnn_in_channels": 512,
|
||||
"rcnn_fc_out_channels": 4096,
|
||||
"rcnn_loss_cls_weight": 1,
|
||||
"rcnn_loss_reg_weight": 1,
|
||||
"rcnn_target_means": [0., 0., 0., 0.],
|
||||
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2],
|
||||
|
||||
# train proposal
|
||||
"rpn_proposal_nms_across_levels": False,
|
||||
"rpn_proposal_nms_pre": 2000,
|
||||
"rpn_proposal_nms_post": 2000,
|
||||
"rpn_proposal_max_num": 2000,
|
||||
"rpn_proposal_nms_thr": 0.7,
|
||||
"rpn_proposal_min_bbox_size": 0,
|
||||
|
||||
# test proposal
|
||||
"rpn_nms_across_levels": False,
|
||||
"rpn_nms_pre": 1000,
|
||||
"rpn_nms_post": 1000,
|
||||
"rpn_max_num": 1000,
|
||||
"rpn_nms_thr": 0.7,
|
||||
"rpn_min_bbox_min_size": 0,
|
||||
"test_score_thr": 0.80,
|
||||
"test_iou_thr": 0.5,
|
||||
"test_max_per_img": 100,
|
||||
"test_batch_size": 2,
|
||||
|
||||
"rpn_head_loss_type": "CrossEntropyLoss",
|
||||
"rpn_head_use_sigmoid": True,
|
||||
"rpn_head_weight": 1.0,
|
||||
|
||||
# LR
|
||||
"base_lr": 0.02,
|
||||
"base_step": 982 * 8,
|
||||
"total_epoch": 70,
|
||||
"warmup_step": 50,
|
||||
"warmup_mode": "linear",
|
||||
"warmup_ratio": 1 / 3.0,
|
||||
"sgd_step": [8, 11],
|
||||
"sgd_momentum": 0.9,
|
||||
|
||||
# train
|
||||
"batch_size": 2,
|
||||
"loss_scale": 1,
|
||||
"momentum": 0.91,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 70,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 10,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./",
|
||||
|
||||
"mindrecord_dir": "/home/deeptext_sustech/data/mindrecord/full_ori",
|
||||
"use_coco": True,
|
||||
"coco_root": "/d0/dataset/coco2017",
|
||||
"cocotext_json": "/home/deeptext_sustech/data/cocotext.v2.json",
|
||||
"coco_train_data_type": "train2017",
|
||||
"num_classes": 3
|
||||
})
|
|
@ -26,7 +26,7 @@ import mindspore.dataset.vision.c_transforms as C
|
|||
import mindspore.dataset.transforms.c_transforms as CC
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.config import config
|
||||
from model_utils.config import config
|
||||
|
||||
|
||||
def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""metrics utils"""
|
||||
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from model_utils.config import config
|
||||
|
||||
|
||||
def calc_iou(bbox_pred, bbox_ground):
|
||||
|
|
|
@ -15,18 +15,19 @@
|
|||
|
||||
"""train Deeptext and get checkpoint files."""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
|
||||
from src.config import config
|
||||
from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import set_seed
|
||||
|
@ -41,74 +42,113 @@ np.set_printoptions(threshold=np.inf)
|
|||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Deeptext training")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: False.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.")
|
||||
parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.")
|
||||
parser.add_argument("--device_id", type=int, default=5, help="Device id, default: 5.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
parser.add_argument("--imgs_path", type=str, required=True,
|
||||
help="Train images files paths, multiple paths can be separated by ','.")
|
||||
parser.add_argument("--annos_path", type=str, required=True,
|
||||
help="Annotations files paths of train images, multiple paths can be separated by ','.")
|
||||
parser.add_argument("--mindrecord_prefix", type=str, default='Deeptext-TRAIN', help="Prefix of mindrecord.")
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...", flush=True)
|
||||
print("unzip file num: {}".format(data_num), flush=True)
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)), flush=True)
|
||||
print("Extract Done.", flush=True)
|
||||
else:
|
||||
print("This is not zip.", flush=True)
|
||||
else:
|
||||
print("Zip has been extracted.", flush=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.run_distribute:
|
||||
rank = args_opt.rank_id
|
||||
device_num = args_opt.device_num
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1, flush=True)
|
||||
print("Unzip file save dir: ", save_dir_1, flush=True)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===", flush=True)
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}."
|
||||
.format(get_device_id(), zip_file_1, save_dir_1), flush=True)
|
||||
|
||||
config.save_checkpoint_path = os.path.join(config.output_path, config.save_checkpoint_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
if config.run_distribute:
|
||||
rank = get_rank_id()
|
||||
device_num = get_device_num()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
rank = get_rank_id()
|
||||
device_num = 1
|
||||
|
||||
print("Start create dataset!")
|
||||
print("Start create dataset!", flush=True)
|
||||
|
||||
# It will generate mindrecord file in args_opt.mindrecord_dir,
|
||||
# It will generate mindrecord file in config.mindrecord_dir,
|
||||
# and the file name is DeepText.mindrecord0, 1, ... file_num.
|
||||
prefix = args_opt.mindrecord_prefix
|
||||
config.train_images = args_opt.imgs_path
|
||||
config.train_txts = args_opt.annos_path
|
||||
prefix = config.mindrecord_prefix
|
||||
config.train_images = config.imgs_path
|
||||
config.train_txts = config.annos_path
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
print("CHECKING MINDRECORD FILES ...")
|
||||
print("CHECKING MINDRECORD FILES ...", flush=True)
|
||||
|
||||
if rank == 0 and not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if os.path.isdir(config.coco_root):
|
||||
if not os.path.exists(config.coco_root):
|
||||
print("Please make sure config:coco_root is valid.")
|
||||
print("Please make sure config:coco_root is valid.", flush=True)
|
||||
raise ValueError(config.coco_root)
|
||||
print("Create Mindrecord. It may take some time.")
|
||||
print("Create Mindrecord. It may take some time.", flush=True)
|
||||
data_to_mindrecord_byte_image(True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir), flush=True)
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
print("coco_root not exits.", flush=True)
|
||||
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
||||
print("CHECKING MINDRECORD FILES DONE!")
|
||||
|
||||
loss_scale = float(config.loss_scale)
|
||||
print("CHECKING MINDRECORD FILES DONE!", flush=True)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
|
||||
dataset = create_deeptext_dataset(mindrecord_file, repeat_num=1,
|
||||
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done! dataset_size = ", dataset_size)
|
||||
print("Create dataset done! dataset_size = ", dataset_size, flush=True)
|
||||
net = Deeptext_VGG16(config=config)
|
||||
net = net.set_train()
|
||||
|
||||
load_path = args_opt.pre_trained
|
||||
load_path = config.pre_trained
|
||||
if load_path != "":
|
||||
param_dict = load_checkpoint(load_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
@ -119,7 +159,7 @@ if __name__ == '__main__':
|
|||
opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
if args_opt.run_distribute:
|
||||
if config.run_distribute:
|
||||
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
|
||||
mean=True, degree=device_num)
|
||||
else:
|
||||
|
@ -137,3 +177,7 @@ if __name__ == '__main__':
|
|||
|
||||
model = Model(net)
|
||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue