forked from mindspore-Ecosystem/mindspore
!17701 revert alexnet datset.py config.py&modify 310 infer
From: @zeyangao Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34
This commit is contained in:
commit
7aaadbe500
|
@ -14,24 +14,23 @@
|
|||
# ============================================================================
|
||||
"""postprocess for 310 inference"""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||
|
||||
from mindspore.nn import Top1CategoricalAccuracy
|
||||
from src.model_utils.config import config as cfg
|
||||
|
||||
batch_size = 1
|
||||
parser = argparse.ArgumentParser(description="postprocess")
|
||||
parser.add_argument("--result_dir", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_dir", type=str, required=True, help="image file path.")
|
||||
label_path = "./preprocess_Result/cifar10_label_ids.npy"
|
||||
parser.add_argument("--result_dir", type=str, default="./result_Files", help="result files path.")
|
||||
parser.add_argument("--label_dir", type=str, default=label_path, help="image file path.")
|
||||
parser.add_argument("--config_path", type=str, default="../default_config.yaml", help="config file path.")
|
||||
parser.add_argument('--dataset_name', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg.config_path = args.config_path
|
||||
if __name__ == '__main__':
|
||||
top1_acc = Top1CategoricalAccuracy()
|
||||
rst_path = args.result_dir
|
||||
if args.dataset_name == "cifar10":
|
||||
from src.config import alexnet_cifar10_cfg as cfg
|
||||
labels = np.load(args.label_dir, allow_pickle=True)
|
||||
for idx, label in enumerate(labels):
|
||||
f_name = os.path.join(rst_path, "alexnet_data_bs" + str(cfg.batch_size) + "_" + str(idx) + "_0.bin")
|
||||
|
@ -39,17 +38,3 @@ if __name__ == '__main__':
|
|||
pred = pred.reshape(cfg.batch_size, int(pred.shape[0] / cfg.batch_size))
|
||||
top1_acc.update(pred, labels[idx])
|
||||
print("acc: ", top1_acc.eval())
|
||||
else:
|
||||
from src.config import alexnet_imagenet_cfg as cfg
|
||||
top5_acc = Top5CategoricalAccuracy()
|
||||
file_list = os.listdir(rst_path)
|
||||
with open(args.label_dir, "r") as label:
|
||||
labels = json.load(label)
|
||||
for f in file_list:
|
||||
label = f.split("_0.bin")[0] + ".JPEG"
|
||||
pred = np.fromfile(os.path.join(rst_path, f), np.float32)
|
||||
pred = pred.reshape(cfg.batch_size, int(pred.shape[0] / cfg.batch_size))
|
||||
top1_acc.update(pred, [labels[label],])
|
||||
top5_acc.update(pred, [labels[label],])
|
||||
print("Top1 acc: ", top1_acc.eval())
|
||||
print("Top5 acc: ", top5_acc.eval())
|
||||
|
|
|
@ -15,59 +15,28 @@
|
|||
"""preprocess"""
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
from src.model_utils.config import config
|
||||
from src.dataset import create_dataset_cifar10
|
||||
|
||||
def create_label(result_path, dir_path):
|
||||
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
|
||||
dirs = os.listdir(dir_path)
|
||||
file_list = []
|
||||
for file in dirs:
|
||||
file_list.append(file)
|
||||
file_list = sorted(file_list)
|
||||
|
||||
total = 0
|
||||
img_label = {}
|
||||
for i, file_dir in enumerate(file_list):
|
||||
files = os.listdir(os.path.join(dir_path, file_dir))
|
||||
for f in files:
|
||||
img_label[f] = i
|
||||
total += len(files)
|
||||
|
||||
json_file = os.path.join(result_path, "imagenet_label.json")
|
||||
with open(json_file, "w+") as label:
|
||||
json.dump(img_label, label)
|
||||
|
||||
print("[INFO] Completed! Total {} data.".format(total))
|
||||
|
||||
parser = argparse.ArgumentParser('preprocess')
|
||||
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
|
||||
parser.add_argument('--dataset_name', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
|
||||
parser.add_argument('--data_path', type=str, default='', help='eval data dir')
|
||||
parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
|
||||
parser.add_argument("--config_path", type=str, default="../default_config.yaml", help="config file path.")
|
||||
result_path = './preprocess_Result/'
|
||||
#parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset == "cifar10":
|
||||
from src.config import alexnet_cifar10_cfg as cfg
|
||||
else:
|
||||
from src.config import alexnet_imagenet_cfg as cfg
|
||||
|
||||
args.per_batch_size = cfg.batch_size
|
||||
#args.image_size = cfg.image_size
|
||||
|
||||
|
||||
config.config_path = args.config_path
|
||||
if __name__ == "__main__":
|
||||
if args.dataset == "cifar10":
|
||||
dataset = create_dataset_cifar10(args.data_path, args.per_batch_size, training=False)
|
||||
img_path = os.path.join(args.result_path, "00_data")
|
||||
if args.dataset_name == "cifar10":
|
||||
dataset = create_dataset_cifar10(config, args.data_path, batch_size=config.batch_size, status="eval")
|
||||
img_path = os.path.join(result_path, "00_data")
|
||||
os.makedirs(img_path)
|
||||
label_list = []
|
||||
for idx, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
file_name = "alexnet_data_bs" + str(args.per_batch_size) + "_" + str(idx) + ".bin"
|
||||
file_name = "alexnet_data_bs" + str(config.batch_size) + "_" + str(idx) + ".bin"
|
||||
file_path = os.path.join(img_path, file_name)
|
||||
data["image"].tofile(file_path)
|
||||
label_list.append(data["label"])
|
||||
np.save(os.path.join(args.result_path, "cifar10_label_ids.npy"), label_list)
|
||||
np.save(os.path.join(result_path, "cifar10_label_ids.npy"), label_list)
|
||||
print("=" * 20, "export bin files finished", "=" * 20)
|
||||
else:
|
||||
create_label(args.result_path, args.data_path)
|
||||
|
|
@ -45,7 +45,10 @@ device_id=0
|
|||
if [ $# == 4 ]; then
|
||||
device_id=$4
|
||||
fi
|
||||
|
||||
BASEPATH=$(dirname "$(pwd)")
|
||||
config_path=$BASEPATH"/default_config.yaml"
|
||||
echo "base path :"$BASEPATH
|
||||
echo "config path :"$config_path
|
||||
echo "mindir name: "$model
|
||||
echo "dataset name: "$dataset_name
|
||||
echo "dataset path: "$dataset_path
|
||||
|
@ -91,7 +94,7 @@ function preprocess_data()
|
|||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
python3.7 ../preprocess.py --dataset=$dataset_name --data_path=$dataset_path --result_path=./preprocess_Result/
|
||||
python3.7 ../preprocess.py --config_path=$config_path --dataset_name=$dataset_name --data_path=$dataset_path
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
|
@ -112,20 +115,13 @@ function infer()
|
|||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
|
||||
if [ "$dataset_name" == "cifar10" ]; then
|
||||
../ascend310_infer/out/main --mindir_path=$model --dataset_name=$dataset_name --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
|
||||
else
|
||||
../ascend310_infer/out/main --mindir_path=$model --dataset_name=$dataset_name --input0_path=$dataset_path --device_id=$device_id &> infer.log
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
function cal_acc()
|
||||
{
|
||||
if [ "$dataset_name" == "cifar10" ]; then
|
||||
python3.7 ../postprocess.py --result_dir=./result_Files --label_dir=./preprocess_Result/cifar10_label_ids.npy --dataset_name=$dataset_name &> acc.log
|
||||
else
|
||||
python3.7 ../postprocess.py --result_dir=./result_Files --label_dir=./preprocess_Result/imagenet_label.json --dataset_name=$dataset_name &> acc.log
|
||||
fi
|
||||
python3.7 ../postprocess.py --dataset_name=$dataset_name &> acc.log
|
||||
}
|
||||
|
||||
if [ $need_preprocess == "y" ]; then
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
alexnet_cifar10_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'learning_rate': 0.002,
|
||||
'momentum': 0.9,
|
||||
'epoch_size': 30,
|
||||
'batch_size': 32,
|
||||
'buffer_size': 1000,
|
||||
'image_height': 227,
|
||||
'image_width': 227,
|
||||
'save_checkpoint_steps': 1562,
|
||||
'keep_checkpoint_max': 10,
|
||||
'air_name': "alexnet.air",
|
||||
})
|
||||
|
||||
alexnet_imagenet_cfg = edict({
|
||||
'num_classes': 1000,
|
||||
'learning_rate': 0.13,
|
||||
'momentum': 0.9,
|
||||
'epoch_size': 150,
|
||||
'batch_size': 256,
|
||||
'buffer_size': None, # invalid parameter
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'save_checkpoint_steps': 625,
|
||||
'keep_checkpoint_max': 10,
|
||||
'air_name': "alexnet.air",
|
||||
|
||||
# opt
|
||||
'weight_decay': 0.0001,
|
||||
'loss_scale': 1024,
|
||||
|
||||
# lr
|
||||
'is_dynamic_loss_scale': 0,
|
||||
})
|
|
@ -22,10 +22,9 @@ import mindspore.dataset.transforms.c_transforms as C
|
|||
import mindspore.dataset.vision.c_transforms as CV
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from .config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
|
||||
|
||||
|
||||
def create_dataset_cifar10(data_path, cfg, batch_size=32, repeat_size=1, training=True, target="Ascend"):
|
||||
def create_dataset_cifar10(cfg, data_path, batch_size=32, repeat_size=1, status="train", target="Ascend"):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
|
@ -40,18 +39,18 @@ def create_dataset_cifar10(data_path, cfg, batch_size=32, repeat_size=1, trainin
|
|||
num_shards=device_num, shard_id=rank_id)
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
cfg = alexnet_cifar10_cfg
|
||||
# cfg = alexnet_cifar10_cfg
|
||||
|
||||
resize_op = CV.Resize((cfg.image_height, cfg.image_width))
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
if training:
|
||||
if status == "train":
|
||||
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
|
||||
random_horizontal_op = CV.RandomHorizontalFlip()
|
||||
channel_swap_op = CV.HWC2CHW()
|
||||
typecast_op = C.TypeCast(mstype.int32)
|
||||
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=8)
|
||||
if training:
|
||||
if status == "train":
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=8)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=8)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=8)
|
||||
|
@ -65,7 +64,7 @@ def create_dataset_cifar10(data_path, cfg, batch_size=32, repeat_size=1, trainin
|
|||
return cifar_ds
|
||||
|
||||
|
||||
def create_dataset_imagenet(dataset_path, cfg, batch_size=32, repeat_num=1, training=True,
|
||||
def create_dataset_imagenet(cfg, dataset_path, batch_size=32, repeat_num=1, training=True,
|
||||
num_parallel_workers=None, shuffle=None, sampler=None, class_indexing=None):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
@ -82,7 +81,7 @@ def create_dataset_imagenet(dataset_path, cfg, batch_size=32, repeat_num=1, trai
|
|||
"""
|
||||
|
||||
device_num, rank_id = _get_rank_info()
|
||||
cfg = alexnet_imagenet_cfg
|
||||
# cfg = alexnet_imagenet_cfg
|
||||
|
||||
num_parallel_workers = 16
|
||||
if device_num == 1:
|
||||
|
|
Loading…
Reference in New Issue