From 94d63b90f4ad266ea27e9c4713bfb72131b96d1c Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Fri, 4 Sep 2020 11:39:59 +0800 Subject: [PATCH] remove sens parameter --- mindspore/nn/wrap/cell_wrapper.py | 4 +- mindspore/nn/wrap/loss_scale.py | 56 ++++++++++--------- mindspore/train/amp.py | 2 +- mindspore/train/model.py | 6 -- .../st/pynative/loss_scale/test_loss_scale.py | 27 +++++---- .../python/optimizer/test_debug_location.py | 2 +- .../test_optimizer_with_loss_scale.py | 53 +++++++++--------- 7 files changed, 77 insertions(+), 73 deletions(-) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 91194e8d784..86ab01375ef 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -88,7 +88,7 @@ class WithGradCell(Cell): Run in PyNative mode. Args: - network (Cell): The target network to wrap. + network (Cell): The target network to wrap. The network only supports single output. loss_fn (Cell): Primitive loss function used to compute gradients. Default: None. sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape should be same as the `network` output. If None, we will fill one to a same type shape of @@ -143,7 +143,7 @@ class TrainOneStepCell(Cell): parallel modes are available for training. Args: - network (Cell): The training network. + network (Cell): The training network. The network only supports single output. optimizer (Cell): Optimizer for updating the weights. sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 19bd6b6580f..1dfc91743c8 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -49,6 +49,7 @@ grad_overflow = P.FloatStatus() def _tensor_grad_overflow(grad): return grad_overflow(grad) + class DynamicLossScaleUpdateCell(Cell): r""" Dynamic Loss scale update cell. @@ -168,27 +169,26 @@ class TrainOneStepWithLossScaleCell(Cell): This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update Cell as args. The loss scale value can be updated in both host side or device side. The - TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input - data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should - be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`. - If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored. + TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. + The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side, + the value should be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic + should be provied by Cell type of `scale_sense`. If Cell type of `scale_sense` is not None and Tensor type + of `scale_sense` is provided, the Cell type of `scale_sense` will be ignored. Args: - network (Cell): The training network. + network (Cell): The training network. The network only supports single output. optimizer (Cell): Optimizer for updating the weights. - scale_update_cell(Cell): The loss scaling update logic cell. Default: None. + scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value + is Tensor type, Tensor with shape :math:`()`. Default: None. Inputs: - - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - - **scaling_sens** (Tensor) - Tensor of shape :math:`()`. + - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. Outputs: Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value. - **loss** (Tensor) - Tensor with shape :math:`()`. - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. - - **loss_scale** (Tensor) - Tensor with shape :math:`()`. Examples: >>> net_with_loss = Net() @@ -203,7 +203,7 @@ class TrainOneStepWithLossScaleCell(Cell): >>> output = train_network(inputs, label, scaling_sens) """ - def __init__(self, network, optimizer, scale_update_cell=None): + def __init__(self, network, optimizer, scale_sense=None): super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() @@ -236,29 +236,29 @@ class TrainOneStepWithLossScaleCell(Cell): self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") + self.scale_sense = None + self.loss_scaling_manager = None + if isinstance(scale_sense, Cell): + self.loss_scaling_manager = scale_sense + self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), + name="scale_sense") + if isinstance(scale_sense, Tensor): + self.scale_sense = Parameter(scale_sense, name='scale_sense') @C.add_flags(has_effect=True) - def construct(self, data, label, sens=None): + def construct(self, *inputs): weights = self.weights - loss = self.network(data, label) + loss = self.network(*inputs) init = False if not self.gpu_target: # init overflow buffer init = self.alloc_status() # clear overflow buffer self.clear_status(init) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens + scaling_sens = self.scale_sense scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) - grads = self.grad(self.network, weights)(data, label, scaling_sens_filled) + grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) # apply grad reducer on grads grads = self.grad_reducer(grads) @@ -279,8 +279,8 @@ class TrainOneStepWithLossScaleCell(Cell): else: cond = self.less_equal(self.base, flag_sum) overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + if self.loss_scaling_manager is not None: + overflow = self.loss_scaling_manager(self.scale_sense, cond) # if there is no overflow, do optimize if overflow: opt = False @@ -288,3 +288,9 @@ class TrainOneStepWithLossScaleCell(Cell): opt = self.optimizer(grads) ret = (loss, cond, scaling_sens) return F.depend(ret, opt) + + def set_sense_scale(self, sens): + """If the user has set the sens in the training process and wants to reassign the value, he can call + this function again to make modification, and sens needs to be of type Tensor.""" + if self.scale_sense and isinstance(sens, Tensor): + self.self.scale_sense.set_data(sens) diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index d00609bde37..c25e4739d4e 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -182,7 +182,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): "are supported in current version. If you use `O2` option, please" "use `loss_scale_manager=None` or `FixedLossScaleManager`") network = nn.TrainOneStepWithLossScaleCell(network, optimizer, - scale_update_cell=update_cell).set_train() + scale_sense=update_cell).set_train() return network network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() return network diff --git a/mindspore/train/model.py b/mindspore/train/model.py index d67d4fef98d..2e478ed798d 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -34,7 +34,6 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..context import ParallelMode from ..parallel._utils import _need_to_full, _to_full_tensor from ..parallel._cost_model_context import _set_multi_subgraphs -from ..common import dtype as mstype from .dataset_helper import DatasetHelper, connect_network_with_dataset from . import amp @@ -489,11 +488,6 @@ class Model: "return two elements, but got {}".format(len_element)) cb_params.cur_step_num += 1 - overflow = False - if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): - scaling_sens = self._get_scaling_sens() - next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) - cb_params.train_dataset_element = next_element list_callback.step_begin(run_context) outputs = self._train_network(*next_element) diff --git a/tests/st/pynative/loss_scale/test_loss_scale.py b/tests/st/pynative/loss_scale/test_loss_scale.py index 246ded8900b..b34fa257b69 100644 --- a/tests/st/pynative/loss_scale/test_loss_scale.py +++ b/tests/st/pynative/loss_scale/test_loss_scale.py @@ -148,7 +148,6 @@ class MSELoss(nn.Cell): def test_loss_scale_fp16_lr_overflow(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) lr = Tensor(np.ones([1], np.float32) * 0.1) net = NetFP16(16, 16) net.set_train() @@ -157,9 +156,11 @@ def test_loss_scale_fp16_lr_overflow(): optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9) net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) - output_1 = train_network(inputs, label, scaling_sens) - output_2 = train_network(inputs, label, scaling_sens) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), + dtype=mstype.float32)) + output_1 = train_network(inputs, label) + output_2 = train_network(inputs, label) assert output_1[0].asnumpy() == output_2[0].asnumpy() assert output_1[1].asnumpy() == output_2[1].asnumpy() == True @@ -188,16 +189,17 @@ def test_loss_scale_fp16_model_train_overflow(): def test_loss_scale_fp16_opt_rmsprop_overflow(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full(1, np.finfo(np.float32).max), dtype=mstype.float32) net = NetFP16(16, 16) net.set_train() loss = MSELoss() optimizer = RMSProp(net.trainable_params(), learning_rate=0.1) net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) - output_1 = train_network(inputs, label, scaling_sens) - output_2 = train_network(inputs, label, scaling_sens) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full(1, np.finfo(np.float32).max), + dtype=mstype.float32)) + output_1 = train_network(inputs, label) + output_2 = train_network(inputs, label) assert output_1[0].asnumpy() == output_2[0].asnumpy() assert output_1[1].asnumpy() == output_2[1].asnumpy() == True @@ -208,7 +210,6 @@ def test_loss_scale_fp16_opt_rmsprop_overflow(): def test_loss_scale_fp16_overflow(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) net = NetFP16(16, 16) net.set_train() @@ -216,8 +217,10 @@ def test_loss_scale_fp16_overflow(): optimizer = Lamb(net.trainable_params(), learning_rate=0.01) net_with_loss = WithLossCell(net, loss) net_with_loss.set_grad() - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) - output_1 = train_network(inputs, label, scaling_sens) - output_2 = train_network(inputs, label, scaling_sens) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), + dtype=mstype.float32)) + output_1 = train_network(inputs, label) + output_2 = train_network(inputs, label) assert output_1[0].asnumpy() == output_2[0].asnumpy() assert output_1[1].asnumpy() == output_2[1].asnumpy() == True diff --git a/tests/ut/python/optimizer/test_debug_location.py b/tests/ut/python/optimizer/test_debug_location.py index a29fe453faf..e4ab48e78a1 100644 --- a/tests/ut/python/optimizer/test_debug_location.py +++ b/tests/ut/python/optimizer/test_debug_location.py @@ -177,7 +177,7 @@ def test_compile_grad_error(): net_with_loss = WithLossCell(net, loss) scale_manager = DynamicLossScaleManager() update_cell = scale_manager.get_update_cell() - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) train_network.set_train() with pytest.raises(TypeError) as e: train_network(inputs, label) diff --git a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py index cccaf8e3b82..84feccb672c 100644 --- a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py +++ b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py @@ -100,70 +100,71 @@ class MSELoss(nn.Cell): def test_momentum_compile(): inputs = Tensor(np.ones([15, 1]).astype(np.float32)) label = Tensor(np.zeros([15, 1]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) net = Net(1, 1) loss = MSELoss() optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32)) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) def test_compile_fp16_not_overflow(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) net = NetFP16(16, 16) loss = MSELoss() optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32)) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) def test_compile_fp16_lr_overflow(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) lr = Tensor(np.ones([1], np.float32) * 0.1) net = NetFP16(16, 16) loss = MSELoss() optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9) net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), + dtype=mstype.float32)) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) def test_compile_fp16_overflow(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) net = NetFP16(16, 16) loss = MSELoss() optimizer = Lamb(net.trainable_params(), learning_rate=0.01) net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), np.finfo(np.float32).max), + dtype=mstype.float32)) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) def test_compile_fp16_lr_overflow_with_lossscale_update(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) lr = Tensor(np.ones([1], np.float32) * 0.1) net = NetFP16(16, 16) loss = MSELoss() @@ -172,9 +173,9 @@ def test_compile_fp16_lr_overflow_with_lossscale_update(): net_with_loss = WithLossCell(net, loss) scale_manager = DynamicLossScaleManager() manager = scale_manager.get_update_cell() - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) @@ -209,7 +210,6 @@ def test_compile_f16_model_train_fixed(): def test_compile_fp16_lr_overflow_fixed_feed(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) lr = Tensor(np.ones([1], np.float32) * 0.1) net = NetFP16(16, 16) loss = MSELoss() @@ -218,16 +218,15 @@ def test_compile_fp16_lr_overflow_fixed_feed(): net_with_loss = WithLossCell(net, loss) scale_manager = FixedLossScaleManager() update_cell = scale_manager.get_update_cell() - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) def test_compile_fp16_lr_overflow_dynamic_feed(): inputs = Tensor(np.ones([16, 16]).astype(np.float32)) label = Tensor(np.zeros([16, 16]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) lr = Tensor(np.ones([1], np.float32) * 0.1) net = NetFP16(16, 16) loss = MSELoss() @@ -236,9 +235,9 @@ def test_compile_fp16_lr_overflow_dynamic_feed(): net_with_loss = WithLossCell(net, loss) scale_manager = DynamicLossScaleManager() update_cell = scale_manager.get_update_cell() - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) @@ -253,7 +252,7 @@ def test_compile_fp16_lr_overflow_fixed_graph(): net_with_loss = WithLossCell(net, loss) scale_manager = FixedLossScaleManager(drop_overflow_update=True) update_cell = scale_manager.get_update_cell() - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) train_network.set_train() output = train_network(inputs, label) print("the result is ", output) @@ -270,7 +269,7 @@ def test_compile_fp16_lr_overflow_dynamic_graph(): net_with_loss = WithLossCell(net, loss) scale_manager = DynamicLossScaleManager() update_cell = scale_manager.get_update_cell() - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell) train_network.set_train() output = train_network(inputs, label) print("the result is ", output) @@ -279,7 +278,6 @@ def test_compile_fp16_lr_overflow_dynamic_graph(): def adam_compile(loss_scale=1.0): inputs = Tensor(np.ones([15, 1]).astype(np.float32)) label = Tensor(np.zeros([15, 1]).astype(np.float32)) - scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) net = Net(1, 1) loss = MSELoss() @@ -287,14 +285,17 @@ def adam_compile(loss_scale=1.0): use_nesterov=False, weight_decay=0.0, loss_scale=loss_scale) net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32)) train_network.set_train() - output = train_network(inputs, label, scaling_sens) + output = train_network(inputs, label) print("the result is ", output) + def test_adam_compile(): adam_compile() + def test_adam_loss_scale_compile(): """ test setting loss_scale to 1e-40 """ adam_compile(loss_scale=1e-40)