!18532 tinydarknet pass parameter modification and amend cnnctc

Merge pull request !18532 from chenweitao_295/tinydarknet_amend
This commit is contained in:
i-robot 2021-06-21 11:45:30 +08:00 committed by Gitee
commit 3259d65d00
5 changed files with 33 additions and 50 deletions

View File

@ -61,9 +61,8 @@ DEFINE_int32(image_height, 32, "image height");
DEFINE_int32(image_width, 100, "image width");
int PadImage(const MSTensor &input, MSTensor *output) {
std::shared_ptr<TensorTransform> normalize(new Normalize({127.5, 127.5, 127.5},
{127.5, 127.5, 127.5}));
Execute composeNormalize({normalize});
auto normalize = Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5});
Execute composeNormalize(normalize);
std::vector<int64_t> shape = input.Shape();
auto imgResize = MSTensor();
auto imgNormalize = MSTensor();
@ -74,19 +73,17 @@ int PadImage(const MSTensor &input, MSTensor *output) {
NewWidth = ceil(FLAGS_image_height * ratio);
paddingSize = FLAGS_image_width - NewWidth;
if (NewWidth > FLAGS_image_width) {
std::shared_ptr<TensorTransform> resize(new Resize({FLAGS_image_height, FLAGS_image_width},
InterpolationMode::kCubicPil));
Execute composeResize({resize});
auto resize = Resize({FLAGS_image_height, FLAGS_image_width}, InterpolationMode::kArea);
Execute composeResize(resize);
composeResize(input, &imgResize);
composeNormalize(imgResize, output);
} else {
std::shared_ptr<TensorTransform> resize(new Resize({FLAGS_image_height, NewWidth},
InterpolationMode::kCubicPil));
Execute composeResize({resize});
auto resize = Resize({FLAGS_image_height, NewWidth}, InterpolationMode::kArea);
Execute composeResize(resize);
composeResize(input, &imgResize);
composeNormalize(imgResize, &imgNormalize);
std::shared_ptr<TensorTransform> pad(new Pad({0, 0, paddingSize, 0}));
Execute composePad({pad});
auto pad = Pad({0, 0, paddingSize, 0});
Execute composePad(pad);
composePad(imgNormalize, output);
}
return 0;
@ -118,10 +115,10 @@ int main(int argc, char **argv) {
auto all_files = GetAllFiles(FLAGS_dataset_path);
std::map<double, double> costTime_map;
size_t size = all_files.size();
std::shared_ptr<TensorTransform> decode(new Decode());
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
Execute composeDecode({decode});
Execute composeTranspose({hwc2chw});
auto decode = Decode();
auto hwc2chw = HWC2CHW();
Execute composeDecode(decode);
Execute composeTranspose(hwc2chw);
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};

View File

@ -88,14 +88,12 @@ int main(int argc, char **argv) {
return 1;
}
std::shared_ptr<TensorTransform> decode(new Decode());
std::shared_ptr<TensorTransform> resize(new Resize({256}));
std::shared_ptr<TensorTransform> dvpp_resize(new Resize({256, 256}));
auto crop_size = {FLAGS_image_height, FLAGS_image_width};
std::shared_ptr<TensorTransform> center_crop(new CenterCrop(crop_size));
std::shared_ptr<TensorTransform> normalize(new Normalize({123.675, 116.28, 103.53},
{58.395, 57.120, 57.375}));
std::shared_ptr<TensorTransform> hwc2chw(new HWC2CHW());
auto decode = Decode();
auto resize = Resize({256});
auto dvpp_resize = Resize({256, 256});
auto center_crop = CenterCrop({FLAGS_image_height, FLAGS_image_width});
auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.120, 57.375});
auto hwc2chw = HWC2CHW();
Execute transform({decode, resize, center_crop, normalize, hwc2chw});
Execute dvpptransform({decode, dvpp_resize});

View File

@ -16,35 +16,24 @@
##############export checkpoint file into air and onnx models#################
python export.py
"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.model_utils.config import config as imagenet_cfg
from src.model_utils.config import config
from src.tinydarknet import TinyDarkNet
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--file_format', type=str, default='AIR', choices=['MINDIR', 'AIR'],
help='file format.')
parser.add_argument('--file_name', type=str, default='tinydarknet', help='output file name.')
args_opt = parser.parse_args()
if args_opt.dataset_name == 'imagenet':
cfg = imagenet_cfg
else:
if config.dataset_name != 'imagenet':
raise ValueError("Dataset is not support.")
net = TinyDarkNet(num_classes=cfg.num_classes)
net = TinyDarkNet(num_classes=config.num_classes)
assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None."
param_dict = load_checkpoint(cfg.checkpoint_path)
assert config.checkpoint_path is not None, "config.checkpoint_path is None."
param_dict = load_checkpoint(config.checkpoint_path)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]), ms.float32)
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[config.batch_size, 3, 224, 224]), ms.float32)
export(net, input_arr, file_name=config.file_name, file_format=config.file_format)

View File

@ -30,8 +30,8 @@ train_data_dir: './dataset/imagenet_original/train/'
val_data_dir: './dataset/imagenet_original/val/'
keep_checkpoint_max: 1
checkpoint_path: './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt'
onnx_filename: 'tinydarknet.onnx'
air_filename: 'tinydarknet.air'
file_name: 'tinydarknet'
file_format: 'MINDIR'
# optimizer and lr related
lr_scheduler: 'exponential'
lr_epochs: [70, 140, 210, 280]
@ -44,6 +44,9 @@ is_dynamic_loss_scale: False
loss_scale: 1024
label_smooth_factor: 0.1
use_label_smooth: True
#310infer postprocess
result_path: ''
label_file: ''
---
@ -55,3 +58,4 @@ data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
enable_profiling: 'Whether enable profiling while training, default: False'
file_format: '["MINDIR", "AIR"]'

View File

@ -14,13 +14,8 @@
# ============================================================================
"""post process for 310 inference"""
import os
import argparse
import numpy as np
parser = argparse.ArgumentParser(description='tinydarknet calcul top1 and top5 acc')
parser.add_argument("--result_path", type=str, required=True, default='', help="result file path")
parser.add_argument("--label_file", type=str, required=True, default='', help="label file")
args = parser.parse_args()
from src.model_utils.config import config
def get_top5_acc(top_arg, gt_class):
@ -69,4 +64,4 @@ def cal_acc(result_path, label_file):
if __name__ == '__main__':
cal_acc(args.result_path, args.label_file)
cal_acc(config.result_path, config.label_file)