From 82b88ade658b67356b48338be61fa02590b99764 Mon Sep 17 00:00:00 2001 From: linqingke Date: Wed, 1 Sep 2021 16:35:40 +0800 Subject: [PATCH] [Boost]Add MindBoost. --- cmake/package.cmake | 1 + cmake/package_win.cmake | 1 + mindspore/{nn/acc => boost}/__init__.py | 10 +-- mindspore/{nn/acc => boost}/adasum.py | 0 mindspore/{nn/acc => boost}/base.py | 55 +++++++++++++- mindspore/{nn/acc/acc.py => boost/boost.py} | 30 ++++---- .../boost_cell_wrapper.py} | 72 +++++++++---------- .../{nn/acc => boost}/grad_accumulation.py | 0 mindspore/{nn/acc => boost}/grad_freeze.py | 0 .../acc => boost}/less_batch_normalization.py | 4 +- mindspore/nn/cell.py | 14 ++-- mindspore/train/amp.py | 24 +++---- mindspore/train/model.py | 28 ++++---- model_zoo/official/cv/mobilenetv2/README.md | 1 + .../official/cv/mobilenetv2/README_CN.md | 2 +- .../cv/mobilenetv2/default_config.yaml | 2 +- ...fig_acc.yaml => default_config_boost.yaml} | 2 +- .../cv/mobilenetv2/default_config_cpu.yaml | 2 +- .../cv/mobilenetv2/default_config_gpu.yaml | 2 +- model_zoo/official/cv/mobilenetv2/train.py | 2 +- model_zoo/official/cv/resnet/README.md | 2 +- model_zoo/official/cv/resnet/README_CN.md | 2 +- .../config/resnet101_imagenet2012_config.yaml | 2 +- .../config/resnet18_cifar10_config.yaml | 2 +- .../config/resnet18_cifar10_config_gpu.yaml | 2 +- .../config/resnet18_imagenet2012_config.yaml | 2 +- .../resnet18_imagenet2012_config_gpu.yaml | 2 +- .../config/resnet34_imagenet2012_config.yaml | 2 +- .../config/resnet50_cifar10_config.yaml | 2 +- ...net50_imagenet2012_Ascend_Thor_config.yaml | 2 +- ...> resnet50_imagenet2012_Boost_config.yaml} | 2 +- ...resnet50_imagenet2012_GPU_Thor_config.yaml | 2 +- .../config/resnet50_imagenet2012_config.yaml | 2 +- .../resnet/config/resnet_benchmark_GPU.yaml | 2 +- .../se-resnet50_imagenet2012_config.yaml | 2 +- model_zoo/official/cv/resnet/train.py | 4 +- .../official/nlp/bert/pretrain_config.yaml | 6 +- .../nlp/bert/pretrain_config_Ascend_Thor.yaml | 8 +-- .../nlp/bert/src/model_utils/config.py | 10 +-- .../models/resnet50/src_thor/dataset.py | 8 +-- 40 files changed, 186 insertions(+), 132 deletions(-) rename mindspore/{nn/acc => boost}/__init__.py (85%) rename mindspore/{nn/acc => boost}/adasum.py (100%) rename mindspore/{nn/acc => boost}/base.py (73%) rename mindspore/{nn/acc/acc.py => boost/boost.py} (84%) rename mindspore/{nn/acc/acc_cell_wrapper.py => boost/boost_cell_wrapper.py} (87%) rename mindspore/{nn/acc => boost}/grad_accumulation.py (100%) rename mindspore/{nn/acc => boost}/grad_freeze.py (100%) rename mindspore/{nn/acc => boost}/less_batch_normalization.py (98%) rename model_zoo/official/cv/mobilenetv2/{default_config_acc.yaml => default_config_boost.yaml} (99%) rename model_zoo/official/cv/resnet/config/{resnet50_imagenet2012_Acc_config.yaml => resnet50_imagenet2012_Boost_config.yaml} (99%) diff --git a/cmake/package.cmake b/cmake/package.cmake index 9213f0d5427..04c4007312f 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -277,6 +277,7 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/mindrecord ${CMAKE_SOURCE_DIR}/mindspore/numpy ${CMAKE_SOURCE_DIR}/mindspore/train + ${CMAKE_SOURCE_DIR}/mindspore/boost ${CMAKE_SOURCE_DIR}/mindspore/common ${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/communication diff --git a/cmake/package_win.cmake b/cmake/package_win.cmake index bbed4e0ff07..aaac79b921a 100644 --- a/cmake/package_win.cmake +++ b/cmake/package_win.cmake @@ -175,6 +175,7 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/mindrecord ${CMAKE_SOURCE_DIR}/mindspore/numpy ${CMAKE_SOURCE_DIR}/mindspore/train + ${CMAKE_SOURCE_DIR}/mindspore/boost ${CMAKE_SOURCE_DIR}/mindspore/common ${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/communication diff --git a/mindspore/nn/acc/__init__.py b/mindspore/boost/__init__.py similarity index 85% rename from mindspore/nn/acc/__init__.py rename to mindspore/boost/__init__.py index cd79a6f603c..59265a52bfc 100644 --- a/mindspore/nn/acc/__init__.py +++ b/mindspore/boost/__init__.py @@ -13,22 +13,22 @@ # limitations under the License. # ============================================================================ """ -Accelerating. +MindBoost(Beta Feature) Provide auto accelerating for network, such as Less BN, Gradient Freeze. """ -from .acc import * +from .boost import * from .base import * -from .acc_cell_wrapper import * +from .boost_cell_wrapper import * from .less_batch_normalization import * from .grad_freeze import * from .grad_accumulation import * from .adasum import * -__all__ = ['AutoAcc', +__all__ = ['AutoBoost', 'OptimizerProcess', 'ParameterProcess', - 'AccTrainOneStepCell', 'AccTrainOneStepWithLossScaleCell', + 'BoostTrainOneStepCell', 'BoostTrainOneStepWithLossScaleCell', 'LessBN', 'GradientFreeze', 'FreezeOpt', 'freeze_cell', 'GradientAccumulation', diff --git a/mindspore/nn/acc/adasum.py b/mindspore/boost/adasum.py similarity index 100% rename from mindspore/nn/acc/adasum.py rename to mindspore/boost/adasum.py diff --git a/mindspore/nn/acc/base.py b/mindspore/boost/base.py similarity index 73% rename from mindspore/nn/acc/base.py rename to mindspore/boost/base.py index eb15f9cfe65..7a80d469d95 100644 --- a/mindspore/nn/acc/base.py +++ b/mindspore/boost/base.py @@ -26,10 +26,34 @@ __all__ = ["OptimizerProcess", "ParameterProcess"] class OptimizerProcess: """ - Process optimizer for ACC. + Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags + and creating new optimizers. Args: opt (Cell): Optimizer used. + + Examples: + >>> from mindspore import Tensor, Parameter, nn + >>> from mindspore.ops import operations as P + >>> from mindspore.boost import OptimizerProcess + >>> + >>> class Net(nn.Cell): + ... def __init__(self, in_features, out_features): + ... super(Net, self).__init__() + ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), + ... name='weight') + ... self.matmul = ops.MatMul() + ... + ... def construct(self, x): + ... output = self.matmul(x, self.weight) + ... return output + ... + >>> size, in_features, out_features = 16, 16, 10 + >>> network = Net(in_features, out_features) + >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> optimizer_process = OptimizerProcess(optimizer) + >>> optimizer_process.add_grad_centralization(network) + >>> optimizer = optimizer_process.generate_new_optimizer() """ def __init__(self, opt): if isinstance(opt, LARS): @@ -113,7 +137,34 @@ class OptimizerProcess: class ParameterProcess: """ - Process parameter for ACC. + Process parameter for Boost. Currently, this class supports creating group parameters + and automatically setting gradient segmentation point. + + Examples: + >>> from mindspore import Tensor, Parameter, nn + >>> from mindspore.ops import operations as P + >>> from mindspore.boost import OptimizerProcess + >>> + >>> class Net(nn.Cell): + ... def __init__(self, in_features, out_features): + ... super(Net, self).__init__() + ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), + ... name='weight') + ... self.weight2 = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), + ... name='weight2') + ... self.matmul = ops.MatMul() + ... self.matmul2 = ops.MatMul() + ... + ... def construct(self, x): + ... output = self.matmul(x, self.weight) + ... output2 = self.matmul2(x, self.weight2) + ... return output + output2 + ... + >>> size, in_features, out_features = 16, 16, 10 + >>> network = Net(in_features, out_features) + >>> new_parameter = net.trainable_params()[:1] + >>> parameter_process = ParameterProcess() + >>> group_params = parameter_process.generate_group_params(new_parameter, net.trainable_params()) """ def __init__(self): self._parameter_indices = 1 diff --git a/mindspore/nn/acc/acc.py b/mindspore/boost/boost.py similarity index 84% rename from mindspore/nn/acc/acc.py rename to mindspore/boost/boost.py index c08051c86e1..fac89322608 100644 --- a/mindspore/nn/acc/acc.py +++ b/mindspore/boost/boost.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""acc""" +"""boost""" from .less_batch_normalization import LessBN from .grad_freeze import GradientFreeze from .base import OptimizerProcess, ParameterProcess -__all__ = ["AutoAcc"] +__all__ = ["AutoBoost"] -_acc_config_level = { +_boost_config_level = { "O0": { "less_bn": False, "grad_freeze": False, @@ -36,19 +36,19 @@ _acc_config_level = { "adasum": True}} -class AutoAcc: +class AutoBoost: """ Provide auto accelerating for network. Args: - level (Str): acc config level. + level (Str): boost config level. """ def __init__(self, level, kwargs): - if level not in _acc_config_level.keys(): + if level not in _boost_config_level.keys(): level = 'O0' self.level = level - acc_config = _acc_config_level[level] - self._acc_config = acc_config + boost_config = _boost_config_level[level] + self._boost_config = boost_config self._fn_flag = True self._gc_flag = True self._param_groups = 10 @@ -62,13 +62,13 @@ class AutoAcc: def _get_configuration(self, kwargs): """Get configuration.""" for key, val in kwargs.items(): - if key not in self._acc_config_func_map.keys(): + if key not in self._boost_config_func_map.keys(): continue - self._acc_config_func_map[key](self, val) + self._boost_config_func_map[key](self, val) def network_auto_process_train(self, network, optimizer): """Network train.""" - if self._acc_config["less_bn"]: + if self._boost_config["less_bn"]: network = LessBN(network, fn_flag=self._fn_flag) optimizer_process = OptimizerProcess(optimizer) group_params = self._param_processer.assign_parameter_group(network.trainable_params(), @@ -79,18 +79,18 @@ class AutoAcc: optimizer_process.add_grad_centralization(network) optimizer = optimizer_process.generate_new_optimizer() - if self._acc_config["grad_freeze"]: + if self._boost_config["grad_freeze"]: freeze_processer = GradientFreeze(self._param_groups, self._freeze_type, self._freeze_p, self._total_steps) network, optimizer = freeze_processer.freeze_generate(network, optimizer) - if self._acc_config["adasum"]: + if self._boost_config["adasum"]: setattr(optimizer, "adasum", True) return network, optimizer def network_auto_process_eval(self, network): """Network eval.""" - if self._acc_config["less_bn"]: + if self._boost_config["less_bn"]: network = LessBN(network) return network @@ -120,7 +120,7 @@ class AutoAcc: gradient_groups = list(gradient_groups) self._gradient_groups = gradient_groups - _acc_config_func_map = { + _boost_config_func_map = { "fn_flag": set_fn_flag, "gc_flag": set_gc_flag, "param_groups": set_param_groups, diff --git a/mindspore/nn/acc/acc_cell_wrapper.py b/mindspore/boost/boost_cell_wrapper.py similarity index 87% rename from mindspore/nn/acc/acc_cell_wrapper.py rename to mindspore/boost/boost_cell_wrapper.py index 63652f34e06..b72a8ab3d39 100644 --- a/mindspore/nn/acc/acc_cell_wrapper.py +++ b/mindspore/boost/boost_cell_wrapper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Acc Mode Cell Wrapper.""" +"""Boost Mode Cell Wrapper.""" from mindspore.nn.wrap import TrainOneStepCell import mindspore.context as context from mindspore.context import ParallelMode, get_auto_parallel_context @@ -31,7 +31,7 @@ from .adasum import AdaSum from .grad_accumulation import gradient_accumulation_op, gradient_clear_op -__all__ = ["AccTrainOneStepCell", "AccTrainOneStepWithLossScaleCell"] +__all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"] _get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight") @@ -51,9 +51,9 @@ def _save_weight_process(new_parameter, old_parameter): return P.Assign()(new_parameter, old_parameter) -class AccTrainOneStepCell(TrainOneStepCell): +class BoostTrainOneStepCell(TrainOneStepCell): r""" - Acc Network training package class. + Boost Network training package class. Wraps the network with an optimizer. The resulting Cell is trained with input '\*inputs'. The backward graph will be created in the construct function to update the parameter. Different @@ -82,29 +82,29 @@ class AccTrainOneStepCell(TrainOneStepCell): >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> #1) Using the WithLossCell existing provide >>> loss_net = nn.WithLossCell(net, loss_fn) - >>> train_net = nn.acc.AccTrainOneStepCell(loss_net, optim) + >>> train_net = boost.BoostTrainOneStepCell(loss_net, optim) >>> >>> #2) Using user-defined WithLossCell >>> class MyWithLossCell(Cell): - mindspore. def __init__(self, backbone, loss_fn): - mindspore. super(MyWithLossCell, self).__init__(auto_prefix=False) - mindspore. self._backbone = backbone - mindspore. self._loss_fn = loss_fn - mindspore. - mindspore. def construct(self, x, y, label): - mindspore. out = self._backbone(x, y) - mindspore. return self._loss_fn(out, label) - mindspore. - mindspore. @property - mindspore. def backbone_network(self): - mindspore. return self._backbone - mindspore. + ... def __init__(self, backbone, loss_fn): + ... super(MyWithLossCell, self).__init__(auto_prefix=False) + ... self._backbone = backbone + ... self._loss_fn = loss_fn + ... + ... def construct(self, x, y, label): + ... out = self._backbone(x, y) + ... return self._loss_fn(out, label) + ... + ... @property + ... def backbone_network(self): + ... return self._backbone + ... >>> loss_net = MyWithLossCell(net, loss_fn) - >>> train_net = nn.acc.AccTrainOneStepCellTrainOneStepCell(loss_net, optim) + >>> train_net = boost.BoostTrainOneStepCellTrainOneStepCell(loss_net, optim) """ def __init__(self, network, optimizer, sens=1.0): - super(AccTrainOneStepCell, self).__init__(network, optimizer, sens) + super(BoostTrainOneStepCell, self).__init__(network, optimizer, sens) self.hyper_map = C.HyperMap() self.freeze = isinstance(optimizer, FreezeOpt) if not self.freeze: @@ -240,13 +240,13 @@ class AccTrainOneStepCell(TrainOneStepCell): return is_enable -class AccTrainOneStepWithLossScaleCell(AccTrainOneStepCell): +class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell): r""" - Acc Network training with loss scaling. + Boost Network training with loss scaling. This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update Cell as args. The loss scale value can be updated in both host side or device side. The - AccTrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. + BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side, the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic must be provied by Cell type of `scale_sense`. @@ -282,16 +282,16 @@ class AccTrainOneStepWithLossScaleCell(AccTrainOneStepCell): >>> from mindspore.common import dtype as mstype >>> >>> class Net(nn.Cell): - mindspore. def __init__(self, in_features, out_features): - mindspore. super(Net, self).__init__() - mindspore. self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), - mindspore. name='weight') - mindspore. self.matmul = P.MatMul() - mindspore. - mindspore. def construct(self, x): - mindspore. output = self.matmul(x, self.weight) - mindspore. return output - mindspore. + ... def __init__(self, in_features, out_features): + ... super(Net, self).__init__() + ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), + ... name='weight') + ... self.matmul = ops.MatMul() + ... + ... def construct(self, x): + ... output = self.matmul(x, self.weight) + ... return output + ... >>> size, in_features, out_features = 16, 16, 10 >>> #1) when the type of scale_sense is Cell: >>> net = Net(in_features, out_features) @@ -299,7 +299,7 @@ class AccTrainOneStepWithLossScaleCell(AccTrainOneStepCell): >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> net_with_loss = WithLossCell(net, loss) >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) - >>> train_network = nn.acc.AccTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) + >>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) >>> output = train_network(input, labels) @@ -312,11 +312,11 @@ class AccTrainOneStepWithLossScaleCell(AccTrainOneStepCell): >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) - >>> train_network = nn.acc.AccTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) + >>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) >>> output = train_network(inputs, label) """ def __init__(self, network, optimizer, scale_sense): - super(AccTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) + super(BoostTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) self.base = Tensor(1, mstype.float32) self.reduce_sum = P.ReduceSum(keep_dims=False) self.less_equal = P.LessEqual() diff --git a/mindspore/nn/acc/grad_accumulation.py b/mindspore/boost/grad_accumulation.py similarity index 100% rename from mindspore/nn/acc/grad_accumulation.py rename to mindspore/boost/grad_accumulation.py diff --git a/mindspore/nn/acc/grad_freeze.py b/mindspore/boost/grad_freeze.py similarity index 100% rename from mindspore/nn/acc/grad_freeze.py rename to mindspore/boost/grad_freeze.py diff --git a/mindspore/nn/acc/less_batch_normalization.py b/mindspore/boost/less_batch_normalization.py similarity index 98% rename from mindspore/nn/acc/less_batch_normalization.py rename to mindspore/boost/less_batch_normalization.py index c6bbc92650f..e8d9e590217 100644 --- a/mindspore/nn/acc/less_batch_normalization.py +++ b/mindspore/boost/less_batch_normalization.py @@ -91,13 +91,13 @@ class LessBN(Cell): network (Cell): Network to be modified. Examples: - >>> network = acc.LessBN(network) + >>> network = boost.LessBN(network) """ def __init__(self, network, fn_flag=False): super(LessBN, self).__init__() self.network = network - self.network.set_acc("less_bn") + self.network.set_boost("less_bn") self.network.update_cell_prefix() if fn_flag: self._convert_to_less_bn_net(self.network) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 65c35432f3e..72a6867825d 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1145,29 +1145,29 @@ class Cell(Cell_): self._add_init_args(**flags) return self - def set_acc(self, acc_type): + def set_boost(self, boost_type): """ In order to improve the network performance, configure the network auto enable to accelerate the algorithm in the algorithm library. - If `acc_type is not in the algorithm library`, Please view the algorithm in the algorithm library + If `boost_type is not in the algorithm library`, Please view the algorithm in the algorithm library through `algorithm library`. Note: Some acceleration algorithms may affect the accuracy of the network, please choose carefully. Args: - acc_type (str): accelerate algorithm. + boost_type (str): accelerate algorithm. Returns: Cell, the cell itself. Raises: - ValueError: If acc_type is not in the algorithm library. + ValueError: If boost_type is not in the algorithm library. """ - if acc_type not in ("less_bn",): - raise ValueError("The acc_type is not in the algorithm library.") - flags = {"less_bn": acc_type == "less_bn"} + if boost_type not in ("less_bn",): + raise ValueError("The boost_type is not in the algorithm library.") + flags = {"less_bn": boost_type == "less_bn"} self.add_flags_recursive(**flags) return self diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 16b73cae5fc..077c7ed7ccd 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -17,13 +17,13 @@ from .. import nn from .._checkparam import Validator as validator from .._checkparam import Rel from ..common import dtype as mstype -from ..nn import acc from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell from ..ops import functional as F from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager from ..context import ParallelMode +from .. import boost from .. import context @@ -111,7 +111,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): return network -def build_train_network(network, optimizer, loss_fn=None, level='O0', acc_level='O0', **kwargs): +def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs): """ Build the mixed precision training cell automatically. @@ -147,9 +147,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', acc_level= (with property `drop_overflow_update=False` ), or a `ValueError` exception will be raised. """ validator.check_value_type('network', network, nn.Cell) - validator.check_value_type('optimizer', optimizer, (nn.Optimizer, acc.FreezeOpt)) + validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt)) validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN) - validator.check('acc_level', acc_level, "", ['O0', 'O1', 'O2'], Rel.IN) + validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN) if level == "auto": device_target = context.get_context('device_target') @@ -175,9 +175,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', acc_level= if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network = _VirtualDatasetCell(network) - enable_acc = False - if acc_level in ["O1", "O2"]: - enable_acc = True + enable_boost = False + if boost_level in ["O1", "O2"]: + enable_boost = True loss_scale = 1.0 if config["loss_scale_manager"] is not None: @@ -193,17 +193,17 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', acc_level= if _get_pipeline_stages() > 1: network = _TrainPipelineWithLossScaleCell(network, optimizer, scale_sense=update_cell).set_train() - elif enable_acc: - network = acc.AccTrainOneStepWithLossScaleCell(network, optimizer, - scale_sense=update_cell).set_train() + elif enable_boost: + network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer, + scale_sense=update_cell).set_train() else: network = nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense=update_cell).set_train() return network if _get_pipeline_stages() > 1: network = _TrainPipelineAccuStepCell(network, optimizer).set_train() - elif enable_acc: - network = acc.AccTrainOneStepCell(network, optimizer, loss_scale).set_train() + elif enable_boost: + network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train() else: network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() return network diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 0c5ce12206d..725d1ec2424 100644 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -32,7 +32,7 @@ from ..parallel._ps_context import _is_role_pserver, _is_role_sched from ..nn.metrics import Loss from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell -from ..nn.acc import acc +from ..boost import AutoBoost from ..context import ParallelMode from ..parallel._cost_model_context import _set_multi_subgraphs from .dataset_helper import DatasetHelper, connect_network_with_dataset @@ -89,13 +89,13 @@ class Model: O2 is recommended on GPU, O3 is recommended on Ascend.The more detailed explanation of `amp_level` setting can be found at `mindspore.amp.build_train_network` . - acc_level (str): Option for argument `level` in `mindspore.acc` , level for acc mode + boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode training. Supports ["O0", "O1", "O2"]. Default: "O0". - O0: Do not change. - - O1: Enable the acc mode, the performance is improved by about 20%, and + - O1: Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy. - - O2: Enable the acc mode, the performance is improved by about 30%, and + - O2: Enable the boost mode, the performance is improved by about 30%, and the accuracy is reduced by less than 3%. Examples: >>> from mindspore import Model, nn @@ -132,7 +132,7 @@ class Model: """ def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, - eval_indexes=None, amp_level="O0", acc_level="O0", **kwargs): + eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs): self._network = network self._loss_fn = loss_fn self._optimizer = optimizer @@ -141,7 +141,7 @@ class Model: self._keep_bn_fp32 = True self._check_kwargs(kwargs) self._amp_level = amp_level - self._acc_level = acc_level + self._boost_level = boost_level self._eval_network = eval_network self._process_amp_args(kwargs) self._parallel_mode = _get_parallel_mode() @@ -152,7 +152,7 @@ class Model: self._check_amp_level_arg(optimizer, amp_level) self._check_for_graph_cell(kwargs) - self._build_acc_network(kwargs) + self._build_boost_network(kwargs) self._train_network = self._build_train_network() self._build_eval_network(metrics, self._eval_network, eval_indexes) self._build_predict_network() @@ -194,16 +194,16 @@ class Model: if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self): raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.') - def _build_acc_network(self, kwargs): - """Build the acc network.""" - processor = acc.AutoAcc(self._acc_level, kwargs) + def _build_boost_network(self, kwargs): + """Build the boost network.""" + processor = AutoBoost(self._boost_level, kwargs) if processor.level not in ["O1", "O2"]: return if self._optimizer is None: - logger.warning("In acc mode, the optimizer must be defined.") + logger.warning("In boost mode, the optimizer must be defined.") return if self._eval_network is None and self._metrics is None: - logger.warning("In acc mode, the eval_network and metrics cannot be undefined at the same time.") + logger.warning("In boost mode, the eval_network and metrics cannot be undefined at the same time.") return self._network, self._optimizer = processor.network_auto_process_train(self._network, self._optimizer) @@ -222,7 +222,7 @@ class Model: self._optimizer, self._loss_fn, level=self._amp_level, - acc_level=self._acc_level, + boost_level=self._boost_level, loss_scale_manager=self._loss_scale_manager, keep_batchnorm_fp32=self._keep_bn_fp32) else: @@ -230,7 +230,7 @@ class Model: self._optimizer, self._loss_fn, level=self._amp_level, - acc_level=self._acc_level, + boost_level=self._boost_level, keep_batchnorm_fp32=self._keep_bn_fp32) elif self._loss_fn: if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): diff --git a/model_zoo/official/cv/mobilenetv2/README.md b/model_zoo/official/cv/mobilenetv2/README.md index ea5febe75a2..4deadc1165a 100644 --- a/model_zoo/official/cv/mobilenetv2/README.md +++ b/model_zoo/official/cv/mobilenetv2/README.md @@ -210,6 +210,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil │ ├──local_adapter.py # Get local ID │ └──moxing_adapter.py # Parameter processing ├── default_config.yaml # Training parameter profile(ascend) + ├── default_config_boost.yaml # Training parameter profile(ascend boost) ├── default_config_cpu.yaml # Training parameter profile(cpu) ├── default_config_gpu.yaml # Training parameter profile(gpu) ├── train.py # training script diff --git a/model_zoo/official/cv/mobilenetv2/README_CN.md b/model_zoo/official/cv/mobilenetv2/README_CN.md index e2a72027c9f..e9c4b105558 100644 --- a/model_zoo/official/cv/mobilenetv2/README_CN.md +++ b/model_zoo/official/cv/mobilenetv2/README_CN.md @@ -211,7 +211,7 @@ MobileNetV2总体网络架构如下: │ ├──local_adapter.py # 获取本地id │ └──moxing_adapter.py # 云上数据准备 ├── default_config.yaml # 训练配置参数(ascend) - ├── default_config_acc.yaml # 训练配置参数(ascend acc模式) + ├── default_config_boost.yaml # 训练配置参数(ascend boost模式) ├── default_config_cpu.yaml # 训练配置参数(cpu) ├── default_config_gpu.yaml # 训练配置参数(gpu) ├── train.py # 训练脚本 diff --git a/model_zoo/official/cv/mobilenetv2/default_config.yaml b/model_zoo/official/cv/mobilenetv2/default_config.yaml index e7060077ab7..b8dd4b98af7 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config.yaml @@ -18,7 +18,7 @@ num_classes: 1000 image_height: 224 image_width: 224 num_workers: 32 -acc_mode: "O0" +boost_mode: "O0" batch_size: 256 epoch_size: 200 warmup_epochs: 4 diff --git a/model_zoo/official/cv/mobilenetv2/default_config_acc.yaml b/model_zoo/official/cv/mobilenetv2/default_config_boost.yaml similarity index 99% rename from model_zoo/official/cv/mobilenetv2/default_config_acc.yaml rename to model_zoo/official/cv/mobilenetv2/default_config_boost.yaml index 59837379e48..381876a6370 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config_acc.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config_boost.yaml @@ -18,7 +18,7 @@ num_classes: 1000 image_height: 224 image_width: 224 num_workers: 32 -acc_mode: "O1" +boost_mode: "O1" batch_size: 256 epoch_size: 200 warmup_epochs: 4 diff --git a/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml b/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml index 83b6e587019..9cab593aebd 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config_cpu.yaml @@ -18,7 +18,7 @@ num_classes: 26 image_height: 224 image_width: 224 num_workers: 8 -acc_mode: "O0" +boost_mode: "O0" batch_size: 150 epoch_size: 15 warmup_epochs: 0 diff --git a/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml b/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml index d032f4e2c55..40a3f57d608 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config_gpu.yaml @@ -18,7 +18,7 @@ num_classes: 1000 image_height: 224 image_width: 224 num_workers: 8 -acc_mode: "O0" +boost_mode: "O0" batch_size: 150 epoch_size: 200 warmup_epochs: 0 diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py index d357316d17a..8e86f816daf 100644 --- a/model_zoo/official/cv/mobilenetv2/train.py +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -172,7 +172,7 @@ def train_mobilenetv2(): model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, eval_network=dist_eval_network, amp_level="O2", keep_batchnorm_fp32=False, - acc_level=config.acc_mode) + boost_level=config.boost_mode) else: opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay) diff --git a/model_zoo/official/cv/resnet/README.md b/model_zoo/official/cv/resnet/README.md index 385643a59ea..0dc17d7bd4f 100644 --- a/model_zoo/official/cv/resnet/README.md +++ b/model_zoo/official/cv/resnet/README.md @@ -209,7 +209,7 @@ If you want to run in modelarts, please check the official documentation of [mod ├── resnet18_imagenet2012_config_gpu.yaml ├── resnet34_imagenet2012_config.yaml ├── resnet50_cifar10_config.yaml - ├── resnet50_imagenet2012_Acc_config.yaml # High performance version: The performance is improved by more than 10% and the precision decrease less than 1% + ├── resnet50_imagenet2012_Boost_config.yaml # High performance version: The performance is improved by more than 10% and the precision decrease less than 1% ├── resnet50_imagenet2012_Ascend_Thor_config.yaml ├── resnet50_imagenet2012_config.yaml ├── resnet50_imagenet2012_GPU_Thor_config.yaml diff --git a/model_zoo/official/cv/resnet/README_CN.md b/model_zoo/official/cv/resnet/README_CN.md index 9bc3fc882c0..449425c4a32 100755 --- a/model_zoo/official/cv/resnet/README_CN.md +++ b/model_zoo/official/cv/resnet/README_CN.md @@ -195,7 +195,7 @@ bash run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH] [CONFIG_PATH] ├── resnet18_imagenet2012_config_gpu.yaml ├── resnet34_imagenet2012_config.yaml ├── resnet50_cifar10_config.yaml - ├── resnet50_imagenet2012_Acc_config.yaml # 高性能版本:性能提高超过10%而精度下降少于1% + ├── resnet50_imagenet2012_Boost_config.yaml # 高性能版本:性能提高超过10%而精度下降少于1% ├── resnet50_imagenet2012_Ascend_Thor_config.yaml ├── resnet50_imagenet2012_config.yaml ├── resnet50_imagenet2012_GPU_Thor_config.yaml diff --git a/model_zoo/official/cv/resnet/config/resnet101_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/config/resnet101_imagenet2012_config.yaml index 49d250cda15..d9144b0b42c 100644 --- a/model_zoo/official/cv/resnet/config/resnet101_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet101_imagenet2012_config.yaml @@ -50,7 +50,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" all_reduce_fusion_config: diff --git a/model_zoo/official/cv/resnet/config/resnet18_cifar10_config.yaml b/model_zoo/official/cv/resnet/config/resnet18_cifar10_config.yaml index ceded05be0e..e7c5050e3fe 100644 --- a/model_zoo/official/cv/resnet/config/resnet18_cifar10_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet18_cifar10_config.yaml @@ -50,7 +50,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" train_image_size: 224 diff --git a/model_zoo/official/cv/resnet/config/resnet18_cifar10_config_gpu.yaml b/model_zoo/official/cv/resnet/config/resnet18_cifar10_config_gpu.yaml index 0d2dd0c3a60..0c1ebc025dd 100644 --- a/model_zoo/official/cv/resnet/config/resnet18_cifar10_config_gpu.yaml +++ b/model_zoo/official/cv/resnet/config/resnet18_cifar10_config_gpu.yaml @@ -50,7 +50,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" train_image_size: 224 diff --git a/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config.yaml index 087b7ef31bf..e0fcbce8bc0 100644 --- a/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config.yaml @@ -52,7 +52,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" train_image_size: 224 diff --git a/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config_gpu.yaml b/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config_gpu.yaml index e0ad3eb8569..ee04dc202c1 100644 --- a/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config_gpu.yaml +++ b/model_zoo/official/cv/resnet/config/resnet18_imagenet2012_config_gpu.yaml @@ -52,7 +52,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" train_image_size: 224 diff --git a/model_zoo/official/cv/resnet/config/resnet34_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/config/resnet34_imagenet2012_config.yaml index 93158a3ed9e..2371608aa55 100644 --- a/model_zoo/official/cv/resnet/config/resnet34_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet34_imagenet2012_config.yaml @@ -52,7 +52,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" train_image_size: 224 diff --git a/model_zoo/official/cv/resnet/config/resnet50_cifar10_config.yaml b/model_zoo/official/cv/resnet/config/resnet50_cifar10_config.yaml index d6feb0d326e..9e0f04001d6 100644 --- a/model_zoo/official/cv/resnet/config/resnet50_cifar10_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet50_cifar10_config.yaml @@ -50,7 +50,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" all_reduce_fusion_config: diff --git a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Ascend_Thor_config.yaml b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Ascend_Thor_config.yaml index 0c2ca84f64a..b0761155dc3 100644 --- a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Ascend_Thor_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Ascend_Thor_config.yaml @@ -51,7 +51,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "HeUniform" dense_init: "HeUniform" all_reduce_fusion_config: diff --git a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Acc_config.yaml b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Boost_config.yaml similarity index 99% rename from model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Acc_config.yaml rename to model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Boost_config.yaml index 543903d9f1a..02ff74ebd72 100644 --- a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Acc_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_Boost_config.yaml @@ -52,7 +52,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O1" +boost_mode: "O1" conv_init: "TruncatedNormal" dense_init: "RandomNormal" all_reduce_fusion_config: diff --git a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_GPU_Thor_config.yaml b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_GPU_Thor_config.yaml index 89d5ef79f7f..37a704efe3b 100644 --- a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_GPU_Thor_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_GPU_Thor_config.yaml @@ -51,7 +51,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "HeUniform" dense_init: "HeUniform" all_reduce_fusion_config: diff --git a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_config.yaml index 4ef9459f019..b9db5f7bab7 100644 --- a/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/config/resnet50_imagenet2012_config.yaml @@ -52,7 +52,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" all_reduce_fusion_config: diff --git a/model_zoo/official/cv/resnet/config/resnet_benchmark_GPU.yaml b/model_zoo/official/cv/resnet/config/resnet_benchmark_GPU.yaml index a1fac9d1c69..7fc403790cc 100644 --- a/model_zoo/official/cv/resnet/config/resnet_benchmark_GPU.yaml +++ b/model_zoo/official/cv/resnet/config/resnet_benchmark_GPU.yaml @@ -25,7 +25,7 @@ eval: False save_ckpt: False mode_name: "GRAPH" dtype: "fp16" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" train_image_size: 224 diff --git a/model_zoo/official/cv/resnet/config/se-resnet50_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/config/se-resnet50_imagenet2012_config.yaml index 7031d024259..71f24869d08 100644 --- a/model_zoo/official/cv/resnet/config/se-resnet50_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/config/se-resnet50_imagenet2012_config.yaml @@ -53,7 +53,7 @@ eval_interval: 1 enable_cache: False cache_session_id: "" mode_name: "GRAPH" -acc_mode: "O0" +boost_mode: "O0" conv_init: "XavierUniform" dense_init: "TruncatedNormal" all_reduce_fusion_config: diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 5fff69eb3af..d663640c22a 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -110,7 +110,7 @@ def set_parameter(): gradients_mean=True) set_algo_parameters(elementwise_op_strategy_follow=True) if config.net_name == "resnet50" or config.net_name == "se-resnet50": - if config.acc_mode not in ["O1", "O2"]: + if config.boost_mode not in ["O1", "O2"]: context.set_auto_parallel_context(all_reduce_fusion_config=config.all_reduce_fusion_config) elif config.net_name in ["resnet101", "resnet152"]: context.set_auto_parallel_context(all_reduce_fusion_config=config.all_reduce_fusion_config) @@ -258,7 +258,7 @@ def train_net(): model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network) else: model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, - amp_level="O2", acc_level=config.acc_mode, keep_batchnorm_fp32=False, + amp_level="O2", boost_level=config.boost_mode, keep_batchnorm_fp32=False, eval_network=dist_eval_network) if config.optimizer == "Thor" and config.dataset == "imagenet2012": diff --git a/model_zoo/official/nlp/bert/pretrain_config.yaml b/model_zoo/official/nlp/bert/pretrain_config.yaml index 01aa73cd379..ff71adfed23 100644 --- a/model_zoo/official/nlp/bert/pretrain_config.yaml +++ b/model_zoo/official/nlp/bert/pretrain_config.yaml @@ -35,7 +35,7 @@ schema_dir: '' # ============================================================================== # pretrain related batch_size: 32 -# Available: [base, nezha, large, large_acc] +# Available: [base, nezha, large, large_boost] bert_network: 'base' loss_scale_value: 65536 scale_factor: 2 @@ -138,8 +138,8 @@ large_net_cfg: dtype: mstype.float32 compute_type: mstype.float16 # Accelerated large network which is only supported in Ascend yet. -large_acc_batch_size: 24 -large_acc_net_cfg: +large_boost_batch_size: 24 +large_boost_net_cfg: seq_length: 512 vocab_size: 30522 hidden_size: 1024 diff --git a/model_zoo/official/nlp/bert/pretrain_config_Ascend_Thor.yaml b/model_zoo/official/nlp/bert/pretrain_config_Ascend_Thor.yaml index 93670c29e50..3ee365cfed3 100644 --- a/model_zoo/official/nlp/bert/pretrain_config_Ascend_Thor.yaml +++ b/model_zoo/official/nlp/bert/pretrain_config_Ascend_Thor.yaml @@ -35,8 +35,8 @@ schema_dir: '' # ============================================================================== # pretrain related batch_size: 20 -# Available: [base, nezha, large, large_acc] -bert_network: 'large_acc' +# Available: [base, nezha, large, large_boost] +bert_network: 'large_boost' loss_scale_value: 65536 scale_factor: 2 scale_window: 1000 @@ -138,8 +138,8 @@ large_net_cfg: dtype: mstype.float32 compute_type: mstype.float16 # Accelerated large network which is only supported in Ascend yet. -large_acc_batch_size: 20 -large_acc_net_cfg: +large_boost_batch_size: 20 +large_boost_net_cfg: seq_length: 512 vocab_size: 30522 hidden_size: 1024 diff --git a/model_zoo/official/nlp/bert/src/model_utils/config.py b/model_zoo/official/nlp/bert/src/model_utils/config.py index d36689c29cd..d04d9862143 100644 --- a/model_zoo/official/nlp/bert/src/model_utils/config.py +++ b/model_zoo/official/nlp/bert/src/model_utils/config.py @@ -141,8 +141,8 @@ def extra_operations(cfg): cfg.nezha_net_cfg.compute_type = parse_dtype(cfg.nezha_net_cfg.compute_type) cfg.large_net_cfg.dtype = parse_dtype(cfg.large_net_cfg.dtype) cfg.large_net_cfg.compute_type = parse_dtype(cfg.large_net_cfg.compute_type) - cfg.large_acc_net_cfg.dtype = parse_dtype(cfg.large_acc_net_cfg.dtype) - cfg.large_acc_net_cfg.compute_type = parse_dtype(cfg.large_acc_net_cfg.compute_type) + cfg.large_boost_net_cfg.dtype = parse_dtype(cfg.large_boost_net_cfg.dtype) + cfg.large_boost_net_cfg.compute_type = parse_dtype(cfg.large_boost_net_cfg.compute_type) if cfg.bert_network == 'base': cfg.batch_size = cfg.base_batch_size _bert_net_cfg = cfg.base_net_cfg @@ -152,9 +152,9 @@ def extra_operations(cfg): elif cfg.bert_network == 'large': cfg.batch_size = cfg.large_batch_size _bert_net_cfg = cfg.large_net_cfg - elif cfg.bert_network == 'large_acc': - cfg.batch_size = cfg.large_acc_batch_size - _bert_net_cfg = cfg.large_acc_net_cfg + elif cfg.bert_network == 'large_boost': + cfg.batch_size = cfg.large_boost_batch_size + _bert_net_cfg = cfg.large_boost_net_cfg else: pass cfg.bert_net_cfg = BertConfig(**_bert_net_cfg.__dict__) diff --git a/tests/st/networks/models/resnet50/src_thor/dataset.py b/tests/st/networks/models/resnet50/src_thor/dataset.py index 83efe83e3a8..21cd223c0ad 100644 --- a/tests/st/networks/models/resnet50/src_thor/dataset.py +++ b/tests/st/networks/models/resnet50/src_thor/dataset.py @@ -40,12 +40,12 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): rank_id = int(os.getenv("RANK_ID")) if do_train: if device_num == 1: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=16, shuffle=True) else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True, num_shards=device_num, shard_id=rank_id) else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False, + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=False, num_shards=device_num, shard_id=rank_id) image_size = 224 @@ -73,7 +73,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): type_cast_op = C2.TypeCast(mstype.int32) data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=24) - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=24) + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12) # apply batch operations data_set = data_set.batch(batch_size, drop_remainder=True)