add distribute training for gpu benchmark

This commit is contained in:
VectorSL 2020-11-12 19:50:24 +08:00
parent 51a57b9243
commit a57fdc3b2b
3 changed files with 47 additions and 28 deletions

View File

@ -277,7 +277,7 @@ sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATA
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
# gpu benchmark example
sh run_gpu_resnet_benchmark.sh [IMAGENET_DATASET_PATH] [BATCH_SIZE](optional)
sh run_gpu_resnet_benchmark.sh [IMAGENET_DATASET_PATH] [BATCH_SIZE](optional) [DEVICE_NUM](optional)
```
#### Running parameter server mode training
@ -345,16 +345,11 @@ epoch: 5 step: 5004, loss is 3.3501816
```
# ========START RESNET50 GPU BENCHMARK========
step time: 22549.130 ms, fps: 11 img/sec. epoch: 1 step: 1, loss is 6.940182
step time: 182.485 ms, fps: 1402 img/sec. epoch: 1 step: 2, loss is 7.078993
step time: 175.263 ms, fps: 1460 img/sec. epoch: 1 step: 3, loss is 7.559594
step time: 174.775 ms, fps: 1464 img/sec. epoch: 1 step: 4, loss is 8.020937
step time: 175.564 ms, fps: 1458 img/sec. epoch: 1 step: 5, loss is 8.140132
step time: 175.438 ms, fps: 1459 img/sec. epoch: 1 step: 6, loss is 8.021118
step time: 175.760 ms, fps: 1456 img/sec. epoch: 1 step: 7, loss is 7.910158
step time: 176.033 ms, fps: 1454 img/sec. epoch: 1 step: 8, loss is 7.940162
step time: 175.995 ms, fps: 1454 img/sec. epoch: 1 step: 9, loss is 7.740654
step time: 175.313 ms, fps: 1460 img/sec. epoch: 1 step: 10, loss is 7.956182
step time: 12416.098 ms, fps: 412 img/sec. epoch: 1 step: 20, loss is 6.940182
step time: 3472.037 ms, fps: 1474 img/sec. epoch: 2 step: 20, loss is 7.078993
step time: 3469.523 ms, fps: 1475 img/sec. epoch: 3 step: 20, loss is 7.559594
step time: 3460.311 ms, fps: 1479 img/sec. epoch: 4 step: 20, loss is 6.920937
step time: 3460.543 ms, fps: 1479 img/sec. epoch: 5 step: 20, loss is 6.814013
...
```
## [Evaluation Process](#contents)

View File

@ -14,45 +14,54 @@
# ============================================================================
"""train resnet."""
import argparse
import ast
import time
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback, LossMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_group_size
from mindspore.common import set_seed
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from src.resnet_gpu_benchmark import resnet50 as resnet
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.')
parser.add_argument('--epoch_size', type=str, default="2", help='Epoch_size: default 2')
parser.add_argument('--print_per_steps', type=str, default="20", help='Print loss and time per steps: default 20')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--dataset_path', type=str, default=None, help='Imagenet dataset path')
args_opt = parser.parse_args()
set_seed(1)
class MyTimeMonitor(Callback):
def __init__(self, batch_size):
def __init__(self, batch_size, sink_size):
super(MyTimeMonitor, self).__init__()
self.batch_size = batch_size
self.size = sink_size
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_mseconds = (time.time() - self.step_time) * 1000
fps = self.batch_size / step_mseconds *1000
fps = self.batch_size / step_mseconds *1000 * self.size
print("step time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ")
def pad(image):
zeros = np.zeros([224, 224, 1], dtype=np.uint8)
output = np.concatenate((image, zeros), axis=2)
return output
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU"):
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
@ -73,16 +82,13 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
C.Normalize(mean=mean, std=std),
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
ds = ds.map(operations=C2.PadEnd(pad_shape=[224, 224, 4], pad_value=0), input_columns="image",
num_parallel_workers=8)
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=4)
ds = ds.map(operations=pad, input_columns="image", num_parallel_workers=4)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
if repeat_num > 1:
ds = ds.repeat(repeat_num)
return ds
@ -101,16 +107,27 @@ def get_liner_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per
return lr_each_step
if __name__ == '__main__':
# set args
dev = "GPU"
epoch_size = int(args_opt.epoch_size)
total_batch = int(args_opt.batch_size)
print_per_steps = int(args_opt.print_per_steps)
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=dev, save_graphs=False)
if args_opt.run_distribute:
init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[85, 160])
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=total_batch, target=dev)
step_size = dataset.get_dataset_size()
if (print_per_steps > step_size or print_per_steps < 1):
print("Arg: print_per_steps should lessequal to dataset_size ", step_size)
print("Change to default: 20")
print_per_steps = 20
# define net
net = resnet(class_num=1001)
@ -151,10 +168,10 @@ if __name__ == '__main__':
amp_level="O2", keep_batchnorm_fp32=False)
# define callbacks
time_cb = MyTimeMonitor(total_batch)
time_cb = MyTimeMonitor(total_batch, print_per_steps)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
# train model
print("========START RESNET50 GPU BENCHMARK========")
model.train(epoch_size, dataset, callbacks=cb, sink_size=dataset.get_dataset_size())
model.train(int(epoch_size * step_size / print_per_steps), dataset, callbacks=cb, sink_size=print_per_steps)

View File

@ -14,9 +14,10 @@
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 2 ]
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional)"
echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DEVICE_NUM](optional)"
echo "Example: sh run_gpu_resnet_benchmark.sh /path/imagenet/train 256 8"
exit 1
fi
@ -40,3 +41,9 @@ if [ $# == 2 ]
then
python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --batch_size=$2
fi
if [ $# == 3 ]
then
mpirun --allow-run-as-root -n $3 python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True \
--dataset_path=$DATAPATH --batch_size=$2
fi