!1016 add resnet50_imagenet2012 to example

Merge pull request !1016 from gengdongjie/master
This commit is contained in:
mindspore-ci-bot 2020-05-09 15:07:03 +08:00 committed by Gitee
commit 1fde96546e
15 changed files with 746 additions and 48 deletions

View File

@ -8,9 +8,9 @@ This is an example of training ResNet-50 with CIFAR-10 dataset in MindSpore.
- Install [MindSpore](https://www.mindspore.cn/install/en). - Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset [CIFAR-10](http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz). - Download the dataset CIFAR-10
> Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows: > Unzip the CIFAR-10 dataset to any path you want and the folder structure should include train and eval dataset as follows:
> ``` > ```
> . > .
> ├── cifar-10-batches-bin # train dataset > ├── cifar-10-batches-bin # train dataset
@ -26,9 +26,9 @@ This is an example of training ResNet-50 with CIFAR-10 dataset in MindSpore.
├── dataset.py # data preprocessing ├── dataset.py # data preprocessing
├── eval.py # infer script ├── eval.py # infer script
├── lr_generator.py # generate learning rate for each step ├── lr_generator.py # generate learning rate for each step
├── run_distribute_train.sh # launch distributed training ├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_infer.sh # launch infering ├── run_infer.sh # launch infering
├── run_standalone_train.sh # launch standalone training ├── run_standalone_train.sh # launch standalone training(1 pcs)
└── train.py # train script └── train.py # train script
``` ```
@ -51,11 +51,11 @@ Parameters for both training and inference can be set in config.py.
"save_checkpoint_steps": 195, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step "save_checkpoint_steps": 195, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint "save_checkpoint_path": "./", # path to save checkpoint
"warmup_epochs": 5, # number of warmup epoch
"lr_decay_mode": "poly" # decay mode can be selected in steps, ploy and default
"lr_init": 0.01, # initial learning rate "lr_init": 0.01, # initial learning rate
"lr_end": 0.00001, # final learning rate "lr_end": 0.00001, # final learning rate
"lr_max": 0.1, # maximum learning rate "lr_max": 0.1, # maximum learning rate
"warmup_epochs": 5, # number of warmup epoch
"lr_decay_mode": "poly" # decay mode can be selected in steps, ploy and default
``` ```
## Running the example ## Running the example
@ -65,7 +65,7 @@ Parameters for both training and inference can be set in config.py.
#### Usage #### Usage
``` ```
# distribute training # distributed training
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
# standalone training # standalone training
@ -90,7 +90,7 @@ sh run_standalone_train.sh ~/cifar-10-batches-bin
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log.
``` ```
# distribute training result(8p) # distribute training result(8 pcs)
epoch: 1 step: 195, loss is 1.9601055 epoch: 1 step: 195, loss is 1.9601055
epoch: 2 step: 195, loss is 1.8555021 epoch: 2 step: 195, loss is 1.8555021
epoch: 3 step: 195, loss is 1.6707983 epoch: 3 step: 195, loss is 1.6707983

View File

@ -31,9 +31,9 @@ config = ed({
"save_checkpoint_steps": 195, "save_checkpoint_steps": 195,
"keep_checkpoint_max": 10, "keep_checkpoint_max": 10,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "poly",
"lr_init": 0.01, "lr_init": 0.01,
"lr_end": 0.00001, "lr_end": 0.00001,
"lr_max": 0.1, "lr_max": 0.1
"warmup_epochs": 5,
"lr_decay_mode": "poly"
}) })

View File

@ -40,39 +40,30 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
rank_id = int(os.getenv("RANK_ID")) rank_id = int(os.getenv("RANK_ID"))
if device_num == 1: if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True) ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True, ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id) num_shards=device_num, shard_id=rank_id)
resize_height = config.image_height
resize_width = config.image_width
rescale = 1.0 / 255.0
shift = 0.0
# define map operations # define map operations
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4))
random_horizontal_flip_op = C.RandomHorizontalFlip(rank_id / (rank_id + 1))
resize_op = C.Resize((resize_height, resize_width))
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
change_swap_op = C.HWC2CHW()
trans = [] trans = []
if do_train: if do_train:
trans += [random_crop_op, random_horizontal_flip_op] trans += [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5)
]
trans += [resize_op, rescale_op, normalize_op, change_swap_op] trans += [
C.Resize((config.image_height, config.image_width)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32) type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op) ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
ds = ds.map(input_columns="image", operations=trans) ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
# apply shuffle operations
ds = ds.shuffle(buffer_size=config.buffer_size)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)

View File

