forked from mindspore-Ecosystem/mindspore
!4058 modify parameter input
Merge pull request !4058 from lijiaqi/cell_inputs
This commit is contained in:
commit
cfae4096d2
|
@ -383,9 +383,13 @@ class Cell:
|
|||
inputs (Function or Cell): inputs of construct method.
|
||||
"""
|
||||
parallel_inputs_run = []
|
||||
if len(inputs) > self._construct_inputs_num:
|
||||
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.
|
||||
format(len(inputs), self._construct_inputs_num))
|
||||
# judge if *args exists in input
|
||||
if self.argspec[1] is not None:
|
||||
prefix = self.argspec[1]
|
||||
for i in range(len(inputs)):
|
||||
key = prefix + str(i)
|
||||
self._construct_inputs_names = self._construct_inputs_names + (key,)
|
||||
self._construct_inputs_num = self._construct_inputs_num + 1
|
||||
for i, tensor in enumerate(inputs):
|
||||
key = self._construct_inputs_names[i]
|
||||
# if input is not used, self.parameter_layout_dict may not contain the key
|
||||
|
@ -412,7 +416,7 @@ class Cell:
|
|||
from mindspore._extends.parse.parser import get_parse_method_of_class
|
||||
|
||||
fn = get_parse_method_of_class(self)
|
||||
inspect.getfullargspec(fn)
|
||||
self.argspec = inspect.getfullargspec(fn)
|
||||
self._construct_inputs_num = fn.__code__.co_argcount
|
||||
self._construct_inputs_names = fn.__code__.co_varnames
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment,
|
|||
|
||||
|
||||
class Momentum(Optimizer):
|
||||
"""
|
||||
r"""
|
||||
Implements the Momentum algorithm.
|
||||
|
||||
Refer to the paper on the importance of initialization and momentum in deep learning for more details.
|
||||
|
@ -56,13 +56,13 @@ class Momentum(Optimizer):
|
|||
.. math::
|
||||
v_{t} = v_{t-1} \ast u + gradients
|
||||
|
||||
If use_nesterov is True:
|
||||
.. math::
|
||||
p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr)
|
||||
If use_nesterov is True:
|
||||
.. math::
|
||||
p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr)
|
||||
|
||||
If use_nesterov is Flase:
|
||||
.. math::
|
||||
p_{t} = p_{t-1} - lr \ast v_{t}
|
||||
If use_nesterov is Flase:
|
||||
.. math::
|
||||
p_{t} = p_{t-1} - lr \ast v_{t}
|
||||
|
||||
Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively.
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, s
|
|||
|
||||
|
||||
class SGD(Optimizer):
|
||||
"""
|
||||
r"""
|
||||
Implements stochastic gradient descent (optionally with momentum).
|
||||
|
||||
Introduction to SGD can be found at https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
|
||||
|
@ -47,15 +47,15 @@ class SGD(Optimizer):
|
|||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
.. math::
|
||||
v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
|
||||
v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
|
||||
|
||||
If nesterov is True:
|
||||
.. math::
|
||||
p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
|
||||
If nesterov is True:
|
||||
.. math::
|
||||
p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
|
||||
|
||||
If nesterov is Flase:
|
||||
.. math::
|
||||
p_{t+1} = p_{t} - lr \ast v_{t+1}
|
||||
If nesterov is Flase:
|
||||
.. math::
|
||||
p_{t+1} = p_{t} - lr \ast v_{t+1}
|
||||
|
||||
To be noticed, for the first step, v_{t+1} = gradient
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ class WithGradCell(Cell):
|
|||
|
||||
Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
|
||||
as argument. If loss function in None, the network must be a wrapper of network and loss function. This
|
||||
Cell accepts data and label as inputs and returns gradients for each trainable parameter.
|
||||
Cell accepts *inputs as inputs and returns gradients for each trainable parameter.
|
||||
|
||||
Note:
|
||||
Run in PyNative mode.
|
||||
|
@ -95,8 +95,7 @@ class WithGradCell(Cell):
|
|||
output value. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||
|
||||
Outputs:
|
||||
list, a list of Tensors with identical shapes as trainable weights.
|
||||
|
@ -126,12 +125,12 @@ class WithGradCell(Cell):
|
|||
self.network_with_loss = WithLossCell(self.network, self.loss_fn)
|
||||
self.network_with_loss.set_train()
|
||||
|
||||
def construct(self, data, label):
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
if self.sens is None:
|
||||
grads = self.grad(self.network_with_loss, weights)(data, label)
|
||||
grads = self.grad(self.network_with_loss, weights)(*inputs)
|
||||
else:
|
||||
grads = self.grad(self.network_with_loss, weights)(data, label, self.sens)
|
||||
grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
|
||||
return grads
|
||||
|
||||
|
||||
|
@ -139,7 +138,7 @@ class TrainOneStepCell(Cell):
|
|||
r"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
|
||||
Wraps the network with an optimizer. The resulting Cell be trained with input *inputs.
|
||||
Backward graph will be created in the construct function to do parameter updating. Different
|
||||
parallel modes are available to run the training.
|
||||
|
||||
|
@ -149,8 +148,7 @@ class TrainOneStepCell(Cell):
|
|||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
|
@ -181,11 +179,11 @@ class TrainOneStepCell(Cell):
|
|||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, data, label):
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
loss = self.network(*inputs)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(data, label, sens)
|
||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
|
Loading…
Reference in New Issue