Modify the way warpctc passes parameters
This commit is contained in:
parent
b159ca939f
commit
ef3f0cb0cc
|
@ -1,420 +1,507 @@
|
|||
# Contents
|
||||
|
||||
- [CNNCTC Description](#CNNCTC-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Inference Process](#inference-process)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Infer on Ascend310](#infer-on-ascend310)
|
||||
- [result](#result)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [How to use](#how-to-use)
|
||||
- [Inference](#inference)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Transfer Learning](#transfer-learning)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CNNCTC Description](#contents)
|
||||
|
||||
This paper proposes three major contributions to addresses scene text recognition (STR).
|
||||
First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies.
|
||||
Second, we introduce a unified four-stage STR framework that most existing STR models fit into.
|
||||
Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously
|
||||
unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed,
|
||||
and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current
|
||||
comparisons to understand the performance gain of the existing modules.
|
||||
[Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation.
|
||||
|
||||
- step 1:
|
||||
|
||||
All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt).
|
||||
|
||||
- step 2:
|
||||
|
||||
Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT.
|
||||
|
||||
- step 3:
|
||||
|
||||
Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below:
|
||||
|
||||
```text
|
||||
|--- CNNCTC/
|
||||
|--- cnnctc_data/
|
||||
|--- ST/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|--- MJ/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|--- IIIT/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|
||||
......
|
||||
```
|
||||
|
||||
- step 4:
|
||||
|
||||
Preprocess the dataset by running:
|
||||
|
||||
```bash
|
||||
python src/preprocess_dataset.py
|
||||
```
|
||||
|
||||
This takes around 75 minutes.
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
- Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install lmdb
|
||||
pip install Pillow
|
||||
pip install tqdm
|
||||
pip install six
|
||||
```
|
||||
|
||||
- Standalone Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
- Distributed Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
- Evaluation:
|
||||
|
||||
```bash
|
||||
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
The entire code structure is as following:
|
||||
|
||||
```text
|
||||
|--- CNNCTC/
|
||||
|---README.md // descriptions about cnnctc
|
||||
|---train.py // train scripts
|
||||
|---eval.py // eval scripts
|
||||
|---export.py // export scripts
|
||||
|---pstprocess.py // postprocess scripts
|
||||
|---ascend310_infer // application for 310 inference
|
||||
|---scripts
|
||||
|---run_infer_310.sh // shell script for infer on ascend310
|
||||
|---run_standalone_train_ascend.sh // shell script for standalone on ascend
|
||||
|---run_distribute_train_ascend.sh // shell script for distributed on ascend
|
||||
|---run_eval_ascend.sh // shell script for eval on ascend
|
||||
|---src
|
||||
|---__init__.py // init file
|
||||
|---cnn_ctc.py // cnn_ctc network
|
||||
|---config.py // total config
|
||||
|---callback.py // loss callback file
|
||||
|---dataset.py // process dataset
|
||||
|---util.py // routine operation
|
||||
|---preprocess_dataset.py // preprocess dataset
|
||||
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in `config.py`.
|
||||
|
||||
Arguments:
|
||||
|
||||
- `--CHARACTER`: Character labels.
|
||||
- `--NUM_CLASS`: The number of classes including all character labels and the <blank> label for CTCLoss.
|
||||
- `--HIDDEN_SIZE`: Model hidden size.
|
||||
- `--FINAL_FEATURE_WIDTH`: The number of features.
|
||||
- `--IMG_H`: The height of input image.
|
||||
- `--IMG_W`: The width of input image.
|
||||
- `--TRAIN_DATASET_PATH`: The path to training dataset.
|
||||
- `--TRAIN_DATASET_INDEX_PATH`: The path to training dataset index file which determines the order .
|
||||
- `--TRAIN_BATCH_SIZE`: Training batch size. The batch size and index file must ensure input data is in fixed shape.
|
||||
- `--TRAIN_DATASET_SIZE`: Training dataset size.
|
||||
- `--TEST_DATASET_PATH`: The path to test dataset.
|
||||
- `--TEST_BATCH_SIZE`: Test batch size.
|
||||
- `--TRAIN_EPOCHS`:Total training epochs.
|
||||
- `--CKPT_PATH`:The path to model checkpoint file, can be used to resume training and evaluation.
|
||||
- `--SAVE_PATH`:The path to save model checkpoint file.
|
||||
- `--LR`:Learning rate for standalone training.
|
||||
- `--LR_PARA`:Learning rate for distributed training.
|
||||
- `--MOMENTUM`:Momentum.
|
||||
- `--LOSS_SCALE`:Loss scale to prevent gradient underflow.
|
||||
- `--SAVE_CKPT_PER_N_STEP`:Save model checkpoint file per N steps.
|
||||
- `--KEEP_CKPT_MAX_NUM`:The maximum number of saved model checkpoint file.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- Standalone Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`.
|
||||
|
||||
`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||
|
||||
- Distributed Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively.
|
||||
Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`.
|
||||
|
||||
`$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend.
|
||||
`$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||
|
||||
### Training Result
|
||||
|
||||
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the following in loss.log.
|
||||
|
||||
```text
|
||||
# distribute training result(8p)
|
||||
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.235177839748392712
|
||||
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.25798572540283203
|
||||
epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.229678678512573
|
||||
epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.23512671788533527
|
||||
epoch: 1 step: 5 , loss is 58.375, average time per step is 0.23149147033691406
|
||||
epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.2292975425720215
|
||||
...
|
||||
epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.2184656601312549
|
||||
epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.2184725407765116
|
||||
epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.21847309686135555
|
||||
epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.21847339290613375
|
||||
epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.2184720295013031
|
||||
epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.21847410284595573
|
||||
epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.21847338271072345
|
||||
epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.2184726025560777
|
||||
epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.21847212061114694
|
||||
epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.2184715309307257
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- Evaluation:
|
||||
|
||||
```bash
|
||||
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||
```
|
||||
|
||||
The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed.
|
||||
|
||||
## [Inference process](#contents)
|
||||
|
||||
### Export MindIR
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
The file_name parameter is file name after export.
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
### Infer on Ascend310
|
||||
|
||||
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
|
||||
Current batch_size can only be set to 1, modify the parameter `TEST_BATCH_SIZE` in `config.py` to 1 before export the model
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_PATH] [DVPP] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `DVPP` is mandatory, and must choose from ["DVPP", "CPU"], it's case-insensitive. CNNCTC only support CPU mode .
|
||||
- `DEVICE_ID` is optional, default value is 0.
|
||||
|
||||
### result
|
||||
|
||||
Inference result is saved in current path, you can find result like this in acc.log file.
|
||||
|
||||
```bash
|
||||
'Accuracy': 0.8546
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | CNNCTC |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 09/28/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | MJSynth,SynthText |
|
||||
| Training Parameters | epoch=3, batch_size=192 |
|
||||
| Optimizer | RMSProp |
|
||||
| Loss Function | CTCLoss |
|
||||
| Speed | 1pc: 250 ms/step; 8pcs: 260 ms/step |
|
||||
| Total time | 1pc: 15 hours; 8pcs: 1.92 hours |
|
||||
| Parameters (M) | 177 |
|
||||
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc> |
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | CNNCTC |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 09/28/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | IIIT5K |
|
||||
| batch_size | 192 |
|
||||
| outputs | Accuracy |
|
||||
| Accuracy | 85% |
|
||||
| Model for inference | 675M (.ckpt file) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | CNNCTC |
|
||||
| Resource | Ascend 310; CentOS 3.10 |
|
||||
| Uploaded Date | 19/05/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | IIIT5K |
|
||||
| batch_size | 1 |
|
||||
| outputs | Accuracy |
|
||||
| Accuracy | Accuracy=0.8546 |
|
||||
| Model for inference | 675M(.ckpt file) |
|
||||
|
||||
## [How to use](#contents)
|
||||
|
||||
### Inference
|
||||
|
||||
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```python
|
||||
# Set context
|
||||
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
# Load unseen dataset for inference
|
||||
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||
|
||||
# Define model
|
||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
|
||||
cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
# Load pre-trained model
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# Make predictions on the unseen dataset
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
```
|
||||
|
||||
### Continue Training on the Pretrained Model
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```python
|
||||
# Load dataset
|
||||
dataset = create_dataset(cfg.data_path, 1)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
# Define model
|
||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||
# Continue training if set pre_trained to be True
|
||||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
|
||||
steps_per_epoch=batch_num)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
|
||||
|
||||
# Set callbacks
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
# Start training
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
```
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
# Contents
|
||||
|
||||
- [CNNCTC Description](#CNNCTC-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Inference Process](#inference-process)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Infer on Ascend310](#infer-on-ascend310)
|
||||
- [result](#result)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [How to use](#how-to-use)
|
||||
- [Inference](#inference)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Transfer Learning](#transfer-learning)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CNNCTC Description](#contents)
|
||||
|
||||
This paper proposes three major contributions to addresses scene text recognition (STR).
|
||||
First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies.
|
||||
Second, we introduce a unified four-stage STR framework that most existing STR models fit into.
|
||||
Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously
|
||||
unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed,
|
||||
and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current
|
||||
comparisons to understand the performance gain of the existing modules.
|
||||
[Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation.
|
||||
|
||||
- step 1:
|
||||
|
||||
All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt).
|
||||
|
||||
- step 2:
|
||||
|
||||
Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT.
|
||||
|
||||
- step 3:
|
||||
|
||||
Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below:
|
||||
|
||||
```text
|
||||
|--- CNNCTC/
|
||||
|--- cnnctc_data/
|
||||
|--- ST/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|--- MJ/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|--- IIIT/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|
||||
......
|
||||
```
|
||||
|
||||
- step 4:
|
||||
|
||||
Preprocess the dataset by running:
|
||||
|
||||
```bash
|
||||
python src/preprocess_dataset.py
|
||||
```
|
||||
|
||||
This takes around 75 minutes.
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
- Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install lmdb
|
||||
pip install Pillow
|
||||
pip install tqdm
|
||||
pip install six
|
||||
```
|
||||
|
||||
- Standalone Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
- Distributed Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
- Evaluation:
|
||||
|
||||
```bash
|
||||
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
The entire code structure is as following:
|
||||
|
||||
```text
|
||||
|--- CNNCTC/
|
||||
|---README.md // descriptions about cnnctc
|
||||
|---README_cn.md // descriptions about cnnctc
|
||||
|---default_config.yaml // config file
|
||||
|---train.py // train scripts
|
||||
|---eval.py // eval scripts
|
||||
|---export.py // export scripts
|
||||
|---preprocess.py // preprocess scripts
|
||||
|---postprocess.py // postprocess scripts
|
||||
|---ascend310_infer // application for 310 inference
|
||||
|---scripts
|
||||
|---run_infer_310.sh // shell script for infer on ascend310
|
||||
|---run_standalone_train_ascend.sh // shell script for standalone on ascend
|
||||
|---run_distribute_train_ascend.sh // shell script for distributed on ascend
|
||||
|---run_eval_ascend.sh // shell script for eval on ascend
|
||||
|---src
|
||||
|---__init__.py // init file
|
||||
|---cnn_ctc.py // cnn_ctc network
|
||||
|---config.py // total config
|
||||
|---callback.py // loss callback file
|
||||
|---dataset.py // process dataset
|
||||
|---util.py // routine operation
|
||||
|---preprocess_dataset.py // preprocess dataset
|
||||
|--- model_utils
|
||||
|---config.py // Parameter config
|
||||
|---moxing_adapter.py // modelarts device configuration
|
||||
|---device_adapter.py // Device Config
|
||||
|---local_adapter.py // local device config
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in `default_config.yaml`.
|
||||
|
||||
Arguments:
|
||||
|
||||
- `--CHARACTER`: Character labels.
|
||||
- `--NUM_CLASS`: The number of classes including all character labels and the <blank> label for CTCLoss.
|
||||
- `--HIDDEN_SIZE`: Model hidden size.
|
||||
- `--FINAL_FEATURE_WIDTH`: The number of features.
|
||||
- `--IMG_H`: The height of input image.
|
||||
- `--IMG_W`: The width of input image.
|
||||
- `--TRAIN_DATASET_PATH`: The path to training dataset.
|
||||
- `--TRAIN_DATASET_INDEX_PATH`: The path to training dataset index file which determines the order .
|
||||
- `--TRAIN_BATCH_SIZE`: Training batch size. The batch size and index file must ensure input data is in fixed shape.
|
||||
- `--TRAIN_DATASET_SIZE`: Training dataset size.
|
||||
- `--TEST_DATASET_PATH`: The path to test dataset.
|
||||
- `--TEST_BATCH_SIZE`: Test batch size.
|
||||
- `--TRAIN_EPOCHS`:Total training epochs.
|
||||
- `--CKPT_PATH`:The path to model checkpoint file, can be used to resume training and evaluation.
|
||||
- `--SAVE_PATH`:The path to save model checkpoint file.
|
||||
- `--LR`:Learning rate for standalone training.
|
||||
- `--LR_PARA`:Learning rate for distributed training.
|
||||
- `--MOMENTUM`:Momentum.
|
||||
- `--LOSS_SCALE`:Loss scale to prevent gradient underflow.
|
||||
- `--SAVE_CKPT_PER_N_STEP`:Save model checkpoint file per N steps.
|
||||
- `--KEEP_CKPT_MAX_NUM`:The maximum number of saved model checkpoint file.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- Standalone Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [PRETRAINED_CKPT(options)]
|
||||
```
|
||||
|
||||
Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`.
|
||||
|
||||
`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||
|
||||
- Distributed Training:
|
||||
|
||||
```bash
|
||||
bash scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT(options)]
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
|
||||
|
||||
Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively.
|
||||
Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`.
|
||||
|
||||
`$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend.
|
||||
`$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||
|
||||
### Training Result
|
||||
|
||||
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the following in loss.log.
|
||||
|
||||
```text
|
||||
# distribute training result(8p)
|
||||
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.235177839748392712
|
||||
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.25798572540283203
|
||||
epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.229678678512573
|
||||
epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.23512671788533527
|
||||
epoch: 1 step: 5 , loss is 58.375, average time per step is 0.23149147033691406
|
||||
epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.2292975425720215
|
||||
...
|
||||
epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.2184656601312549
|
||||
epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.2184725407765116
|
||||
epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.21847309686135555
|
||||
epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.21847339290613375
|
||||
epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.2184720295013031
|
||||
epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.21847410284595573
|
||||
epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.21847338271072345
|
||||
epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.2184726025560777
|
||||
epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.21847212061114694
|
||||
epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.2184715309307257
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```python
|
||||
# Example of using distributed training dpn on modelarts :
|
||||
# Data set storage method
|
||||
|
||||
# ├── CNNCTC_Data # dataset dir
|
||||
# ├──train # train dir
|
||||
# ├── ST_MJ # train dataset dir
|
||||
# ├── data.mdb # data file
|
||||
# ├── lock.mdb
|
||||
# ├── st_mj_fixed_length_index_list.pkl
|
||||
# ├── eval # eval dir
|
||||
# ├── IIIT5K_3000 # eval dataset dir
|
||||
# ├── checkpoint # checkpoint dir
|
||||
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
|
||||
# a. set "enable_modelarts=True"
|
||||
# set "run_distribute=True"
|
||||
# set "TRAIN_DATASET_PATH=/cache/data/ST_MJ/"
|
||||
# set "TRAIN_DATASET_INDEX_PATH=/cache/data/st_mj_fixed_length_index_list.pkl"
|
||||
# set "SAVE_PATH=/cache/train/checkpoint"
|
||||
#
|
||||
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/cnnctc"。
|
||||
# (4) Set the model's startup file on the modelarts interface "export.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../CNNCTC_Data/train"(choices CNNCTC_Data/train Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (6) start trainning the model。
|
||||
|
||||
# Example of using model inference on modelarts
|
||||
# (1) Place the trained model to the corresponding position of the bucket。
|
||||
# (2) chocie a or b。
|
||||
# a.set "enable_modelarts=True"
|
||||
# set "TEST_DATASET_PATH=/cache/data/IIIT5K_3000/"
|
||||
# set "CHECKPOINT_PATH=/cache/data/checkpoint/checkpoint file name"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) Set the code path on the modelarts interface "/path/cnnctc"。
|
||||
# (5) Set the model's startup file on the modelarts interface "export.py" 。
|
||||
# (6) Set the data path of the model on the modelarts interface ".../CNNCTC_Data/train"(choices CNNCTC_Data/train Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (7) Start model inference。
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- Evaluation:
|
||||
|
||||
```bash
|
||||
bash scripts/run_eval_ascend.sh [DEVICE_ID] [TRAINED_CKPT]
|
||||
```
|
||||
|
||||
The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed.
|
||||
|
||||
## [Inference process](#contents)
|
||||
|
||||
### Export MindIR
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT] --TEST_BATCH_SIZE [BATCH_SIZE]
|
||||
```
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"].
|
||||
`BATCH_SIZE` current batch_size can only be set to 1.
|
||||
|
||||
- Export MindIR on Modelarts
|
||||
|
||||
```Modelarts
|
||||
Export MindIR example on ModelArts
|
||||
Data storage method is the same as training
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters)。
|
||||
# a. set "enable_modelarts=True"
|
||||
# set "file_name=/cache/train/cnnctc"
|
||||
# set "file_format=MINDIR"
|
||||
# set "ckpt_file=/cache/data/checkpoint file name"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
# (2)Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/cnnctc"。
|
||||
# (4) Set the model's startup file on the modelarts interface "export.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../CNNCTC_Data/eval/checkpoint"(choices CNNCTC_Data/eval/checkpoint Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
```
|
||||
|
||||
### Infer on Ascend310
|
||||
|
||||
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `DVPP` is mandatory, and must choose from ["DVPP", "CPU"], it's case-insensitive. CNNCTC only support CPU mode .
|
||||
- `DEVICE_ID` is optional, default value is 0.
|
||||
|
||||
### result
|
||||
|
||||
Inference result is saved in current path, you can find result like this in acc.log file.
|
||||
|
||||
```bash
|
||||
'Accuracy': 0.8642
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | CNNCTC |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 09/28/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | MJSynth,SynthText |
|
||||
| Training Parameters | epoch=3, batch_size=192 |
|
||||
| Optimizer | RMSProp |
|
||||
| Loss Function | CTCLoss |
|
||||
| Speed | 1pc: 250 ms/step; 8pcs: 260 ms/step |
|
||||
| Total time | 1pc: 15 hours; 8pcs: 1.92 hours |
|
||||
| Parameters (M) | 177 |
|
||||
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc> |
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | CNNCTC |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 09/28/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | IIIT5K |
|
||||
| batch_size | 192 |
|
||||
| outputs | Accuracy |
|
||||
| Accuracy | 85% |
|
||||
| Model for inference | 675M (.ckpt file) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | CNNCTC |
|
||||
| Resource | Ascend 310; CentOS 3.10 |
|
||||
| Uploaded Date | 19/05/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | IIIT5K |
|
||||
| batch_size | 1 |
|
||||
| outputs | Accuracy |
|
||||
| Accuracy | Accuracy=0.8642 |
|
||||
| Model for inference | 675M(.ckpt file) |
|
||||
|
||||
## [How to use](#contents)
|
||||
|
||||
### Inference
|
||||
|
||||
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```python
|
||||
# Set context
|
||||
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
# Load unseen dataset for inference
|
||||
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||
|
||||
# Define model
|
||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
|
||||
cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
# Load pre-trained model
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# Make predictions on the unseen dataset
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
```
|
||||
|
||||
### Continue Training on the Pretrained Model
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```python
|
||||
# Load dataset
|
||||
dataset = create_dataset(cfg.data_path, 1)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
# Define model
|
||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||
# Continue training if set pre_trained to be True
|
||||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
|
||||
steps_per_epoch=batch_num)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
|
||||
|
||||
# Set callbacks
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
# Start training
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
```
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -127,19 +127,19 @@ pip install six
|
|||
- 单机训练:
|
||||
|
||||
```shell
|
||||
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [PRETRAINED_CKPT(options)]
|
||||
```
|
||||
|
||||
- 分布式训练:
|
||||
|
||||
```shell
|
||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||
bash scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT(options)]
|
||||
```
|
||||
|
||||
- 评估:
|
||||
|
||||
```shell
|
||||
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||
bash scripts/run_eval_ascend.sh DEVICE_ID TRAINED_CKPT
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
@ -150,12 +150,15 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
|||
|
||||
```python
|
||||
|--- CNNCTC/
|
||||
|---README_CN.md // CNN+CTC相关描述
|
||||
|---README.md // CNN+CTC相关描述
|
||||
|---train.py // 训练脚本
|
||||
|---eval.py // 评估脚本
|
||||
|---export.py // 模型导出脚本
|
||||
|---postprocess.py // 推理后处理脚本
|
||||
|---ascend310_infer // 用于310推理
|
||||
|---preprocess.py // 推理前处理脚本
|
||||
|---ascend310_infer // 用于310推理
|
||||
|---default_config.yaml // 参数配置
|
||||
|---scripts
|
||||
|---run_standalone_train_ascend.sh // Ascend单机shell脚本
|
||||
|---run_distribute_train_ascend.sh // Ascend分布式shell脚本
|
||||
|
@ -164,18 +167,22 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
|||
|---src
|
||||
|---__init__.py // init文件
|
||||
|---cnn_ctc.py // cnn_ctc网络
|
||||
|---config.py // 总配置
|
||||
|---callback.py // 损失回调文件
|
||||
|---dataset.py // 处理数据集
|
||||
|---util.py // 常规操作
|
||||
|---generate_hccn_file.py // 生成分布式json文件
|
||||
|---preprocess_dataset.py // 预处理数据集
|
||||
|---model_utils
|
||||
|---config.py # 参数生成
|
||||
|---device_adapter.py # 设备相关信息
|
||||
|---local_adapter.py # 设备相关信息
|
||||
|---moxing_adapter.py # 装饰器(主要用于ModelArts数据拷贝)
|
||||
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在`config.py`中可以同时配置训练参数和评估参数。
|
||||
在`default_config.yaml`中可以同时配置训练参数和评估参数。
|
||||
|
||||
参数:
|
||||
|
||||
|
@ -208,7 +215,7 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
|||
- 单机训练:
|
||||
|
||||
```shell
|
||||
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [PRETRAINED_CKPT(options)]
|
||||
```
|
||||
|
||||
结果和检查点被写入`./train`文件夹。日志可以在`./train/log`中找到,损失值记录在`./train/loss.log`中。
|
||||
|
@ -218,7 +225,7 @@ bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
|||
- 分布式训练:
|
||||
|
||||
```shell
|
||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||
bash scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT(options)]
|
||||
```
|
||||
|
||||
结果和检查点分别写入设备`i`的`./train_parallel_{i}`文件夹。
|
||||
|
@ -227,6 +234,10 @@ bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
|||
在Ascend上运行分布式任务时需要`$RANK_TABLE_FILE`。
|
||||
`$PATH_TO_CHECKPOINT`为模型检查点的路径,**可选**。如果值为none,模型将从头开始训练。
|
||||
|
||||
> 注意:
|
||||
|
||||
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
### 训练结果
|
||||
|
||||
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
|
||||
|
@ -259,31 +270,104 @@ epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184
|
|||
- 评估:
|
||||
|
||||
```shell
|
||||
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||
bash scripts/run_eval_ascend.sh [DEVICE_ID] [TRAINED_CKPT]
|
||||
```
|
||||
|
||||
在IIIT数据集上评估模型,并打印样本结果和总准确率。
|
||||
|
||||
- 如果要在modelarts上进行模型的训练,可以参考modelarts的[官方指导文档](https://support.huaweicloud.com/modelarts/) 开始进行模型的训练和推理,具体操作如下:
|
||||
|
||||
```ModelArts
|
||||
# 在ModelArts上使用分布式训练示例:
|
||||
# 数据集存放方式
|
||||
|
||||
# ├── CNNCTC_Data # dataset dir
|
||||
# ├──train # train dir
|
||||
# ├── ST_MJ # train dataset dir
|
||||
# ├── data.mdb # data file
|
||||
# ├── lock.mdb
|
||||
# ├── st_mj_fixed_length_index_list.pkl
|
||||
# ├── eval # eval dir
|
||||
# ├── IIIT5K_3000 # eval dataset dir
|
||||
# ├── checkpoint # checkpoint dir
|
||||
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "run_distribute=True"
|
||||
# 设置 "TRAIN_DATASET_PATH=/cache/data/ST_MJ/"
|
||||
# 设置 "TRAIN_DATASET_INDEX_PATH=/cache/data/st_mj_fixed_length_index_list.pkl"
|
||||
# 设置 "SAVE_PATH=/cache/train/checkpoint"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/cnnctc"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../CNNCTC_Data/train"(选择CNNCTC_Data/train文件夹路径) ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (6) 开始模型的训练。
|
||||
|
||||
# 在modelarts上使用模型推理的示例
|
||||
# (1) 把训练好的模型地方到桶的对应位置。
|
||||
# (2) 选择a或者b其中一种方式。
|
||||
# a.设置 "enable_modelarts=True"
|
||||
# 设置 "TEST_DATASET_PATH=/cache/data/IIIT5K_3000/"
|
||||
# 设置 "CHECKPOINT_PATH=/cache/data/checkpoint/checkpoint file name"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (3) 设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) 在modelarts的界面上设置代码的路径 "/path/cnnctc"。
|
||||
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
|
||||
# (6) 在modelarts的界面上设置模型的数据路径 "../CNNCTC_Data/eval"(选择CNNCTC_Data/eval文件夹路径) ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (7) 开始模型的推理。
|
||||
```
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 导出MindIR
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT]
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT] --TEST_BATCH_SIZE [BATCH_SIZE]
|
||||
```
|
||||
|
||||
参数ckpt_file为必填项,
|
||||
参数file_name为导出后文件名,
|
||||
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"].
|
||||
`BATCH_SIZE` 目前仅支持batch_size为1的推理.
|
||||
|
||||
- 在modelarts上导出MindIR
|
||||
|
||||
```Modelarts
|
||||
在ModelArts上导出MindIR示例
|
||||
数据集存放方式同Modelart训练
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "file_name=/cache/train/cnnctc"
|
||||
# 设置 "file_format=MINDIR"
|
||||
# 设置 "ckpt_file=/cache/data/checkpoint file name"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/cnnctc"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "export.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../CNNCTC_Data/eval/checkpoint"(选择CNNCTC_Data/eval/checkpoint文件夹路径) ,
|
||||
# MindIR的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
```
|
||||
|
||||
### 在Ascend310执行推理
|
||||
|
||||
在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用mindir模型执行推理的示例。
|
||||
目前仅支持batch_size为1的推理,导出模型前请修改`config.py`中的参数`TEST_BATCH_SIZE`为1。
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_PATH] [DVPP] [DEVICE_ID]
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `DVPP` 为必填项,需要在["DVPP", "CPU"]选择,大小写均可。CNNCTC目前仅支持使用CPU算子进行推理。
|
||||
|
@ -294,7 +378,7 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_PATH] [DVPP] [DEVICE_ID]
|
|||
推理结果保存在脚本执行的当前路径,你可以在acc.log中看到以下精度计算结果。
|
||||
|
||||
```bash
|
||||
'Accuracy':0.8546
|
||||
'Accuracy':0.8642
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
@ -343,7 +427,7 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_PATH] [DVPP] [DEVICE_ID]
|
|||
| 数据集 | IIIT5K |
|
||||
| batch_size | 1 |
|
||||
| 输出 | Accuracy |
|
||||
| 准确率 | Accuracy=0.8546 |
|
||||
| 准确率 | Accuracy=0.8642 |
|
||||
| 推理模型 | 675M(.ckpt文件) |
|
||||
|
||||
## 用法
|
||||
|
|
|
@ -75,13 +75,13 @@ int PadImage(const MSTensor &input, MSTensor *output) {
|
|||
paddingSize = FLAGS_image_width - NewWidth;
|
||||
if (NewWidth > FLAGS_image_width) {
|
||||
std::shared_ptr<TensorTransform> resize(new Resize({FLAGS_image_height, FLAGS_image_width},
|
||||
InterpolationMode::kArea));
|
||||
InterpolationMode::kCubicPil));
|
||||
Execute composeResize({resize});
|
||||
composeResize(input, &imgResize);
|
||||
composeNormalize(imgResize, output);
|
||||
} else {
|
||||
std::shared_ptr<TensorTransform> resize(new Resize({FLAGS_image_height, NewWidth},
|
||||
InterpolationMode::kArea));
|
||||
InterpolationMode::kCubicPil));
|
||||
Execute composeResize({resize});
|
||||
composeResize(input, &imgResize);
|
||||
composeNormalize(imgResize, &imgNormalize);
|
||||
|
|
|
@ -14,23 +14,16 @@
|
|||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.config import Config_CNNCTC
|
||||
from src.model_utils.config import config
|
||||
from src.util import CTCLabelConverter
|
||||
|
||||
parser = argparse.ArgumentParser(description="cnnctc acc calculation")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="label path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def calcul_acc(labels, preds):
|
||||
return sum(1 for x, y in zip(labels, preds) if x == y) / len(labels)
|
||||
|
||||
|
||||
def get_result(result_path, label_path):
|
||||
config = Config_CNNCTC()
|
||||
converter = CTCLabelConverter(config.CHARACTER)
|
||||
files = os.listdir(result_path)
|
||||
preds = []
|
||||
|
@ -42,7 +35,7 @@ def get_result(result_path, label_path):
|
|||
label_dict[line.split(',')[0]] = line.split(',')[1].replace('\n', '')
|
||||
for file in files:
|
||||
file_name = file.split('.')[0]
|
||||
label = label_dict[file_name + '.png']
|
||||
label = label_dict[file_name]
|
||||
labels.append(label)
|
||||
resultPath = os.path.join(result_path, file)
|
||||
output = np.fromfile(resultPath, dtype=np.float32)
|
||||
|
@ -51,10 +44,10 @@ def get_result(result_path, label_path):
|
|||
preds_size = np.array([model_predict.shape[0]] * 1)
|
||||
preds_index = np.argmax(model_predict, axis=1)
|
||||
preds_str = converter.decode(preds_index, preds_size)
|
||||
preds.append(preds_str[0].upper())
|
||||
preds.append(preds_str[0])
|
||||
acc = calcul_acc(labels, preds)
|
||||
print("TOtal data: {}, accuracy: {}".format(len(labels), acc))
|
||||
print("Total data: {}, accuracy: {}".format(len(labels), acc))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_result(args.result_path, args.label_path)
|
||||
get_result(config.result_path, config.label_path)
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import sys
|
||||
import six
|
||||
import lmdb
|
||||
from PIL import Image
|
||||
from src.model_utils.config import config
|
||||
from src.util import CTCLabelConverter
|
||||
|
||||
|
||||
def get_img_from_lmdb(env_, ind):
|
||||
with env_.begin(write=False) as txn_:
|
||||
label_key = 'label-%09d'.encode() % ind
|
||||
label_ = txn_.get(label_key).decode('utf-8')
|
||||
img_key = 'image-%09d'.encode() % ind
|
||||
imgbuf = txn_.get(img_key)
|
||||
|
||||
buf = six.BytesIO()
|
||||
buf.write(imgbuf)
|
||||
buf.seek(0)
|
||||
try:
|
||||
img_ = Image.open(buf).convert('RGB') # for color image
|
||||
|
||||
except IOError:
|
||||
print(f'Corrupted image for {ind}')
|
||||
# make dummy image and dummy label for corrupted image.
|
||||
img_ = Image.new('RGB', (config.IMG_W, config.IMG_H))
|
||||
label_ = '[dummy_label]'
|
||||
|
||||
label_ = label_.lower()
|
||||
|
||||
return img_, label_
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
max_len = int((26 + 1) // 2)
|
||||
converter = CTCLabelConverter(config.CHARACTER)
|
||||
env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
if not env:
|
||||
print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH))
|
||||
sys.exit(0)
|
||||
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
# Filtering
|
||||
filtered_index_list = []
|
||||
for index_ in range(nSamples):
|
||||
index_ += 1 # lmdb starts with 1
|
||||
label_key_ = 'label-%09d'.encode() % index_
|
||||
label = txn.get(label_key_).decode('utf-8')
|
||||
|
||||
if len(label) > max_len:
|
||||
continue
|
||||
|
||||
illegal_sample = False
|
||||
for char_item in label.lower():
|
||||
if char_item not in config.CHARACTER:
|
||||
illegal_sample = True
|
||||
break
|
||||
if illegal_sample:
|
||||
continue
|
||||
|
||||
filtered_index_list.append(index_)
|
||||
|
||||
img_ret = []
|
||||
text_ret = []
|
||||
|
||||
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}')
|
||||
i = 0
|
||||
label_dict = {}
|
||||
for index in filtered_index_list:
|
||||
img, label = get_img_from_lmdb(env, index)
|
||||
img_name = os.path.join(config.preprocess_output, str(i) + ".png")
|
||||
img.save(img_name)
|
||||
label_dict[str(i)] = label
|
||||
i += 1
|
||||
with open('./label.txt', 'w') as file:
|
||||
for k, v in label_dict.items():
|
||||
file.write(str(k) + ',' + str(v) + '\n')
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 3 || $# -gt 4 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [LABEL_PATH] [DVPP] [DEVICE_ID]
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
|
||||
DVPP is mandatory, and must choose from [DVPP|CPU], it's case-insensitive
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
|
@ -30,17 +30,15 @@ get_real_path(){
|
|||
}
|
||||
model=$(get_real_path $1)
|
||||
data_path=$(get_real_path $2)
|
||||
label_path=$(get_real_path $3)
|
||||
DVPP=${4^^}
|
||||
DVPP=${3^^}
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
if [ $# == 4 ]; then
|
||||
device_id=$4
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "label path: "$label_path
|
||||
echo "image process mode: "$DVPP
|
||||
echo "device id: "$device_id
|
||||
|
||||
|
@ -58,6 +56,16 @@ else
|
|||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function preprocess_data()
|
||||
{
|
||||
if [ -d preprocess_Result ]; then
|
||||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
python ../preprocess.py --preprocess_output=./preprocess_Result --TEST_DATASET_PATH=$data_path &> preprocess.log
|
||||
data_path=./preprocess_Result
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer || exit
|
||||
|
@ -88,9 +96,14 @@ function infer()
|
|||
|
||||
function cal_acc()
|
||||
{
|
||||
python3.7 ../postprocess.py --result_path=./result_Files --label_path=$label_path &> acc.log &
|
||||
python3.7 ../postprocess.py --result_path=./result_Files --label_path=./label.txt &> acc.log &
|
||||
}
|
||||
|
||||
preprocess_data
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "preprocess data failed"
|
||||
exit 1
|
||||
fi
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
|
|
|
@ -36,6 +36,10 @@ checkpoint_path: None
|
|||
file_name: "warpctc"
|
||||
ckpt_file: ""
|
||||
file_format: "MINDIR"
|
||||
#310infer-related
|
||||
dataset_path: ""
|
||||
result_path: ""
|
||||
label_path: ""
|
||||
|
||||
---
|
||||
|
||||
|
|
|
@ -14,14 +14,10 @@
|
|||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.model_utils.config import config as cf
|
||||
|
||||
batch_Size = 1
|
||||
parser = argparse.ArgumentParser(description="warpctc acc calculation")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="label path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def is_eq(pred_lbl, target):
|
||||
|
@ -76,9 +72,9 @@ def get_result(result_path, label_path):
|
|||
resultPath = os.path.join(result_path, file)
|
||||
output = np.fromfile(resultPath, dtype=np.float16).reshape((-1, batch_Size, 11))
|
||||
preds.append(get_prediction(output))
|
||||
acc = calcul_acc(preds, labels)
|
||||
acc = round(calcul_acc(preds, labels), 3)
|
||||
print("Total data: {}, accuracy: {}".format(len(labels), acc))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_result(args.result_path, args.label_path)
|
||||
get_result(cf.result_path, cf.label_path)
|
||||
|
|
|
@ -14,24 +14,18 @@
|
|||
# ============================================================================
|
||||
import os
|
||||
import math as m
|
||||
import argparse
|
||||
from src.config import config as cf
|
||||
from src.model_utils.config import config as cf
|
||||
from src.dataset import create_dataset
|
||||
|
||||
batch_size = 1
|
||||
parser = argparse.ArgumentParser(description="Warpctc preprocess")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
parser.add_argument("--output_path", type=str, default=None, help="output path")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
dataset = create_dataset(dataset_path=cf.dataset_path,
|
||||
batch_size=batch_size,
|
||||
device_target="Ascend")
|
||||
|
||||
img_path = args_opt.output_path
|
||||
img_path = cf.output_path
|
||||
if not os.path.isdir(img_path):
|
||||
os.makedirs(img_path)
|
||||
total = dataset.get_dataset_size()
|
||||
|
|
Loading…
Reference in New Issue