forked from mindspore-Ecosystem/mindspore
!10297 Fix typos in CRNN and remove the support of GPU
From: @c_34 Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
62271467d5
|
@ -160,15 +160,15 @@ max_text_length": 23, # max number of digits in each
|
|||
|
||||
### [Training](#contents)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, either on Ascend or on GPU.
|
||||
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, only support Ascend now.
|
||||
|
||||
``` bash
|
||||
bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)
|
||||
```
|
||||
|
||||
#### [Distributed Training](#contents)
|
||||
|
||||
- Run `run_distribute_train.sh` for distributed training of WarpCTC model on Ascend.
|
||||
- Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend.
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
@ -188,7 +188,7 @@ Epoch time: 2743.688s, per step time: 0.097s
|
|||
- Run `run_eval.sh` for evaluation.
|
||||
|
||||
``` bash
|
||||
bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
|
||||
bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional)
|
||||
```
|
||||
|
||||
Check the `eval/log.txt` and you will get outputs as following:
|
||||
|
@ -232,7 +232,7 @@ result: {'CRNNAccuracy': (0.806)}
|
|||
| Dataset | SVT | IIIT5K |
|
||||
| batch_size | 1 | 1 |
|
||||
| outputs | ACC | ACC |
|
||||
| Accuracy | 80.9% | 80.6% |
|
||||
| Accuracy | 80.8% | 79.7% |
|
||||
| Model for inference | 83M (.ckpt file) | 83M (.ckpt file) |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
|
||||
if [ $# != 4 ] && [ $# != 3 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional) "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -30,7 +30,12 @@ get_real_path() {
|
|||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
PLATFORM=$4
|
||||
|
||||
if [ $# == 4 ]; then
|
||||
PLATFORM=$4
|
||||
else
|
||||
PLATFORM="Ascend"
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]"
|
||||
if [ $# != 3 ] && [ $# != 2 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -29,7 +29,11 @@ get_real_path() {
|
|||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PLATFORM=$3
|
||||
if [ $# == 3 ]; then
|
||||
PLATFORM=$3
|
||||
else
|
||||
PLATFORM="Ascend"
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
|
@ -58,7 +62,7 @@ run_gpu() {
|
|||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
WORKDIR=./train$(DEVICE_ID)
|
||||
WORKDIR=./train${DEVICE_ID}
|
||||
mkdir $WORKDIR
|
||||
cp ../*.py $WORKDIR
|
||||
cp -r ../src $WORKDIR
|
||||
|
|
|
@ -34,8 +34,8 @@ set_seed(1)
|
|||
parser = argparse.ArgumentParser(description="crnn training")
|
||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'],
|
||||
help='Running platform, only support Ascend now. Default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
parser.set_defaults(run_distribute=False)
|
||||
|
@ -92,7 +92,7 @@ if __name__ == '__main__':
|
|||
model = Model(net)
|
||||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
if config.save_checkpoint:
|
||||
if config.save_checkpoint and rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
|
|
Loading…
Reference in New Issue