debug_for_new_interface_forwardvalueandgrad

This commit is contained in:
lvliang 2021-03-01 17:12:59 +08:00
parent a1232aa987
commit 26f6fea675
2 changed files with 25 additions and 51 deletions

View File

@ -18,7 +18,6 @@ from types import FunctionType, MethodType
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from ...common.tensor import Tensor
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C from ...ops import composite as C
@ -197,15 +196,16 @@ class ForwardValueAndGrad(Cell):
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
Default: False. Default: False.
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair the input parameter.
parameter, the key must be sens.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Inputs: Inputs:
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
- **(\*sens)** - A sensitivity (gradient with respect to output) as the input of backpropagation.
If network has single output, the sens is a tensor.
If network has multiple outputs, the sens is the tuple(tensor).
Outputs: Outputs:
- **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running. - **forward value** - The result of network forward running.
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs. - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
Supported Platforms: Supported Platforms:
@ -219,8 +219,8 @@ class ForwardValueAndGrad(Cell):
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> #1) Using the WithLossCell existing provide >>> #1) Using the WithLossCell existing provide
>>> loss_net = nn.WithLossCell(net, loss_fn) >>> loss_net = nn.WithLossCell(net, loss_fn)
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True)
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) >>> loss, grads = forward_value_and_grad(inputs, labels)
>>> >>>
>>> #2) Using user-defined WithLossCell >>> #2) Using user-defined WithLossCell
>>> class MyWithLossCell(Cell): >>> class MyWithLossCell(Cell):
@ -238,40 +238,40 @@ class ForwardValueAndGrad(Cell):
... return self._backbone ... return self._backbone
... ...
>>> loss_net = MyWithLossCell(net, loss_fn) >>> loss_net = MyWithLossCell(net, loss_fn)
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True)
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) >>> loss, grads = forward_value_and_grad(inputs, labels)
""" """
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False, sens=1.0): def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
super(ForwardValueAndGrad, self).__init__(auto_prefix=False) super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
if not isinstance(network, (Cell, FunctionType, MethodType)): if not isinstance(network, (Cell, FunctionType, MethodType)):
raise TypeError(f"The type of training network should be cell, function type or method type, " raise TypeError(f"The type of training network should be cell, function type or method type, "
f"but got '{type(network)}'") f"but got '{type(network)}'")
if not isinstance(get_all, bool):
raise TypeError(f"The type of get_all should be bool, but got '{type(get_all)}'")
if not isinstance(get_by_list, bool):
raise TypeError(f"The type of get_by_list should be bool, but got '{type(get_by_list)}'")
if get_by_list and not isinstance(weights, ParameterTuple): if get_by_list and not isinstance(weights, ParameterTuple):
raise TypeError(f"When get_by_list is set to True, the parameters of training network should be " raise TypeError(f"When get_by_list is set to True, the parameters of training network should be "
f"ParameterTuple type, but got '{type(weights)}'") f"ParameterTuple type, but got '{type(weights)}'")
if get_by_list is not True and weights is not None:
raise TypeError(f"When get_by_list is set to False, the parameters of training network should be "
f"NoneType, but got '{type(weights)}'")
self.network = network self.network = network
self.network.set_grad() if isinstance(network, Cell):
self.network.set_grad()
self.weights = weights self.weights = weights
self.get_all = get_all self.get_all = get_all
self.get_by_list = get_by_list self.get_by_list = get_by_list
self.sens_param = sens_param self.sens_param = sens_param
self.sens = sens
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
def construct(self, *inputs): def construct(self, *inputs):
weights = self.weights grad_inputs = inputs
loss = self.network(*inputs)
if self.sens_param: if self.sens_param:
sens = self.sens inputs = inputs[:-1]
if not isinstance(self.sens, Tensor): loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) if self.get_by_list:
grads = self.grad(self.network, weights)(*inputs, sens) grads = self.grad(self.network, self.weights)(*grad_inputs)
else: else:
grads = self.grad(self.network, weights)(*inputs) grads = self.grad(self.network)(*grad_inputs)
return loss, grads return loss, grads

View File

@ -414,40 +414,14 @@ def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1):
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer = Momentum(weights, 0.1, 0.9) optimizer = Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
sens=1.0)
losses = [] losses = []
for i in range(0, epoch): for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224] data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01) ).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32)) label = Tensor(np.ones([batch_size]).astype(np.int32))
loss, grads = train_network(data, label) sens = Tensor(np.ones([1]).astype(np.float32))
grads = F.identity(grads) loss, grads = train_network(data, label, sens)
optimizer(grads)
losses.append(loss)
assert (losses[-1].asnumpy() < 0.8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=338):
net = resnet50(num_classes)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
net_with_criterion.set_train()
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer = Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True,
sens=1.0)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss, grads = train_network(data, label)
grads = F.identity(grads) grads = F.identity(grads)
optimizer(grads) optimizer(grads)
losses.append(loss) losses.append(loss)