From abd273639617af594279cd5cd3e4898271a31828 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Tue, 7 Sep 2021 11:37:22 +0800 Subject: [PATCH] fix mobilenetv2 bugs on GPU --- .../cv/mobilenetv2/default_config.yaml | 2 + .../cv/mobilenetv2/default_config_acc.yaml | 2 + .../cv/mobilenetv2/default_config_cpu.yaml | 5 ++ .../cv/mobilenetv2/default_config_gpu.yaml | 5 ++ model_zoo/official/cv/mobilenetv2/eval.py | 17 ++--- model_zoo/official/cv/mobilenetv2/export.py | 19 ++---- .../cv/mobilenetv2/scripts/run_train.sh | 6 +- .../official/cv/mobilenetv2/src/dataset.py | 27 +------- .../official/cv/mobilenetv2/src/metric.py | 11 ++-- .../official/cv/mobilenetv2/src/utils.py | 62 +++++-------------- model_zoo/official/cv/mobilenetv2/train.py | 54 +++++++--------- 11 files changed, 78 insertions(+), 132 deletions(-) diff --git a/model_zoo/official/cv/mobilenetv2/default_config.yaml b/model_zoo/official/cv/mobilenetv2/default_config.yaml index 4cafbb569ed..e7060077ab7 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config.yaml @@ -38,6 +38,7 @@ device_id: 0 rank_id: 0 rank_size: 1 run_distribute: False +run_eval: False activation: "Softmax" # Image classification trian. train_parse_args():return train_args @@ -86,6 +87,7 @@ file_name: "output file name." result_path: "result files path." label_path: "label path." enable_profiling: 'Whether enable profiling while training, default: False' +run_eval: 'Whether run evaluation while training, default is false.' run_distribute: 'Run distribute, default is false.' device_id: 'Device id, default is 0.' rank_id: 'Rank id, default is 0.' diff --git a/model_zoo/official/cv/mobilenetv2/default_config_acc.yaml b/model_zoo/official/cv/mobilenetv2/default_config_acc.yaml index ba553e768eb..59837379e48 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config_acc.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config_acc.yaml @@ -38,6 +38,7 @@ device_id: 0 rank_id: 0 rank_size: 1 run_distribute: False +run_eval: True activation: "Softmax" # Image classification trian. train_parse_args():return train_args @@ -86,6 +87,7 @@ file_name: "output file name." result_path: "result files path." label_path: "label path." enable_profiling: 'Whether enable profiling while training, default: False' +run_eval: 'Whether run evaluation while training, default is false.' run_distribute: 'Run distribute, default is false.' device_id: 'Device id, default is 0.' rank_id: 'Rank id, default is 0.' diff --git a/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml b/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml index 39e61cee409..83b6e587019 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml @@ -34,7 +34,11 @@ save_checkpoint_epochs: 1 keep_checkpoint_max: 20 save_checkpoint_path: "./" platform: 'CPU' +device_id: 0 +rank_id: 0 +rank_size: 1 run_distribute: False +run_eval: False activation: "Softmax" # Image classification trian. train_parse_args():return train_args @@ -83,6 +87,7 @@ file_name: "output file name." result_path: "result files path." label_path: "label path." enable_profiling: 'Whether enable profiling while training, default: False' +run_eval: 'Whether run evaluation while training, default is false.' run_distribute: 'Run distribute, default is false.' device_id: 'Device id, default is 0.' rank_id: 'Rank id, default is 0.' diff --git a/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml b/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml index cf7c0ca051d..d032f4e2c55 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml @@ -34,6 +34,9 @@ save_checkpoint_epochs: 1 keep_checkpoint_max: 200 save_checkpoint_path: "./" platform: 'GPU' +device_id: 0 +rank_id: 0 +rank_size: 1 run_distribute: True activation: "Softmax" @@ -57,6 +60,7 @@ ckpt_file: "/cache/train/mobilenetv2-200_625.ckpt" file_name: "mobilenetv2" file_format: "MINDIR" is_training_export: False +run_eval: False run_distribute_export: False # postprocess.py / mobilenetv2 acc calculation @@ -83,6 +87,7 @@ file_name: "output file name." result_path: "result files path." label_path: "label path." enable_profiling: 'Whether enable profiling while training, default: False' +run_eval: 'Whether run evaluation while training, default is false.' run_distribute: 'Run distribute, default is false.' device_id: 'Device id, default is 0.' rank_id: 'Rank id, default is 0.' diff --git a/model_zoo/official/cv/mobilenetv2/eval.py b/model_zoo/official/cv/mobilenetv2/eval.py index 5f0c175b5a0..c004fe8322b 100644 --- a/model_zoo/official/cv/mobilenetv2/eval.py +++ b/model_zoo/official/cv/mobilenetv2/eval.py @@ -19,21 +19,14 @@ import time import os from mindspore import nn from mindspore.train.model import Model -from mindspore.common import dtype as mstype - from src.dataset import create_dataset from src.models import define_net, load_ckpt -from src.utils import switch_precision, set_context +from src.utils import context_device_init from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper -from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id - +from src.model_utils.device_adapter import get_device_id, get_device_num config.is_training = config.is_training_eval -config.device_id = get_device_id() -config.rank_id = get_rank_id() -config.rank_size = get_device_num() -config.run_distribute = config.rank_size > 1. def modelarts_process(): """ modelarts process """ @@ -96,13 +89,13 @@ def modelarts_process(): def eval_mobilenetv2(): config.dataset_path = os.path.join(config.dataset_path, 'validation_preprocess') print('\nconfig: \n', config) - set_context(config) + if not config.device_id: + config.device_id = get_device_id() + context_device_init(config) _, _, net = define_net(config, config.is_training) load_ckpt(net, config.pretrain_ckpt) - switch_precision(net, mstype.float16, config) - dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, config=config) step_size = dataset.get_dataset_size() if step_size == 0: diff --git a/model_zoo/official/cv/mobilenetv2/export.py b/model_zoo/official/cv/mobilenetv2/export.py index a83ac51edd6..9d643492da0 100644 --- a/model_zoo/official/cv/mobilenetv2/export.py +++ b/model_zoo/official/cv/mobilenetv2/export.py @@ -16,25 +16,16 @@ mobilenetv2 export file. """ import numpy as np -from mindspore import Tensor, export, context +from mindspore import Tensor, export from src.models import define_net, load_ckpt -from src.utils import set_context +from src.utils import context_device_init from src.model_utils.config import config -from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id +from src.model_utils.device_adapter import get_device_id from src.model_utils.moxing_adapter import moxing_wrapper - -config.device_id = get_device_id() -config.rank_id = get_rank_id() -config.rank_size = get_device_num() -config.run_distribute = config.rank_size > 1. config.batch_size = config.batch_size_export config.is_training = config.is_training_export -context.set_context(mode=context.GRAPH_MODE, device_target=config.platform) -if config.platform == "Ascend": - context.set_context(device_id=get_device_id()) - def modelarts_process(): pass @@ -42,7 +33,9 @@ def modelarts_process(): def export_mobilenetv2(): """ export_mobilenetv2 """ print('\nconfig: \n', config) - set_context(config) + if not config.device_id: + config.device_id = get_device_id() + context_device_init(config) _, _, net = define_net(config, config.is_training) load_ckpt(net, config.ckpt_file) diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh index a17c2847d86..d2d00007f54 100644 --- a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh @@ -46,7 +46,10 @@ run_ascend() echo "error: DATASET_PATH=$6 is not a directory or file" exit 1 fi - + RUN_DISTRIBUTE=True + if [ $2 -eq 1 ] ; then + RUN_DISTRIBUTE=False + fi BASEPATH=$(cd "`dirname $0`" || exit; pwd) CONFIG_FILE="${BASEPATH}/../$2" @@ -85,6 +88,7 @@ run_ascend() echo "start training for rank $RANK_ID, device $DEVICE_ID" env > env.log taskset -c $cmdopt python train.py \ + --run_distribute=$RUN_DISTRIBUTE \ --config_path=$CONFIG_FILE \ --platform=$1 \ --dataset_path=$6 \ diff --git a/model_zoo/official/cv/mobilenetv2/src/dataset.py b/model_zoo/official/cv/mobilenetv2/src/dataset.py index c3bc07b5612..1e1a5bdb667 100644 --- a/model_zoo/official/cv/mobilenetv2/src/dataset.py +++ b/model_zoo/official/cv/mobilenetv2/src/dataset.py @@ -49,31 +49,8 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1, enable_cache=Fa nfs_dataset_cache = None num_workers = config.num_workers - if config.platform == "Ascend": - rank_size = int(os.getenv("RANK_SIZE", '1')) - rank_id = int(os.getenv("RANK_ID", '0')) - if rank_size == 1: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True, - cache=nfs_dataset_cache) - else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True, - num_shards=rank_size, shard_id=rank_id, cache=nfs_dataset_cache) - elif config.platform == "GPU": - if do_train: - if config.run_distribute: - from mindspore.communication.management import get_rank, get_group_size - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True, - num_shards=get_group_size(), shard_id=get_rank(), - cache=nfs_dataset_cache) - else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True, - cache=nfs_dataset_cache) - else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True, - cache=nfs_dataset_cache) - elif config.platform == "CPU": - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, \ - shuffle=True, cache=nfs_dataset_cache) + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=do_train, + num_shards=config.rank_size, shard_id=config.rank_id, cache=nfs_dataset_cache) resize_height = config.image_height resize_width = config.image_width diff --git a/model_zoo/official/cv/mobilenetv2/src/metric.py b/model_zoo/official/cv/mobilenetv2/src/metric.py index 497dd6640c6..783d17fa280 100644 --- a/model_zoo/official/cv/mobilenetv2/src/metric.py +++ b/model_zoo/official/cv/mobilenetv2/src/metric.py @@ -42,14 +42,16 @@ class ClassifyCorrectCell(nn.Cell): >>> eval_net = nn.ClassifyCorrectCell(net) """ - def __init__(self, network): + def __init__(self, network, run_distribute): super(ClassifyCorrectCell, self).__init__(auto_prefix=False) self._network = network self.argmax = P.Argmax() self.equal = P.Equal() self.cast = P.Cast() self.reduce_sum = P.ReduceSum() - self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) + self.run_distribute = run_distribute + if run_distribute: + self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) def construct(self, data, label): outputs = self._network(data) @@ -58,8 +60,9 @@ class ClassifyCorrectCell(nn.Cell): y_correct = self.equal(y_pred, label) y_correct = self.cast(y_correct, mstype.float32) y_correct = self.reduce_sum(y_correct) - total_correct = self.allreduce(y_correct) - return (total_correct,) + if self.run_distribute: + y_correct = self.allreduce(y_correct) + return (y_correct,) class DistAccuracy(nn.Metric): diff --git a/model_zoo/official/cv/mobilenetv2/src/utils.py b/model_zoo/official/cv/mobilenetv2/src/utils.py index 1ca23a82d8d..e32470cdb47 100644 --- a/model_zoo/official/cv/mobilenetv2/src/utils.py +++ b/model_zoo/official/cv/mobilenetv2/src/utils.py @@ -14,72 +14,40 @@ # ============================================================================ from mindspore import context -from mindspore import nn -from mindspore.common import dtype as mstype from mindspore.context import ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.communication.management import get_rank, init, get_group_size from src.models import Monitor - -def switch_precision(net, data_type, config): - if config.platform == "Ascend": - net.to_float(data_type) - for _, cell in net.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.to_float(mstype.float32) - - def context_device_init(config): + if config.platform == "GPU" and config.run_distribute: + config.device_id = 0 + config.rank_id = 0 + config.rank_size = 1 if config.platform == "CPU": context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) - elif config.platform == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) - if config.run_distribute: - init() - context.set_auto_parallel_context(device_num=get_group_size(), - parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True) - - elif config.platform == "Ascend": + elif config.platform in ["Ascend", "GPU"]: context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id, save_graphs=False) if config.run_distribute: + init() + config.rank_id = get_rank() + config.rank_size = get_group_size() context.set_auto_parallel_context(device_num=config.rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True, all_reduce_fusion_config=[140]) - init() + gradients_mean=True) else: raise ValueError("Only support CPU, GPU and Ascend.") -def set_context(config): - if config.platform == "CPU": - context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, - save_graphs=False) - elif config.platform == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, - device_id=config.device_id, save_graphs=False) - elif config.platform == "GPU": - context.set_context(mode=context.GRAPH_MODE, - device_target=config.platform, save_graphs=False) - - def config_ckpoint(config, lr, step_size, model=None, eval_dataset=None): cb = [Monitor(lr_init=lr.asnumpy(), model=model, eval_dataset=eval_dataset)] - if config.platform in ("CPU", "GPU") or config.rank_id == 0: - - if config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config.keep_checkpoint_max) - - rank = 0 - if config.run_distribute: - rank = get_rank() - - ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(rank) + "/" - ckpt_cb = ModelCheckpoint(prefix="mobilenetv2", directory=ckpt_save_dir, config=config_ck) - cb += [ckpt_cb] + if config.save_checkpoint and config.rank_id == 0: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(config.rank_id) + "/" + ckpt_cb = ModelCheckpoint(prefix="mobilenetv2", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] return cb diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py index 3bfeafadc2e..d357316d17a 100644 --- a/model_zoo/official/cv/mobilenetv2/train.py +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -24,7 +24,6 @@ from mindspore import Tensor from mindspore.nn import WithLossCell, TrainOneStepCell from mindspore.nn.optim.momentum import Momentum from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits -from mindspore.communication.management import get_rank from mindspore.train.model import Model from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import save_checkpoint @@ -37,7 +36,7 @@ from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt from src.metric import DistAccuracy, ClassifyCorrectCell from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper -from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id +from src.model_utils.device_adapter import get_device_id, get_device_num set_seed(1) @@ -116,23 +115,16 @@ def build_params_groups(net): def train_mobilenetv2(): config.train_dataset_path = os.path.join(config.dataset_path, 'train') config.eval_dataset_path = os.path.join(config.dataset_path, 'validation_preprocess') - - config.device_id = get_device_id() - config.rank_id = get_rank_id() - config.rank_size = get_device_num() - if config.platform == 'Ascend': - config.run_distribute = config.rank_size > 1. - - print('\nconfig: {} \n'.format(config)) + if not config.device_id: + config.device_id = get_device_id() start = time.time() # set context and device init context_device_init(config) - + print('\nconfig: {} \n'.format(config)) # define network backbone_net, head_net, net = define_net(config, config.is_training) dataset = create_dataset(dataset_path=config.train_dataset_path, do_train=True, config=config, enable_cache=config.enable_cache, cache_session_id=config.cache_session_id) - eval_dataset = create_dataset(dataset_path=config.eval_dataset_path, do_train=False, config=config) step_size = dataset.get_dataset_size() if config.platform == "GPU": context.set_context(enable_graph_kernel=True) @@ -165,23 +157,27 @@ def train_mobilenetv2(): warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size)) - + metrics = {"acc"} + dist_eval_network = None + eval_dataset = None + if config.run_eval: + metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=config.rank_size)} + dist_eval_network = ClassifyCorrectCell(net, config.run_distribute) + eval_dataset = create_dataset(dataset_path=config.eval_dataset_path, do_train=False, config=config) if config.pretrain_ckpt == "" or config.freeze_layer != "backbone": - loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - group_params = build_params_groups(net) - opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) - - metrics = {"acc"} - dist_eval_network = None - if config.run_distribute: - metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=config.rank_size)} - dist_eval_network = ClassifyCorrectCell(net) - - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, - metrics=metrics, eval_network=dist_eval_network, - amp_level="O2", keep_batchnorm_fp32=False, - acc_level=config.acc_mode) + if config.platform == "Ascend": + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + group_params = build_params_groups(net) + opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, + metrics=metrics, eval_network=dist_eval_network, + amp_level="O2", keep_batchnorm_fp32=False, + acc_level=config.acc_mode) + else: + opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay) + model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network, + acc_level=config.acc_mode) cb = config_ckpoint(config, lr, step_size, model, eval_dataset) print("============== Starting Training ==============") model.train(epoch_size, dataset, callbacks=cb) @@ -197,9 +193,7 @@ def train_mobilenetv2(): features_path = config.train_dataset_path + '_features' idx_list = list(range(step_size)) - rank = 0 - if config.run_distribute: - rank = get_rank() + rank = config.rank_id save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') if not os.path.isdir(save_ckpt_path): os.mkdir(save_ckpt_path)