From cb30181bb1dc11231510575affd06da4cc097067 Mon Sep 17 00:00:00 2001 From: chenweitao_295 Date: Fri, 18 Jun 2021 15:40:55 +0800 Subject: [PATCH] tinydarknet pass parameter modification --- .../cv/cnnctc/ascend310_infer/src/main.cc | 27 +++++++++---------- .../tinydarknet/ascend310_infer/src/main.cc | 14 +++++----- model_zoo/official/cv/tinydarknet/export.py | 25 +++++------------ .../cv/tinydarknet/imagenet_config.yaml | 8 ++++-- .../official/cv/tinydarknet/postprocess.py | 9 ++----- 5 files changed, 33 insertions(+), 50 deletions(-) diff --git a/model_zoo/official/cv/cnnctc/ascend310_infer/src/main.cc b/model_zoo/official/cv/cnnctc/ascend310_infer/src/main.cc index e1681a757cc..dea34ae72c1 100644 --- a/model_zoo/official/cv/cnnctc/ascend310_infer/src/main.cc +++ b/model_zoo/official/cv/cnnctc/ascend310_infer/src/main.cc @@ -61,9 +61,8 @@ DEFINE_int32(image_height, 32, "image height"); DEFINE_int32(image_width, 100, "image width"); int PadImage(const MSTensor &input, MSTensor *output) { - std::shared_ptr normalize(new Normalize({127.5, 127.5, 127.5}, - {127.5, 127.5, 127.5})); - Execute composeNormalize({normalize}); + auto normalize = Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5}); + Execute composeNormalize(normalize); std::vector shape = input.Shape(); auto imgResize = MSTensor(); auto imgNormalize = MSTensor(); @@ -74,19 +73,17 @@ int PadImage(const MSTensor &input, MSTensor *output) { NewWidth = ceil(FLAGS_image_height * ratio); paddingSize = FLAGS_image_width - NewWidth; if (NewWidth > FLAGS_image_width) { - std::shared_ptr resize(new Resize({FLAGS_image_height, FLAGS_image_width}, - InterpolationMode::kCubicPil)); - Execute composeResize({resize}); + auto resize = Resize({FLAGS_image_height, FLAGS_image_width}, InterpolationMode::kArea); + Execute composeResize(resize); composeResize(input, &imgResize); composeNormalize(imgResize, output); } else { - std::shared_ptr resize(new Resize({FLAGS_image_height, NewWidth}, - InterpolationMode::kCubicPil)); - Execute composeResize({resize}); + auto resize = Resize({FLAGS_image_height, NewWidth}, InterpolationMode::kArea); + Execute composeResize(resize); composeResize(input, &imgResize); composeNormalize(imgResize, &imgNormalize); - std::shared_ptr pad(new Pad({0, 0, paddingSize, 0})); - Execute composePad({pad}); + auto pad = Pad({0, 0, paddingSize, 0}); + Execute composePad(pad); composePad(imgNormalize, output); } return 0; @@ -118,10 +115,10 @@ int main(int argc, char **argv) { auto all_files = GetAllFiles(FLAGS_dataset_path); std::map costTime_map; size_t size = all_files.size(); - std::shared_ptr decode(new Decode()); - std::shared_ptr hwc2chw(new HWC2CHW()); - Execute composeDecode({decode}); - Execute composeTranspose({hwc2chw}); + auto decode = Decode(); + auto hwc2chw = HWC2CHW(); + Execute composeDecode(decode); + Execute composeTranspose(hwc2chw); for (size_t i = 0; i < size; ++i) { struct timeval start = {0}; struct timeval end = {0}; diff --git a/model_zoo/official/cv/tinydarknet/ascend310_infer/src/main.cc b/model_zoo/official/cv/tinydarknet/ascend310_infer/src/main.cc index d1e12f5130a..47233c8eb42 100644 --- a/model_zoo/official/cv/tinydarknet/ascend310_infer/src/main.cc +++ b/model_zoo/official/cv/tinydarknet/ascend310_infer/src/main.cc @@ -88,14 +88,12 @@ int main(int argc, char **argv) { return 1; } - std::shared_ptr decode(new Decode()); - std::shared_ptr resize(new Resize({256})); - std::shared_ptr dvpp_resize(new Resize({256, 256})); - auto crop_size = {FLAGS_image_height, FLAGS_image_width}; - std::shared_ptr center_crop(new CenterCrop(crop_size)); - std::shared_ptr normalize(new Normalize({123.675, 116.28, 103.53}, - {58.395, 57.120, 57.375})); - std::shared_ptr hwc2chw(new HWC2CHW()); + auto decode = Decode(); + auto resize = Resize({256}); + auto dvpp_resize = Resize({256, 256}); + auto center_crop = CenterCrop({FLAGS_image_height, FLAGS_image_width}); + auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.120, 57.375}); + auto hwc2chw = HWC2CHW(); Execute transform({decode, resize, center_crop, normalize, hwc2chw}); Execute dvpptransform({decode, dvpp_resize}); diff --git a/model_zoo/official/cv/tinydarknet/export.py b/model_zoo/official/cv/tinydarknet/export.py index 010bfe4b671..b9d7700e95c 100644 --- a/model_zoo/official/cv/tinydarknet/export.py +++ b/model_zoo/official/cv/tinydarknet/export.py @@ -16,35 +16,24 @@ ##############export checkpoint file into air and onnx models################# python export.py """ -import argparse import numpy as np import mindspore as ms from mindspore import Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export -from src.model_utils.config import config as imagenet_cfg +from src.model_utils.config import config from src.tinydarknet import TinyDarkNet if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Classification') - parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'], - help='dataset name.') - parser.add_argument('--file_format', type=str, default='AIR', choices=['MINDIR', 'AIR'], - help='file format.') - parser.add_argument('--file_name', type=str, default='tinydarknet', help='output file name.') - args_opt = parser.parse_args() - - if args_opt.dataset_name == 'imagenet': - cfg = imagenet_cfg - else: + if config.dataset_name != 'imagenet': raise ValueError("Dataset is not support.") - net = TinyDarkNet(num_classes=cfg.num_classes) + net = TinyDarkNet(num_classes=config.num_classes) - assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None." - param_dict = load_checkpoint(cfg.checkpoint_path) + assert config.checkpoint_path is not None, "config.checkpoint_path is None." + param_dict = load_checkpoint(config.checkpoint_path) load_param_into_net(net, param_dict) - input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]), ms.float32) - export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format) + input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[config.batch_size, 3, 224, 224]), ms.float32) + export(net, input_arr, file_name=config.file_name, file_format=config.file_format) diff --git a/model_zoo/official/cv/tinydarknet/imagenet_config.yaml b/model_zoo/official/cv/tinydarknet/imagenet_config.yaml index 5617f77eacb..6222cd7d9ba 100644 --- a/model_zoo/official/cv/tinydarknet/imagenet_config.yaml +++ b/model_zoo/official/cv/tinydarknet/imagenet_config.yaml @@ -30,8 +30,8 @@ train_data_dir: './dataset/imagenet_original/train/' val_data_dir: './dataset/imagenet_original/val/' keep_checkpoint_max: 1 checkpoint_path: './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt' -onnx_filename: 'tinydarknet.onnx' -air_filename: 'tinydarknet.air' +file_name: 'tinydarknet' +file_format: 'MINDIR' # optimizer and lr related lr_scheduler: 'exponential' lr_epochs: [70, 140, 210, 280] @@ -44,6 +44,9 @@ is_dynamic_loss_scale: False loss_scale: 1024 label_smooth_factor: 0.1 use_label_smooth: True +#310infer postprocess +result_path: '' +label_file: '' --- @@ -55,3 +58,4 @@ data_path: "The location of the input data." output_path: "The location of the output file." device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend." enable_profiling: 'Whether enable profiling while training, default: False' +file_format: '["MINDIR", "AIR"]' diff --git a/model_zoo/official/cv/tinydarknet/postprocess.py b/model_zoo/official/cv/tinydarknet/postprocess.py index c08e5747d95..e47019e4a3a 100644 --- a/model_zoo/official/cv/tinydarknet/postprocess.py +++ b/model_zoo/official/cv/tinydarknet/postprocess.py @@ -14,13 +14,8 @@ # ============================================================================ """post process for 310 inference""" import os -import argparse import numpy as np - -parser = argparse.ArgumentParser(description='tinydarknet calcul top1 and top5 acc') -parser.add_argument("--result_path", type=str, required=True, default='', help="result file path") -parser.add_argument("--label_file", type=str, required=True, default='', help="label file") -args = parser.parse_args() +from src.model_utils.config import config def get_top5_acc(top_arg, gt_class): @@ -69,4 +64,4 @@ def cal_acc(result_path, label_file): if __name__ == '__main__': - cal_acc(args.result_path, args.label_file) + cal_acc(config.result_path, config.label_file)