diff --git a/model_zoo/official/cv/crnn/README.md b/model_zoo/official/cv/crnn/README.md index a50b2593715..9285f90ff1e 100644 --- a/model_zoo/official/cv/crnn/README.md +++ b/model_zoo/official/cv/crnn/README.md @@ -160,15 +160,15 @@ max_text_length": 23, # max number of digits in each ### [Training](#contents) -- Run `run_standalone_train.sh` for non-distributed training of CRNN model, either on Ascend or on GPU. +- Run `run_standalone_train.sh` for non-distributed training of CRNN model, only support Ascend now. ``` bash -bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] +bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional) ``` #### [Distributed Training](#contents) -- Run `run_distribute_train.sh` for distributed training of WarpCTC model on Ascend. +- Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend. ``` bash bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH] @@ -188,7 +188,7 @@ Epoch time: 2743.688s, per step time: 0.097s - Run `run_eval.sh` for evaluation. ``` bash -bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM] +bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional) ``` Check the `eval/log.txt` and you will get outputs as following: @@ -232,7 +232,7 @@ result: {'CRNNAccuracy': (0.806)} | Dataset | SVT | IIIT5K | | batch_size | 1 | 1 | | outputs | ACC | ACC | -| Accuracy | 80.9% | 80.6% | +| Accuracy | 80.8% | 79.7% | | Model for inference | 83M (.ckpt file) | 83M (.ckpt file) | ## [Description of Random Situation](#contents) diff --git a/model_zoo/official/cv/crnn/scripts/run_eval.sh b/model_zoo/official/cv/crnn/scripts/run_eval.sh index 30bd8ee9603..da1d54ab3c3 100644 --- a/model_zoo/official/cv/crnn/scripts/run_eval.sh +++ b/model_zoo/official/cv/crnn/scripts/run_eval.sh @@ -14,8 +14,8 @@ # limitations under the License. # ============================================================================ -if [ $# != 4 ]; then - echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]" +if [ $# != 4 ] && [ $# != 3 ]; then + echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional) " exit 1 fi @@ -30,7 +30,12 @@ get_real_path() { DATASET_NAME=$1 PATH1=$(get_real_path $2) PATH2=$(get_real_path $3) -PLATFORM=$4 + +if [ $# == 4 ]; then + PLATFORM=$4 +else + PLATFORM="Ascend" +fi if [ ! -d $PATH1 ]; then echo "error: DATASET_PATH=$PATH1 is not a directory" diff --git a/model_zoo/official/cv/crnn/scripts/run_standalone_train.sh b/model_zoo/official/cv/crnn/scripts/run_standalone_train.sh index fde707aaaf6..a8e47eb20c7 100644 --- a/model_zoo/official/cv/crnn/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/crnn/scripts/run_standalone_train.sh @@ -14,8 +14,8 @@ # limitations under the License. # ============================================================================ -if [ $# != 3 ]; then - echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]" +if [ $# != 3 ] && [ $# != 2 ]; then + echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)" exit 1 fi @@ -29,7 +29,11 @@ get_real_path() { DATASET_NAME=$1 PATH1=$(get_real_path $2) -PLATFORM=$3 +if [ $# == 3 ]; then + PLATFORM=$3 +else + PLATFORM="Ascend" +fi if [ ! -d $PATH1 ]; then echo "error: DATASET_PATH=$PATH1 is not a directory" @@ -58,7 +62,7 @@ run_gpu() { if [ -d "train" ]; then rm -rf ./train fi -WORKDIR=./train$(DEVICE_ID) +WORKDIR=./train${DEVICE_ID} mkdir $WORKDIR cp ../*.py $WORKDIR cp -r ../src $WORKDIR diff --git a/model_zoo/official/cv/crnn/train.py b/model_zoo/official/cv/crnn/train.py index c71a4e4b8ab..9a8c28dcc97 100644 --- a/model_zoo/official/cv/crnn/train.py +++ b/model_zoo/official/cv/crnn/train.py @@ -34,8 +34,8 @@ set_seed(1) parser = argparse.ArgumentParser(description="crnn training") parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') -parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], - help='Running platform, choose from Ascend, GPU, and default is Ascend.') +parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'], + help='Running platform, only support Ascend now. Default is Ascend.') parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase") parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) parser.set_defaults(run_distribute=False) @@ -92,7 +92,7 @@ if __name__ == '__main__': model = Model(net) # define callbacks callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] - if config.save_checkpoint: + if config.save_checkpoint and rank == 0: config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, keep_checkpoint_max=config.keep_checkpoint_max) save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')