@ -17,8 +17,6 @@ eval.
""" """
import os import os
import argparse import argparse
import random
import numpy as np
from dataset import create_dataset from dataset import create_dataset
from config import config from config import config
from mindspore import context from mindspore import context
@ -27,13 +25,8 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from mindspore.communication.management import init from mindspore.communication.management import init
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')

View File

@ -15,8 +15,6 @@
"""train_imagenet.""" """train_imagenet."""
import os import os
import argparse import argparse
import random
import numpy as np
from dataset import create_dataset from dataset import create_dataset
from lr_generator import get_lr from lr_generator import get_lr
from config import config from config import config
@ -31,13 +29,8 @@ from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.dataset.engine as de
from mindspore.communication.management import init from mindspore.communication.management import init
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')

View File

@ -0,0 +1,127 @@
# ResNet-50 Example
## Description
This is an example of training ResNet-50 with ImageNet2012 dataset in MindSpore.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset ImageNet2012
> Unzip the ImageNet2012 dataset to any path you want and the folder structure should include train and eval dataset as follows:
> ```
> .
> ├── ilsvrc # train dataset
> └── ilsvrc_eval # infer dataset
> ```
## Example structure
```shell
.
├── crossentropy.py # CrossEntropy loss function
├── config.py # parameter configuration
├── dataset.py # data preprocessing
├── eval.py # infer script
├── lr_generator.py # generate learning rate for each step
├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_infer.sh # launch infering
├── run_standalone_train.sh # launch standalone training(1 pcs)
└── train.py # train script
```
## Parameter configuration
Parameters for both training and inference can be set in config.py.
```
"class_num": 1001, # dataset class number
"batch_size": 32, # batch size of input tensor
"loss_scale": 1024, # loss scale
"momentum": 0.9, # momentum optimizer
"weight_decay": 1e-4, # weight decay
"epoch_size": 90, # only valid for taining, which is always 1 for inference
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
"image_width": 224, # image width
"save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch
"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"warmup_epochs": 0, # number of warmup epoch
"lr_decay_mode": "cosine", # decay mode for generating learning rate
"label_smooth": True, # label smooth
"label_smooth_factor": 0.1, # label smooth factor
"lr_init": 0, # initial learning rate
"lr_max": 0.1, # maximum learning rate
```
## Running the example
### Train
#### Usage
```
# distributed training
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
# standalone training
Usage: sh run_standalone_train.sh [DATASET_PATH]
```
#### Launch
```bash
# distributed training example(8 pcs)
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc
# standalone training example(1 pcs)
sh run_standalone_train.sh dataset/ilsvrc
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
#### Result
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log.
```
# distribute training result(8 pcs)
epoch: 1 step: 5004, loss is 4.8995576
epoch: 2 step: 5004, loss is 3.9235563
epoch: 3 step: 5004, loss is 3.833077
epoch: 4 step: 5004, loss is 3.2795618
epoch: 5 step: 5004, loss is 3.1978393
```
### Infer
#### Usage
```
# infer
Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Launch
```bash
# infer with checkpoint
sh run_infer.sh dataset/ilsvrc_eval train_parallel0/resnet-90_5004.ckpt
```
> checkpoint can be produced in training process.
#### Result
Inference result will be stored in the example path, whose folder name is "infer". Under this, you can find result like the followings in log.
```
result: {'acc': 0.7671054737516005} ckpt=train_parallel0/resnet-90_5004.ckpt
```

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 90,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.1
})

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

View File

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
Returns:
dataset
"""
device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID"))
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize((256, 256)),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
from dataset import create_dataset
from config import config
from mindspore import context
from mindspore.model_zoo.resnet import resnet50
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from crossentropy import CrossEntropy
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
net = resnet50(class_num=config.class_num)
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_eval:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,90 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f "$PATH1" ]
then
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file"
exit 1
fi
if [ ! -d "$PATH2" ]
then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cp *.sh ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
cd ..
done

View File

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "infer" ];
then
rm -rf ./infer
fi
mkdir ./infer
cp *.py ./infer
cp *.sh ./infer
cd ./infer || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ ! -d "$PATH1" ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp *.py ./train
cp *.sh ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --do_train=True --dataset_path=$PATH1 &> log &
cd ..

View File

@ -0,0 +1,99 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
import argparse
from dataset import create_dataset
from lr_generator import get_lr
from config import config
from mindspore import context
from mindspore import Tensor
from mindspore.model_zoo.resnet import resnet50
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from crossentropy import CrossEntropy
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
init()
epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num)
# weight init
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0, lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,
lr_decay_mode='cosine'))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)