!20096 modify resnet 310 infer
Merge pull request !20096 from lilei/modify_model_zoo_readme_R1.3
This commit is contained in:
commit
6216e651fb
|
@ -14,22 +14,17 @@
|
|||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import os
|
||||
import argparse
|
||||
from src.dataset import create_dataset1 as create_dataset
|
||||
|
||||
parser = argparse.ArgumentParser(description='preprocess data')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--output_path', type=str, default=None, help='output path')
|
||||
args_opt = parser.parse_args()
|
||||
from src.model_utils.config import config
|
||||
|
||||
if __name__ == '__main__':
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=1,
|
||||
dataset = create_dataset(dataset_path=config.data_path, do_train=False, batch_size=1,
|
||||
target="Ascend")
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
img_path = os.path.join(args_opt.output_path, "img_data")
|
||||
label_path = os.path.join(args_opt.output_path, "label")
|
||||
img_path = os.path.join(config.output_path, "img_data")
|
||||
label_path = os.path.join(config.output_path, "label")
|
||||
os.makedirs(img_path)
|
||||
os.makedirs(label_path)
|
||||
|
||||
|
|
|
@ -85,8 +85,10 @@ function preprocess_data()
|
|||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
||||
CONFIG_FILE="${BASE_PATH}/$1"
|
||||
|
||||
python3.7 ../preprocess.py --dataset_path=$data_path --output_path=./preprocess_Result
|
||||
python3.7 ../preprocess.py --data_path=$data_path --output_path=./preprocess_Result --config_path=$CONFIG_FILE &> preprocess.log
|
||||
}
|
||||
|
||||
function infer()
|
||||
|
@ -118,7 +120,12 @@ function cal_acc()
|
|||
}
|
||||
|
||||
if [ "x${dataset}" == "xcifar10" ] || [ "x${dataset}" == "xCifar10" ]; then
|
||||
preprocess_data
|
||||
if [ $2 == 'resnet18' ]; then
|
||||
CONFIG_PATH=resnet18_cifar10_config.yaml
|
||||
else
|
||||
CONFIG_PATH=resnet50_cifar10_config.yaml
|
||||
fi
|
||||
preprocess_data ${CONFIG_PATH}
|
||||
data_path=./preprocess_Result/img_data
|
||||
fi
|
||||
|
||||
|
|
Loading…
Reference in New Issue