!37691 set inputs bugfix
Merge pull request !37691 from Henry Shi/set_inputs_bugfix2
This commit is contained in:
commit
ec7ff48ec0
|
@ -142,6 +142,7 @@ class Cell(Cell_):
|
|||
self._dynamic_shape_inputs = None
|
||||
self.saved_dynamic_shape = None
|
||||
self._jit_config_dict = dict()
|
||||
self.grad_ops_label = False
|
||||
|
||||
def __getstate__(self):
|
||||
base = Cell_.__getstate__(self)
|
||||
|
@ -438,7 +439,8 @@ class Cell(Cell_):
|
|||
|
||||
if len(inputs) < positional_args:
|
||||
raise TypeError(f"For 'Cell', the function construct need {positional_args} positional argument, "
|
||||
f"but got {len(inputs)}.")
|
||||
f"but got {len(inputs)}. When using set_inputs, please make sure that all networks"
|
||||
f"and loss functions are configured with set_inputs.")
|
||||
|
||||
if len(inputs) > positional_args + default_args:
|
||||
raise TypeError(f"For 'Cell', the function construct need {positional_args} positional argument and "
|
||||
|
@ -881,27 +883,27 @@ class Cell(Cell_):
|
|||
>>> from mindspore import nn, ops, Tensor, Model
|
||||
>>> from mindspore import dataset as ds
|
||||
>>> import numpy as np
|
||||
|
||||
>>>
|
||||
>>> def get_data(num, w=2.0, b=3.0):
|
||||
... for _ in range(num):
|
||||
... x = np.random.uniform(-10.0, 10.0)
|
||||
... noise = np.random.normal(0, 1)
|
||||
... y = x * w + b + noise
|
||||
... yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
|
||||
|
||||
>>>
|
||||
>>> def create_dataset(num_data, batch_size=16):
|
||||
... dataset = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
|
||||
... dataset = dataset.batch(batch_size)
|
||||
... return dataset
|
||||
|
||||
>>>
|
||||
>>> class NetAddN(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(NetAddN, self).__init__()
|
||||
... self.addn = ops.AddN()
|
||||
|
||||
...
|
||||
... def construct(self, *Z):
|
||||
... return self.addn(Z)
|
||||
|
||||
>>>
|
||||
>>> ds_train = create_dataset(num_data=160)
|
||||
>>> net = NetAddN()
|
||||
>>> loss = nn.MAELoss()
|
||||
|
@ -914,6 +916,9 @@ class Cell(Cell_):
|
|||
NOTE:
|
||||
This is an experimental interface that is subject to change or deletion.
|
||||
"""
|
||||
if self.grad_ops_label:
|
||||
logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is'
|
||||
f'generated.')
|
||||
self._dynamic_shape_inputs = inputs
|
||||
self._check_construct_args(*inputs)
|
||||
for ele in self._dynamic_shape_inputs:
|
||||
|
|
|
@ -373,6 +373,7 @@ class GradOperation(GradOperation_):
|
|||
dynamic_shape_inputs = None
|
||||
if isinstance(fn, ms.nn.Cell):
|
||||
dynamic_shape_inputs = fn.get_inputs()
|
||||
fn.grad_ops_label = True
|
||||
if self.get_by_list:
|
||||
@ms_function(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
|
|
|
@ -294,8 +294,10 @@ class Model:
|
|||
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
|
||||
|
||||
net_inputs = network.get_inputs()
|
||||
loss_inputs = [None]
|
||||
if self._loss_fn:
|
||||
loss_inputs = [self._loss_fn.get_inputs()]
|
||||
if self._loss_fn.get_inputs():
|
||||
loss_inputs = [*self._loss_fn.get_inputs()]
|
||||
loss_inputs.pop(0)
|
||||
if net_inputs:
|
||||
net_inputs = [*net_inputs, *loss_inputs]
|
||||
|
@ -349,7 +351,9 @@ class Model:
|
|||
f" `loss_fn`.")
|
||||
|
||||
net_inputs = self._network.get_inputs()
|
||||
loss_inputs = [self._loss_fn.get_inputs()]
|
||||
loss_inputs = [None]
|
||||
if self._loss_fn.get_inputs():
|
||||
loss_inputs = [*self._loss_fn.get_inputs()]
|
||||
loss_inputs.pop(0)
|
||||
if net_inputs:
|
||||
net_inputs = [*net_inputs, *loss_inputs]
|
||||
|
|
Loading…
Reference in New Issue