gpu for unet++

This commit is contained in:
郑彬 2021-08-10 16:48:24 +08:00
parent 57bbc7fb94
commit a4593ec89f
8 changed files with 291 additions and 106 deletions

View File

@ -127,7 +127,7 @@ After installing MindSpore via the official website, you can start training and
- Run on Ascend
```python
```shell
# run training example
python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.log 2>&1 &
OR
@ -142,6 +142,26 @@ OR
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
- Run on GPU
```shell
# run training example
python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml --device_target=GPU > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](optional)
# run distributed training example
bash scripts/run_distribute_train.sh [RANKSIZE] [DATASET] [CONFIG_PATH] [CUDA_VISIBLE_DEVICES(0,1,2,3,4,5,6,7)](optional)
# run evaluation example
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/yaml > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH] [DEVICE_ID](optional)
# run export
python export.py --config_path=[CONFIG_PATH] --checkpoint_file_path=[model_ckpt_path] --file_name=[air_model_name] --file_format=MINDIR --device_target=GPU
```
- Run on docker
Build docker images(Change version to the one you actually used)
@ -162,7 +182,7 @@ Then you can run everything just like on ascend.
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
```python
```text
# run distributed training on modelarts example
# (1) First, Perform a or b.
# a. Set "enable_modelarts=True" on yaml file.
@ -191,33 +211,18 @@ If you want to run in modelarts, please check the official documentation of [mod
# (7) Create your job.
```
- Run on GPU
```python
# run training example
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
# run distributed training example
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
# run evaluation example
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```shell
```text
├── model_zoo
├── README.md // descriptions about all the models
├── unet
├── README.md // descriptions about Unet
├── README_CN.md // chinese descriptions about Unet
├── ascend310_infer // code of infer on ascend 310
├── Dockerfile
├── scripts
│ ├──docker_start.sh // shell script for quick docker start
│ ├──run_disribute_train.sh // shell script for distributed on Ascend
@ -228,7 +233,7 @@ If you want to run in modelarts, please check the official documentation of [mod
│ ├──run_standalone_eval_gpu.sh // shell script forevaluation on GPU
│ ├──run_distribute_train_gpu.sh // shell script for distributed on GPU
├── src
│ ├──config.py // parameter configuration
│ ├──__init__.py
│ ├──data_loader.py // creating dataset
│ ├──loss.py // loss
│ ├──eval_callback.py // evaluation callback while training
@ -236,18 +241,21 @@ If you want to run in modelarts, please check the official documentation of [mod
│ ├──unet_medical // Unet medical architecture
├──__init__.py // init file
├──unet_model.py // unet model
──unet_parts.py // unet part
──unet_parts.py // unet part
│ ├──unet_nested // Unet++ architecture
├──__init__.py // init file
├──unet_model.py // unet model
├──unet_parts.py // unet part
├── model_utils
│ ├── config.py // parameter configuration
│ ├── device_adapter.py // device adapter
│ ├── local_adapter.py // local adapter
│ ├── moxing_adapter.py // moxing adapter
└──unet_parts.py // unet part
│ ├──model_utils
├──__init__.py
├── config.py // parameter configuration
├── device_adapter.py // device adapter
├── local_adapter.py // local adapter
└── moxing_adapter.py // moxing adapter
├── unet_medical_config.yaml // parameter configuration
├── unet_medicl_gpu_config.yaml // parameter configuration
├── unet_nested_cell_config.yaml // parameter configuration
├── unet_nested_coco_config.yaml // parameter configuration
├── unet_nested_config.yaml // parameter configuration
├── unet_simple_config.yaml // parameter configuration
├── unet_simple_coco_config.yaml // parameter configuration
@ -258,16 +266,16 @@ If you want to run in modelarts, please check the official documentation of [mod
├── postprocess.py // unet 310 infer postprocess.
├── preprocess.py // unet 310 infer preprocess dataset
├── preprocess_dataset.py // the script to adapt MultiClass dataset
── requirements.txt // Requirements of third party package.
── requirements.txt // Requirements of third party package.
```
### [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
Parameters for both training and evaluation can be set in *.yaml
- config for Unet, ISBI dataset
```python
```yaml
'name': 'Unet', # model name
'lr': 0.0001, # learning rate
'epochs': 400, # total training epochs when run 1p
@ -298,7 +306,7 @@ Parameters for both training and evaluation can be set in config.py
- config for Unet++, cell nuclei dataset
```python
```yaml
'model': 'unet_nested', # model name
'dataset': 'Cell_nuclei', # dataset name
'img_size': [96, 96], # image size
@ -366,9 +374,9 @@ The model checkpoint will be saved in the current directory.
#### running on GPU
```shell
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output --device_target GPU > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](optional)
```
The python command above will run in the background, you can view the results through the file train.log. The model checkpoint will be saved in the current directory.
@ -466,6 +474,25 @@ The above python command will run in the background. You can view the results th
| Checkpoint for Fine tuning | 355.11M (.ckpt file) | 355.11M (.ckpt file) |
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
| Parameters | Ascend | GPU |
| -----| ----- | ----- |
| Model Version | U-Net nested(unet++) | U-Net nested(unet++) |
| Resource | Ascend 910 ;CPU 2.60GHz,192cores; Memory,755G; OS Euler2.8 | NV SMX2 V100-32G |
| uploaded Date | 2021-8-20 | 2021-8-20 |
| MindSpore Version | 1.3.0 | 1.3.0 |
| Dataset | Cell_nuclei | Cell_nuclei |
| Training Parameters | 1pc: epoch=200, total steps=6700, batch_size=16, lr=0.0003, 8pc: epoch=1600, total steps=6560, batch_size=16*8, lr=0.0003 | 1pc: epoch=200, total steps=6700, batch_size=16, lr=0.0003, 8pc: epoch=1600, total steps=6560, batch_size=16*8, lr=0.0003 |
| Optimizer | ADAM | ADAM |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| probability | cross valid dice coeff is 0.966, cross valid IOU is 0.936 | cross valid dice coeff is 0.976,cross valid IOU is 0.955 |
| Loss | <0.1 | <0.1 |
| Speed | 1pc: 150~200 fps | 1pc230~280 fps, 8pc(170~210)*8 fps |
| Total time | 1pc: 10.8min | 1pc8min |
| Parameters (M) | 27M | 27M |
| Checkpoint for Fine tuning | 103.4M(.ckpt file) | 103.4M(.ckpt file) |
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
## [How to use](#contents)
### Inference
@ -489,7 +516,7 @@ The checkpoint_file_path parameter is required,
Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start as follows)
```python
```text
# Export on ModelArts
# (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
@ -530,7 +557,7 @@ Cross valid dice coeff is: 0.9054352151297033
Set options `resume` to True in `*.yaml`, and set `resume_ckpt` to the path of your checkpoint. e.g.
```python
```yaml
'resume': True,
'resume_ckpt': 'ckpt_unet_sample_adam_1-1_600.ckpt',
'transfer_training': False,
@ -541,7 +568,7 @@ Set options `resume` to True in `*.yaml`, and set `resume_ckpt` to the path of y
Do the same thing as resuming traing above. In addition, set `transfer_training` to True. The `filter_weight` shows the weights which will be filtered for different dataset. Usually, the default value of `filter_weight` don't need to be changed. The default values includes the weights which depends on the class number. e.g.
```python
```yaml
'resume': True,
'resume_ckpt': 'ckpt_unet_sample_adam_1-1_600.ckpt',
'transfer_training': True,

View File

@ -131,9 +131,9 @@ python preprocess_dataset.py --config_path path/unet/*.yaml --data_path /data/s
- Ascend处理器环境运行
```python
```shell
# 训练示例
python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.log 2>&1 &
python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.log 2>&1 &
OR
bash scripts/run_standalone_train.sh [DATASET] [CONFIG_PATH]
@ -141,11 +141,31 @@ python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.l
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]
# 评估示例
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/yaml > eval.log 2>&1 &
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/yaml > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
- GPU处理器环境运行
```shell
# 训练示例
python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml --device_target=GPU > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](optional)
# 分布式训练示例
bash scripts/run_distribute_train.sh [RANKSIZE] [DATASET] [CONFIG_PATH] [CUDA_VISIBLE_DEVICES(0,1,2,3,4,5,6,7)](optional)
# 评估示例
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/yaml > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH] [DEVICE_ID](optional)
# 模型导出
python export.py --config_path=[CONFIG_PATH] --checkpoint_file_path=[model_ckpt_path] --file_name=[air_model_name] --file_format=MINDIR --device_target=GPU
```
- Docker中运行
创建docker镜像(讲版本号换成你实际使用的版本)
@ -167,7 +187,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
如果要在modelarts上进行模型的训练可以参考modelarts的官方指导文档(https://support.huaweicloud.com/modelarts/)
开始进行模型的训练和推理,具体操作如下:
```python
```text
# 在modelarts上使用分布式训练的示例
# (1) 选址a或者b其中一种方式。
# a. 设置 "enable_modelarts=True" 。
@ -198,35 +218,20 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
# (7) 开始模型的推理。
```
- GPU处理器环境运行
```python
# 训练示例
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
# 分布式训练示例
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
# 评估示例
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
# 脚本说明
## 脚本说明
### 脚本及样例代码
```path
```text
├── model_zoo
├── README.md // 模型描述
├── unet
├── README.md // Unet描述
├── README_CN.md // Unet中文描述
├── ascend310_infer // Ascend 310 推理代码
├── Dockerfile
├── scripts
│ ├──docker_start.sh // docker 脚本
│ ├──run_disribute_train.sh // Ascend 上分布式训练脚本
@ -237,26 +242,29 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
│ ├──run_standalone_eval_gpu.sh // GPU 上评估脚本
│ ├──run_distribute_train_gpu.sh // GPU 上分布式训练脚本
├── src
│ ├──config.py // 参数配置
│ ├──__init__.py
│ ├──data_loader.py // 数据处理
│ ├──loss.py // 损失函数
│ ├─ eval_callback.py // 训练时推理回调函数
│ ├──eval_callback.py // 训练时推理回调函数
│ ├──utils.py // 通用组件(回调函数)
│ ├──unet_medical // 医学图像处理Unet结构
├──__init__.py
├──unet_model.py // Unet 网络结构
──unet_parts.py // Unet 子网
──unet_parts.py // Unet 子网
│ ├──unet_nested // Unet++
├──__init__.py
├──unet_model.py // Unet++ 网络结构
├──unet_parts.py // Unet++ 子网
├── model_utils
│ ├── config.py // 参数配置
│ ├── device_adapter.py // 设备配置
│ ├── local_adapter.py // 本地设备配置
│ ├── moxing_adapter.py // modelarts设备配置
└──net_parts.py // Unet++ 子网
│ ├──model_utils
├──__init__.py
├──config.py // 参数配置
├──device_adapter.py // 设备配置
├──local_adapter.py // 本地设备配置
└──moxing_adapter.py // modelarts设备配置
├── unet_medical_config.yaml // 配置文件
├── unet_medicl_gpu_config.yaml // 配置文件
├── unet_nested_cell_config.yaml // 配置文件
├── unet_nested_coco_config.yaml // 配置文件
├── unet_nested_config.yaml // 配置文件
├── unet_simple_config.yaml // 配置文件
├── unet_simple_coco_config.yaml // 配置文件
@ -267,16 +275,16 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
├── postprocess.py // 310 推理后处理脚本
├── preprocess.py // 310 推理前处理脚本
├── preprocess_dataset.py // 适配MultiClass数据集脚本
── requirements.txt // 需要的三方库.
── requirements.txt // 需要的三方库.
```
### 脚本参数
config.py中可以同时配置训练参数和评估参数。
*.yaml中可以同时配置训练参数和评估参数。
- U-Net配置ISBI数据集
```python
```yaml
'name': 'Unet', # 模型名称
'lr': 0.0001, # 学习率
'epochs': 400, # 运行1p时的总训练轮次
@ -300,7 +308,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
- Unet++配置, cell nuclei数据集
```python
```yaml
'model': 'unet_nested', # 模型名称
'dataset': 'Cell_nuclei', # 数据集名称
'img_size': [96, 96], # 输入图像大小
@ -335,7 +343,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
- Ascend处理器环境运行
```shell
python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.log 2>&1 &
python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.log 2>&1 &
OR
bash scripts/run_standalone_train.sh [DATASET] [CONFIG_PATH]
```
@ -363,9 +371,9 @@ python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.l
- GPU处理器环境运行
```shell
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output --device_target GPU > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](optional)
```
上述python命令在后台运行可通过`train.log`文件查看结果。
@ -412,7 +420,7 @@ bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
在运行以下命令之前,请检查用于评估的检查点路径。将检查点路径设置为绝对全路径,如"username/unet/ckpt_unet_medical_adam-48_600.ckpt"。
```shell
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/yaml > eval.log 2>&1 &
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/yaml > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
@ -465,6 +473,25 @@ python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkp
| 微调检查点 | 355.11M (.ckpt文件) | 355.11M (.ckpt文件) |
| 脚本 | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
| 参数 | Ascend | GPU |
| ----- | ------ | ----- |
| 模型版本 | U-Net nested(unet++) | U-Net nested(unet++) |
| 资源 | Ascend 910CPU2.60GHz192核内存755 GB系统 Euler2.8 | NV SMX2 V100内存32G |
| 上传日期 | 2021-8-20 | 2021-8-20 |
| MindSpore版本 | 1.3.0 | 1.3.0 |
| 数据集 | Cell_nuclei | Cell_nuclei |
| 训练参数 | 1卡: epoch=200, total steps=6700, batch_size=16, lr=0.0003; 8卡: epoch=1600, total steps=6560, batch_size=16*8, lr=0.0003 | 1卡: epoch=200, total steps=6700, batch_size=16, lr=0.0003; 8卡: epoch=1600, total steps=6560, batch_size=16*8, lr=0.0003 |
| 优化器 | ADAM | ADAM |
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
| 输出 | 概率 | 概率 |
| 概率 | cross valid dice coeff is 0.966, cross valid IOU is 0.936 | cross valid dice coeff is 0.976,cross valid IOU is 0.955 |
| 损失 | <0.1 | <0.1 |
| 速度 | 1卡150~200 fps | 1卡230~280 fps, 8卡(170~210)*8 fps|
| 总时长 | 1卡: 10.8分钟 | 1卡: 8分钟 |
| 参数(M) | 27M | 27M |
| 微调检查点 | 103.4M(.ckpt文件) | 103.4M(.ckpt文件) |
| 脚本 | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
### 用法
#### 推理
@ -485,7 +512,7 @@ python export.py --config_path=[CONFIG_PATH] --checkpoint_file_path=[model_ckpt_
ModelArts导出mindir
```python
```text
# (1) 把训练好的模型地方到桶的对应位置。
# (2) 选址a或者b其中一种方式。
# a. 设置 "enable_modelarts=True"
@ -522,9 +549,9 @@ Cross valid dice coeff is: 0.9054352151297033
#### 继续训练预训练模型
在`config.py`里将`resume`设置成True并将`resume_ckpt`设置成对应的权重文件路径,例如:
在`*.yaml`里将`resume`设置成True并将`resume_ckpt`设置成对应的权重文件路径,例如:
```python
```yaml
'resume': True,
'resume_ckpt': 'ckpt_unet_medical_adam_1-1_600.ckpt',
'transfer_training': False,
@ -535,7 +562,7 @@ Cross valid dice coeff is: 0.9054352151297033
首先像上面讲的那样讲继续训练的权重加载进来。然后将`transfer_training`设置成True。配置中还有一个 `filter_weight`参数,用于将一些不能适用于不同数据集的权重过滤掉。通常这个`filter_weight`的参数不需要修改,其默认值通常是和模型的分类数相关的参数。例如:
```python
```yaml
'resume': True,
'resume_ckpt': 'ckpt_unet_medical_adam_1-1_600.ckpt',
'transfer_training': True,

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
import os
import logging
from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
@ -24,6 +23,7 @@ from src.unet_nested import NestedUNet, UNet
from src.utils import UnetEval, TempLoss, dice_coeff
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
@moxing_wrapper()
def test_net(data_dir,
@ -62,7 +62,7 @@ if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if config.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
device_id = get_device_id()
context.set_context(device_id=device_id)
test_net(data_dir=config.data_path,
ckpt_path=config.checkpoint_file_path,

View File

@ -13,10 +13,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]"
echo "for example: bash run_distribute_train_gpu.sh 8 /path/to/data/ /path/to/config/"
echo "=============================================================================================================="
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python train.py --run_distribute=True --data_path=$2 --config_path=$3 --output=./output > train.log 2>&1 &
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 3 ] && [ $# != 4 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH] [CUDA_VISIBLE_DEVICES(0,1,2,3,4,5,6,7)](optional)"
echo "for example: bash run_distribute_train_gpu.sh 8 /path/to/data/ /path/to/config/"
echo "=============================================================================================================="
exit 1
fi
RANK_SIZE=`expr $1 + 0`
if [ $? != 0 ]; then
echo RANK_SIZE=$1 is not integer!
exit 1
fi
export RANK_SIZE=$RANK_SIZE
DATASET=$(get_real_path $2)
CONFIG_PATH=$(get_real_path $3)
if [ $# != 4 ]; then
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
else
export CUDA_VISIBLE_DEVICES=$4
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
TRAIN_OUTPUT=${PROJECT_DIR}/../train_distributed_gpu
if [ -d $TRAIN_OUTPUT ]; then
rm -rf $TRAIN_OUTPUT
fi
mkdir $TRAIN_OUTPUT
cd $TRAIN_OUTPUT || exit
cp ../train.py ./
cp ../eval.py ./
cp -r ../src ./
cp $CONFIG_PATH ./
env > env.log
mpirun -n $RANK_SIZE --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python train.py --run_distribute=True \
--data_path=$DATASET \
--config_path=${CONFIG_PATH##*/} \
--output=./output \
--device_target=GPU> train.log 2>&1 &

View File

@ -13,10 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]"
echo "for example: bash run_standalone_eval_gpu.sh /path/to/data/ /path/to/checkpoint/ /path/to/config/"
echo "=============================================================================================================="
python eval.py --data_path=$1 --checkpoint_file_path=$2 --config_path=$3 > eval.log 2>&1 &
if [ $# != 3 ] && [ $# != 4 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH] [DEVICE_ID](optional)"
echo "for example: bash run_standalone_eval_gpu.sh /path/to/data/ /path/to/checkpoint/ /path/to/config/"
echo "=============================================================================================================="
exit 1
fi
if [ $# != 4 ]; then
DEVICE_ID=0
else
DEVICE_ID=`expr $4 + 0`
if [ $? != 0 ]; then
echo "DEVICE_ID=$4 is not an integer"
exit 1
fi
fi
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
DATASET=$(get_real_path $1)
CHECKPOINT=$(get_real_path $2)
CONFIG_PATH=$(get_real_path $3)
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
TRAIN_OUTPUT=${PROJECT_DIR}/../eval_gpu
if [ -d $TRAIN_OUTPUT ]; then
rm -rf $TRAIN_OUTPUT
fi
mkdir $TRAIN_OUTPUT
cd $TRAIN_OUTPUT || exit
cp ../eval.py ./
cp -r ../src ./
cp $CONFIG_PATH ./
env > env.log
python eval.py --data_path=$DATASET \
--checkpoint_file_path=$CHECKPOINT \
--config_path=${CONFIG_PATH##*/} \
--device_target=GPU > eval.log 2>&1 &

View File

@ -14,9 +14,50 @@
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] "
echo "for example: bash scripts/run_standalone_train_gpu.sh /path/to/data/ /path/to/config/"
echo "=============================================================================================================="
python train.py --data_path=$1 --config_path=$2 --output ./output > train.log 2>&1 &
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 2 ] && [ $# != 3 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] [DEVICE_ID](optional)"
echo "for example: bash scripts/run_standalone_train_gpu.sh /path/to/data/ /path/to/config/"
echo "=============================================================================================================="
exit 1
fi
if [ $# != 3 ]; then
DEVICE_ID=0
else
DEVICE_ID=`expr $3 + 0`
if [ $? != 0 ]; then
echo "DEVICE_ID=$3 is not an integer"
exit 1
fi
fi
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
DATASET=$(get_real_path $1)
CONFIG_PATH=$(get_real_path $2)
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
TRAIN_OUTPUT=${PROJECT_DIR}/../train_standalone_gpu
if [ -d $TRAIN_OUTPUT ]; then
rm -rf $TRAIN_OUTPUT
fi
mkdir $TRAIN_OUTPUT
cd $TRAIN_OUTPUT || exit
cp ../train.py ./
cp ../eval.py ./
cp -r ../src ./
cp $CONFIG_PATH ./
env > env.log
python train.py --data_path=$DATASET \
--config_path=${CONFIG_PATH##*/} \
--output ./output \
--device_target=GPU > train.log 2>&1 &

View File

@ -32,6 +32,7 @@ from src.eval_callback import EvalCallBack
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
mindspore.set_seed(1)
@ -79,9 +80,11 @@ def train_net(cross_valid_ind=1,
per_print_times = 0
repeat = config.repeat if hasattr(config, "repeat") else 1
split = config.split if hasattr(config, "split") else 0.8
python_multiprocessing = not (config.device_target == "GPU" and run_distribute)
train_dataset = create_multi_class_dataset(data_dir, config.image_size, repeat, batch_size,
num_classes=config.num_classes, is_train=True, augment=True,
split=split, rank=rank, group_size=group_size, shuffle=True)
split=split, rank=rank, group_size=group_size, shuffle=True,
python_multiprocessing=python_multiprocessing)
valid_dataset = create_multi_class_dataset(data_dir, config.image_size, 1, 1,
num_classes=config.num_classes, is_train=False,
eval_resize=config.eval_resize, split=split,
@ -110,9 +113,9 @@ def train_net(cross_valid_ind=1,
loss_scale=config.loss_scale)
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.FixedLossScaleManager, False)
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
amp_level = "O0" if config.device_target == "GPU" else "O3"
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer,
amp_level=amp_level)
print("============== Starting Training ==============")
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
if config.run_eval:
@ -132,7 +135,7 @@ if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if config.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
device_id = get_device_id()
context.set_context(device_id=device_id)
epoch_size = config.epochs if not config.run_distribute else config.distribute_epochs
batchsize = config.batch_size

View File

@ -25,6 +25,7 @@ epochs: 200
repeat: 10
distribute_epochs: 1600
batch_size: 16
distribute_batchsize: 16
cross_valid_ind: 1
num_classes: 2
num_channels: 3
@ -69,6 +70,7 @@ device_target: "Target device type, available: [Ascend, GPU, CPU]"
enable_profiling: "Whether enable profiling while training, default: False"
num_classes: "Class for dataset"
batch_size: "Batch size for training and evaluation"
distribute_batchsize: "Batch size for distribute training"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."