forked from mindspore-Ecosystem/mindspore
debug_for_new_interface_forwardvalueandgrad
This commit is contained in:
parent
a1232aa987
commit
26f6fea675
|
@ -18,7 +18,6 @@ from types import FunctionType, MethodType
|
||||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||||
_get_parallel_mode)
|
_get_parallel_mode)
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from ...common.tensor import Tensor
|
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.parameter import Parameter, ParameterTuple
|
from ...common.parameter import Parameter, ParameterTuple
|
||||||
from ...ops import composite as C
|
from ...ops import composite as C
|
||||||
|
@ -197,15 +196,16 @@ class ForwardValueAndGrad(Cell):
|
||||||
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
||||||
Default: False.
|
Default: False.
|
||||||
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
|
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
|
||||||
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair
|
the input parameter.
|
||||||
parameter, the key must be sens.
|
|
||||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||||
|
- **(\*sens)** - A sensitivity (gradient with respect to output) as the input of backpropagation.
|
||||||
|
If network has single output, the sens is a tensor.
|
||||||
|
If network has multiple outputs, the sens is the tuple(tensor).
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
- **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running.
|
- **forward value** - The result of network forward running.
|
||||||
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
|
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
|
@ -219,8 +219,8 @@ class ForwardValueAndGrad(Cell):
|
||||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
>>> #1) Using the WithLossCell existing provide
|
>>> #1) Using the WithLossCell existing provide
|
||||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||||
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True)
|
||||||
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
|
>>> loss, grads = forward_value_and_grad(inputs, labels)
|
||||||
>>>
|
>>>
|
||||||
>>> #2) Using user-defined WithLossCell
|
>>> #2) Using user-defined WithLossCell
|
||||||
>>> class MyWithLossCell(Cell):
|
>>> class MyWithLossCell(Cell):
|
||||||
|
@ -238,40 +238,40 @@ class ForwardValueAndGrad(Cell):
|
||||||
... return self._backbone
|
... return self._backbone
|
||||||
...
|
...
|
||||||
>>> loss_net = MyWithLossCell(net, loss_fn)
|
>>> loss_net = MyWithLossCell(net, loss_fn)
|
||||||
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True)
|
||||||
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
|
>>> loss, grads = forward_value_and_grad(inputs, labels)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False, sens=1.0):
|
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
|
||||||
super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
|
super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
|
||||||
if not isinstance(network, (Cell, FunctionType, MethodType)):
|
if not isinstance(network, (Cell, FunctionType, MethodType)):
|
||||||
raise TypeError(f"The type of training network should be cell, function type or method type, "
|
raise TypeError(f"The type of training network should be cell, function type or method type, "
|
||||||
f"but got '{type(network)}'")
|
f"but got '{type(network)}'")
|
||||||
|
if not isinstance(get_all, bool):
|
||||||
|
raise TypeError(f"The type of get_all should be bool, but got '{type(get_all)}'")
|
||||||
|
if not isinstance(get_by_list, bool):
|
||||||
|
raise TypeError(f"The type of get_by_list should be bool, but got '{type(get_by_list)}'")
|
||||||
if get_by_list and not isinstance(weights, ParameterTuple):
|
if get_by_list and not isinstance(weights, ParameterTuple):
|
||||||
raise TypeError(f"When get_by_list is set to True, the parameters of training network should be "
|
raise TypeError(f"When get_by_list is set to True, the parameters of training network should be "
|
||||||
f"ParameterTuple type, but got '{type(weights)}'")
|
f"ParameterTuple type, but got '{type(weights)}'")
|
||||||
if get_by_list is not True and weights is not None:
|
|
||||||
raise TypeError(f"When get_by_list is set to False, the parameters of training network should be "
|
|
||||||
f"NoneType, but got '{type(weights)}'")
|
|
||||||
self.network = network
|
self.network = network
|
||||||
self.network.set_grad()
|
if isinstance(network, Cell):
|
||||||
|
self.network.set_grad()
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.get_all = get_all
|
self.get_all = get_all
|
||||||
self.get_by_list = get_by_list
|
self.get_by_list = get_by_list
|
||||||
self.sens_param = sens_param
|
self.sens_param = sens_param
|
||||||
self.sens = sens
|
|
||||||
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
||||||
|
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
weights = self.weights
|
grad_inputs = inputs
|
||||||
loss = self.network(*inputs)
|
|
||||||
if self.sens_param:
|
if self.sens_param:
|
||||||
sens = self.sens
|
inputs = inputs[:-1]
|
||||||
if not isinstance(self.sens, Tensor):
|
loss = self.network(*inputs)
|
||||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
if self.get_by_list:
|
||||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
grads = self.grad(self.network, self.weights)(*grad_inputs)
|
||||||
else:
|
else:
|
||||||
grads = self.grad(self.network, weights)(*inputs)
|
grads = self.grad(self.network)(*grad_inputs)
|
||||||
return loss, grads
|
return loss, grads
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -414,40 +414,14 @@ def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1):
|
||||||
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||||
optimizer = Momentum(weights, 0.1, 0.9)
|
optimizer = Momentum(weights, 0.1, 0.9)
|
||||||
|
|
||||||
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True,
|
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
|
||||||
sens=1.0)
|
|
||||||
losses = []
|
losses = []
|
||||||
for i in range(0, epoch):
|
for i in range(0, epoch):
|
||||||
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
||||||
).astype(np.float32) * 0.01)
|
).astype(np.float32) * 0.01)
|
||||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||||
loss, grads = train_network(data, label)
|
sens = Tensor(np.ones([1]).astype(np.float32))
|
||||||
grads = F.identity(grads)
|
loss, grads = train_network(data, label, sens)
|
||||||
optimizer(grads)
|
|
||||||
losses.append(loss)
|
|
||||||
assert (losses[-1].asnumpy() < 0.8)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=338):
|
|
||||||
net = resnet50(num_classes)
|
|
||||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
||||||
net_with_criterion = WithLossCell(net, criterion)
|
|
||||||
net_with_criterion.set_train()
|
|
||||||
|
|
||||||
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
|
||||||
optimizer = Momentum(weights, 0.1, 0.9)
|
|
||||||
|
|
||||||
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True,
|
|
||||||
sens=1.0)
|
|
||||||
losses = []
|
|
||||||
for i in range(0, epoch):
|
|
||||||
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
|
||||||
).astype(np.float32) * 0.01)
|
|
||||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
|
||||||
loss, grads = train_network(data, label)
|
|
||||||
grads = F.identity(grads)
|
grads = F.identity(grads)
|
||||||
optimizer(grads)
|
optimizer(grads)
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|
Loading…
Reference in New Issue