forked from mindspore-Ecosystem/mindspore
remove sens parameter
This commit is contained in:
parent
a26fdb83ee
commit
94d63b90f4
|
@ -88,7 +88,7 @@ class WithGradCell(Cell):
|
|||
Run in PyNative mode.
|
||||
|
||||
Args:
|
||||
network (Cell): The target network to wrap.
|
||||
network (Cell): The target network to wrap. The network only supports single output.
|
||||
loss_fn (Cell): Primitive loss function used to compute gradients. Default: None.
|
||||
sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape
|
||||
should be same as the `network` output. If None, we will fill one to a same type shape of
|
||||
|
@ -143,7 +143,7 @@ class TrainOneStepCell(Cell):
|
|||
parallel modes are available for training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
network (Cell): The training network. The network only supports single output.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ grad_overflow = P.FloatStatus()
|
|||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class DynamicLossScaleUpdateCell(Cell):
|
||||
r"""
|
||||
Dynamic Loss scale update cell.
|
||||
|
@ -168,27 +169,26 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
|
||||
This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
|
||||
Cell as args. The loss scale value can be updated in both host side or device side. The
|
||||
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input
|
||||
data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should
|
||||
be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`.
|
||||
If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored.
|
||||
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
|
||||
The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
|
||||
the value should be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic
|
||||
should be provied by Cell type of `scale_sense`. If Cell type of `scale_sense` is not None and Tensor type
|
||||
of `scale_sense` is provided, the Cell type of `scale_sense` will be ignored.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
network (Cell): The training network. The network only supports single output.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
|
||||
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
|
||||
is Tensor type, Tensor with shape :math:`()`. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **scaling_sens** (Tensor) - Tensor of shape :math:`()`.
|
||||
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
|
||||
|
||||
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
||||
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
|
||||
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
|
||||
|
||||
Examples:
|
||||
>>> net_with_loss = Net()
|
||||
|
@ -203,7 +203,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
>>> output = train_network(inputs, label, scaling_sens)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
def __init__(self, network, optimizer, scale_sense=None):
|
||||
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
|
@ -236,29 +236,29 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.scale_sense = None
|
||||
self.loss_scaling_manager = None
|
||||
if isinstance(scale_sense, Cell):
|
||||
self.loss_scaling_manager = scale_sense
|
||||
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
|
||||
name="scale_sense")
|
||||
if isinstance(scale_sense, Tensor):
|
||||
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self, data, label, sens=None):
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
loss = self.network(*inputs)
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# init overflow buffer
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_status(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
scaling_sens = self.scale_sense
|
||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||
grads = self.grad(self.network, weights)(data, label, scaling_sens_filled)
|
||||
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
||||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
@ -279,8 +279,8 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(self.scale_sense, cond)
|
||||
# if there is no overflow, do optimize
|
||||
if overflow:
|
||||
opt = False
|
||||
|
@ -288,3 +288,9 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
opt = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, opt)
|
||||
|
||||
def set_sense_scale(self, sens):
|
||||
"""If the user has set the sens in the training process and wants to reassign the value, he can call
|
||||
this function again to make modification, and sens needs to be of type Tensor."""
|
||||
if self.scale_sense and isinstance(sens, Tensor):
|
||||
self.self.scale_sense.set_data(sens)
|
||||
|
|
|
@ -182,7 +182,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|||
"are supported in current version. If you use `O2` option, please"
|
||||
"use `loss_scale_manager=None` or `FixedLossScaleManager`")
|
||||
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
||||
scale_update_cell=update_cell).set_train()
|
||||
scale_sense=update_cell).set_train()
|
||||
return network
|
||||
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
|
||||
return network
|
||||
|
|
|
@ -34,7 +34,6 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from ..context import ParallelMode
|
||||
from ..parallel._utils import _need_to_full, _to_full_tensor
|
||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
from ..common import dtype as mstype
|
||||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
from . import amp
|
||||
|
||||
|
@ -489,11 +488,6 @@ class Model:
|
|||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
scaling_sens = self._get_scaling_sens()
|
||||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*next_element)
|
||||
|
|
|
@ -148,7 +148,6 @@ class MSELoss(nn.Cell):
|
|||
def test_loss_scale_fp16_lr_overflow():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
lr = Tensor(np.ones([1], np.float32) * 0.1)
|
||||
net = NetFP16(16, 16)
|
||||
net.set_train()
|
||||
|
@ -157,9 +156,11 @@ def test_loss_scale_fp16_lr_overflow():
|
|||
optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
output_1 = train_network(inputs, label, scaling_sens)
|
||||
output_2 = train_network(inputs, label, scaling_sens)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
|
||||
dtype=mstype.float32))
|
||||
output_1 = train_network(inputs, label)
|
||||
output_2 = train_network(inputs, label)
|
||||
assert output_1[0].asnumpy() == output_2[0].asnumpy()
|
||||
assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
|
||||
|
||||
|
@ -188,16 +189,17 @@ def test_loss_scale_fp16_model_train_overflow():
|
|||
def test_loss_scale_fp16_opt_rmsprop_overflow():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full(1, np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
net = NetFP16(16, 16)
|
||||
net.set_train()
|
||||
|
||||
loss = MSELoss()
|
||||
optimizer = RMSProp(net.trainable_params(), learning_rate=0.1)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
output_1 = train_network(inputs, label, scaling_sens)
|
||||
output_2 = train_network(inputs, label, scaling_sens)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full(1, np.finfo(np.float32).max),
|
||||
dtype=mstype.float32))
|
||||
output_1 = train_network(inputs, label)
|
||||
output_2 = train_network(inputs, label)
|
||||
assert output_1[0].asnumpy() == output_2[0].asnumpy()
|
||||
assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
|
||||
|
||||
|
@ -208,7 +210,6 @@ def test_loss_scale_fp16_opt_rmsprop_overflow():
|
|||
def test_loss_scale_fp16_overflow():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
net = NetFP16(16, 16)
|
||||
net.set_train()
|
||||
|
||||
|
@ -216,8 +217,10 @@ def test_loss_scale_fp16_overflow():
|
|||
optimizer = Lamb(net.trainable_params(), learning_rate=0.01)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
net_with_loss.set_grad()
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
output_1 = train_network(inputs, label, scaling_sens)
|
||||
output_2 = train_network(inputs, label, scaling_sens)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
|
||||
dtype=mstype.float32))
|
||||
output_1 = train_network(inputs, label)
|
||||
output_2 = train_network(inputs, label)
|
||||
assert output_1[0].asnumpy() == output_2[0].asnumpy()
|
||||
assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
|
||||
|
|
|
@ -177,7 +177,7 @@ def test_compile_grad_error():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
scale_manager = DynamicLossScaleManager()
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell)
|
||||
train_network.set_train()
|
||||
with pytest.raises(TypeError) as e:
|
||||
train_network(inputs, label)
|
||||
|
|
|
@ -100,70 +100,71 @@ class MSELoss(nn.Cell):
|
|||
def test_momentum_compile():
|
||||
inputs = Tensor(np.ones([15, 1]).astype(np.float32))
|
||||
label = Tensor(np.zeros([15, 1]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
|
||||
net = Net(1, 1)
|
||||
|
||||
loss = MSELoss()
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32))
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
def test_compile_fp16_not_overflow():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
|
||||
net = NetFP16(16, 16)
|
||||
|
||||
loss = MSELoss()
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32))
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
def test_compile_fp16_lr_overflow():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
lr = Tensor(np.ones([1], np.float32) * 0.1)
|
||||
net = NetFP16(16, 16)
|
||||
loss = MSELoss()
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
|
||||
dtype=mstype.float32))
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
def test_compile_fp16_overflow():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
net = NetFP16(16, 16)
|
||||
|
||||
loss = MSELoss()
|
||||
optimizer = Lamb(net.trainable_params(), learning_rate=0.01)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
|
||||
dtype=mstype.float32))
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
def test_compile_fp16_lr_overflow_with_lossscale_update():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
lr = Tensor(np.ones([1], np.float32) * 0.1)
|
||||
net = NetFP16(16, 16)
|
||||
loss = MSELoss()
|
||||
|
@ -172,9 +173,9 @@ def test_compile_fp16_lr_overflow_with_lossscale_update():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
scale_manager = DynamicLossScaleManager()
|
||||
manager = scale_manager.get_update_cell()
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
|
@ -209,7 +210,6 @@ def test_compile_f16_model_train_fixed():
|
|||
def test_compile_fp16_lr_overflow_fixed_feed():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
lr = Tensor(np.ones([1], np.float32) * 0.1)
|
||||
net = NetFP16(16, 16)
|
||||
loss = MSELoss()
|
||||
|
@ -218,16 +218,15 @@ def test_compile_fp16_lr_overflow_fixed_feed():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
scale_manager = FixedLossScaleManager()
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell)
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
def test_compile_fp16_lr_overflow_dynamic_feed():
|
||||
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
|
||||
lr = Tensor(np.ones([1], np.float32) * 0.1)
|
||||
net = NetFP16(16, 16)
|
||||
loss = MSELoss()
|
||||
|
@ -236,9 +235,9 @@ def test_compile_fp16_lr_overflow_dynamic_feed():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
scale_manager = DynamicLossScaleManager()
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell)
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
|
@ -253,7 +252,7 @@ def test_compile_fp16_lr_overflow_fixed_graph():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
scale_manager = FixedLossScaleManager(drop_overflow_update=True)
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell)
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
@ -270,7 +269,7 @@ def test_compile_fp16_lr_overflow_dynamic_graph():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
scale_manager = DynamicLossScaleManager()
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=update_cell)
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
@ -279,7 +278,6 @@ def test_compile_fp16_lr_overflow_dynamic_graph():
|
|||
def adam_compile(loss_scale=1.0):
|
||||
inputs = Tensor(np.ones([15, 1]).astype(np.float32))
|
||||
label = Tensor(np.zeros([15, 1]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
|
||||
net = Net(1, 1)
|
||||
|
||||
loss = MSELoss()
|
||||
|
@ -287,14 +285,17 @@ def adam_compile(loss_scale=1.0):
|
|||
use_nesterov=False, weight_decay=0.0, loss_scale=loss_scale)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
|
||||
scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32))
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
output = train_network(inputs, label)
|
||||
print("the result is ", output)
|
||||
|
||||
|
||||
def test_adam_compile():
|
||||
adam_compile()
|
||||
|
||||
|
||||
def test_adam_loss_scale_compile():
|
||||
""" test setting loss_scale to 1e-40 """
|
||||
adam_compile(loss_scale=1e-40)
|
||||
|
|
Loading…
Reference in New Issue