forked from mindspore-Ecosystem/mindspore
!18952 modify model_zoo network bug for clould
Merge pull request !18952 from lilei/modify_model_zoo_bug
This commit is contained in:
commit
7860c29397
|
@ -58,7 +58,7 @@ def set_seed(seed):
|
|||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.ops as ops
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.common import set_seed
|
||||
>>> from mindspore.common.initializer import initializer
|
||||
|
|
|
@ -15,10 +15,16 @@
|
|||
"""post process for 310 inference"""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description="resnet inference")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="image file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
batch_size = 1
|
||||
num_classes = 1000
|
||||
|
||||
def get_result(result_path, label_path):
|
||||
files = os.listdir(result_path)
|
||||
|
@ -31,7 +37,7 @@ def get_result(result_path, label_path):
|
|||
for file in files:
|
||||
img_ids_name = file.split('_0.')[0]
|
||||
data_path = os.path.join(result_path, img_ids_name + "_0.bin")
|
||||
result = np.fromfile(data_path, dtype=np.float16).reshape(batch_size, config.num_classes)
|
||||
result = np.fromfile(data_path, dtype=np.float16).reshape(batch_size, num_classes)
|
||||
for batch in range(batch_size):
|
||||
predict = np.argsort(-result[batch], axis=-1)
|
||||
if labels[img_ids_name+".JPEG"] == predict[0]:
|
||||
|
@ -42,4 +48,4 @@ def get_result(result_path, label_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_result(config.result_path, config.label_path)
|
||||
get_result(args.result_path, args.label_path)
|
||||
|
|
|
@ -45,7 +45,7 @@ do
|
|||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cp *.yaml ./LOG$i
|
||||
cp ./src ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ function infer()
|
|||
|
||||
function cal_acc()
|
||||
{
|
||||
python3.7 ../postprocess.py --result_path=./result_Files --img_path=$data_path --drop &> acc.log &
|
||||
python3.7 ../postprocess.py --result_path=./result_Files --img_path=$data_path --drop=True &> acc.log &
|
||||
}
|
||||
|
||||
compile_app
|
||||
|
|
|
@ -71,8 +71,8 @@ file_name: "mass"
|
|||
file_format: "AIR"
|
||||
vocab_file: ""
|
||||
result_path: "./preprocess_Result/"
|
||||
source_id_folder: ""
|
||||
target_id_folder: ""
|
||||
source_id_folder: "./preprocess_Result/00_source_eos_ids"
|
||||
target_id_folder: "./preprocess_Result/target_eos_ids"
|
||||
result_dir: "./result_Files"
|
||||
|
||||
---
|
||||
|
|
|
@ -26,12 +26,17 @@ from src.model_utils.config import config
|
|||
from src.transformer.transformer_for_infer import TransformerInferModel
|
||||
|
||||
|
||||
def get_config():
|
||||
config.compute_type = mstype.float16 if config.compute_type == "float16" else mstype.float32
|
||||
config.dtype = mstype.float16 if config.dtype == "float16" else mstype.float32
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
vocab = Dictionary.load_from_persisted_dict(config.vocab_file)
|
||||
get_config()
|
||||
dec_len = config.max_decode_length
|
||||
|
||||
tfm_model = TransformerInferModel(config=config, use_one_hot_embeddings=False)
|
||||
|
|
|
@ -18,7 +18,7 @@ export DEVICE_ID=0
|
|||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"`
|
||||
options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,output:,vocab:,metric: -- "$@"`
|
||||
eval set -- "$options"
|
||||
echo $options
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ export DEVICE_ID=0
|
|||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
options=`getopt -u -o ht:n:i::o:v:m: -l help,task:,device_num:,device_id:,config:,output:,vocab:,metric: -- "$@"`
|
||||
options=`getopt -u -o ht:n:i::o:v:m: -l help,task:,device_num:,device_id:,output:,vocab:,metric: -- "$@"`
|
||||
eval set -- "$options"
|
||||
echo $options
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 5 || $# -gt 6 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [CONFIG] [VOCAB] [OUTPUT] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [VOCAB] [OUTPUT] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
|
@ -29,24 +29,22 @@ get_real_path(){
|
|||
fi
|
||||
}
|
||||
model=$(get_real_path $1)
|
||||
config=$(get_real_path $2)
|
||||
vocab=$(get_real_path $3)
|
||||
output=$(get_real_path $4)
|
||||
vocab=$(get_real_path $2)
|
||||
output=$(get_real_path $3)
|
||||
|
||||
if [ "$5" == "y" ] || [ "$5" == "n" ];then
|
||||
need_preprocess=$5
|
||||
if [ "$4" == "y" ] || [ "$4" == "n" ];then
|
||||
need_preprocess=$4
|
||||
else
|
||||
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
device_id=0
|
||||
if [ $# == 6 ]; then
|
||||
device_id=$6
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "config: "$config
|
||||
echo "vocab: "$vocab
|
||||
echo "output: "$output
|
||||
echo "need preprocess: "$need_preprocess
|
||||
|
@ -72,7 +70,7 @@ function preprocess_data()
|
|||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
python3.7 ../preprocess.py --config=$config --result_path=./preprocess_Result/
|
||||
python3.7 ../preprocess.py
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
|
@ -99,7 +97,7 @@ function infer()
|
|||
|
||||
function cal_acc()
|
||||
{
|
||||
python3.7 ../postprocess.py --config=$config --vocab=$vocab --output=$output --source_id_folder=./preprocess_Result/00_source_eos_ids --target_id_folder=./preprocess_Result/target_eos_ids --result_dir=./result_Files &> acc.log
|
||||
python3.7 ../postprocess.py --vocab=$vocab --output=$output &> acc.log
|
||||
}
|
||||
|
||||
if [ $need_preprocess == "y" ]; then
|
||||
|
|
Loading…
Reference in New Issue