fix unet Ascend310 inference bug

This commit is contained in:
yuzhenhua 2021-06-25 15:34:28 +08:00 committed by root
parent 43174475e6
commit 24b47f7daf
7 changed files with 42 additions and 25 deletions

View File

@ -474,19 +474,20 @@ the steps below, this is a simple example:
Export MindIR
Before exporting, you need to modify the parameter in the configuration — checkpoint_file_path and batch_ Size . checkpoint_ file_ Path is the CKPT file path, batch_ Size is set to 1.
```shell
python export.py --checkpoint_file_path [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
python export.py --config_path=[CONFIG_PATH]
```
The checkpoint_file_path parameter is required,
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
Before performing inference, the MINDIR file must be exported by export script on the 910 environment.
Current batch_size can only be set to 1.
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
bash run_infer_310.sh [NETWORK] [MINDIR_PATH] [DEVICE_ID] [NEED_PREPROCESS]
```
`DEVICE_ID` is optional, default value is 0.

View File

@ -471,18 +471,17 @@ python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkp
导出mindir模型
在执行导出前需要修改配置文件中的checkpoint_file_path和batch_size参数。checkpoint_file_path为ckpt文件路径batch_size设置为1。
```shell
python export.py --checkpoint_file_path [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
python export.py --config_path=[CONFIG_PATH]
```
参数`checkpoint_file_path` 是必需的,`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中进行选择。
在执行推理前MINDIR文件必须在910上通过export.py文件导出。
目前仅可处理batch_Size为1。
```shell
# Ascend310 推理
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
bash run_infer_310.sh [NETWORK] [MINDIR_PATH] [DEVICE_ID] [NEED_PREPROCESS]
```
`DEVICE_ID` 可选,默认值为 0。

View File

@ -77,7 +77,7 @@ if __name__ == '__main__':
rst_path = config.rst_path
metrics = dice_coeff()
if config.dataset == "Cell_nuclei":
if hasattr(config, "dataset") and config.dataset == "Cell_nuclei":
img_size = tuple(config.image_size)
for i, bin_name in enumerate(os.listdir('./preprocess_Result/')):
f = bin_name.replace(".png", "")

View File

@ -29,7 +29,7 @@ def preprocess_dataset(data_dir, result_path, cross_valid_ind=1):
labels_list = []
for i, data in enumerate(valid_dataset):
file_name = "ISBI_test_bs_1_" + str(i) + ".bin"
file_path = result_path + file_name
file_path = os.path.join(result_path, file_name)
data[0].asnumpy().tofile(file_path)
labels_list.append(data[1].asnumpy())

View File

@ -15,7 +15,7 @@
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [NEED_PREPROCESS]
echo "Usage: bash run_infer_310.sh [NETWORK] [MINDIR_PATH] [DEVICE_ID] [NEED_PREPROCESS]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero.
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'."
exit 1
@ -28,8 +28,9 @@ get_real_path(){
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
network=$1
model=$(get_real_path $2)
if [ $# == 4 ]; then
device_id=$3
if [ -z $device_id ]; then
@ -40,8 +41,8 @@ if [ $# == 4 ]; then
fi
need_preprocess=$4
echo "network: " $network
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "device id: "$device_id
echo "need preprocess or not: "$need_preprocess
@ -65,7 +66,17 @@ function preprocess_data()
rm -rf ./preprocess_Result
fi
mkdir preprocess_Result
python3.7 ../preprocess.py --data_url=$data_path --result_path=./preprocess_Result/
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
if [ $network == "unet" ]; then
config_path="${BASEPATH}/../unet_simple_config.yaml"
elif [ $network == "unet++" ]; then
config_path="${BASEPATH}/../unet_nested_cell_config.yaml"
else
echo "unsupported network"
exit 1
fi
python3.7 ../preprocess.py --config_path=$config_path
}
function compile_app()
@ -93,7 +104,13 @@ function infer()
function cal_acc()
{
python3.7 ../postprocess.py --data_url=$data_path --rst_path=./result_Files/ &> acc.log &
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
if [ $network == "unet" ]; then
config_path="${BASEPATH}/../unet_simple_config.yaml"
elif [ $network == "unet++" ]; then
config_path="${BASEPATH}/../unet_nested_cell_config.yaml"
fi
python3.7 ../postprocess.py --config_path=$config_path &> acc.log &
}
preprocess_data

View File

@ -45,13 +45,13 @@ eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_nested_adam-4-75.ckpt'
rst_path: './result_Files/'
result_path: ""
result_path: "./preprocess_Result"
# Export options
width: 572
height: 572
file_name: "unet"
file_format: "AIR"
width: 96
height: 96
file_name: "unetplusplus"
file_format: "MINDIR"
---
# Help description for each configuration

View File

@ -41,13 +41,13 @@ eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_simple_adam-4-75.ckpt'
rst_path: './result_Files/'
result_path: ""
result_path: "./preprocess_Result"
# Export options
width: 572
height: 572
width: 576
height: 576
file_name: "unet"
file_format: "AIR"
file_format: "MINDIR"
---
# Help description for each configuration