!37691 set inputs bugfix

Merge pull request !37691 from Henry Shi/set_inputs_bugfix2
This commit is contained in:
i-robot 2022-07-12 03:01:06 +00:00 committed by Gitee
commit ec7ff48ec0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 18 additions and 8 deletions

View File

@ -142,6 +142,7 @@ class Cell(Cell_):
self._dynamic_shape_inputs = None
self.saved_dynamic_shape = None
self._jit_config_dict = dict()
self.grad_ops_label = False
def __getstate__(self):
base = Cell_.__getstate__(self)
@ -438,7 +439,8 @@ class Cell(Cell_):
if len(inputs) < positional_args:
raise TypeError(f"For 'Cell', the function construct need {positional_args} positional argument, "
f"but got {len(inputs)}.")
f"but got {len(inputs)}. When using set_inputs, please make sure that all networks"
f"and loss functions are configured with set_inputs.")
if len(inputs) > positional_args + default_args:
raise TypeError(f"For 'Cell', the function construct need {positional_args} positional argument and "
@ -881,27 +883,27 @@ class Cell(Cell_):
>>> from mindspore import nn, ops, Tensor, Model
>>> from mindspore import dataset as ds
>>> import numpy as np
>>>
>>> def get_data(num, w=2.0, b=3.0):
... for _ in range(num):
... x = np.random.uniform(-10.0, 10.0)
... noise = np.random.normal(0, 1)
... y = x * w + b + noise
... yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
>>>
>>> def create_dataset(num_data, batch_size=16):
... dataset = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
... dataset = dataset.batch(batch_size)
... return dataset
>>>
>>> class NetAddN(nn.Cell):
... def __init__(self):
... super(NetAddN, self).__init__()
... self.addn = ops.AddN()
...
... def construct(self, *Z):
... return self.addn(Z)
>>>
>>> ds_train = create_dataset(num_data=160)
>>> net = NetAddN()
>>> loss = nn.MAELoss()
@ -914,6 +916,9 @@ class Cell(Cell_):
NOTE:
This is an experimental interface that is subject to change or deletion.
"""
if self.grad_ops_label:
logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is'
f'generated.')
self._dynamic_shape_inputs = inputs
self._check_construct_args(*inputs)
for ele in self._dynamic_shape_inputs:

View File

@ -373,6 +373,7 @@ class GradOperation(GradOperation_):
dynamic_shape_inputs = None
if isinstance(fn, ms.nn.Cell):
dynamic_shape_inputs = fn.get_inputs()
fn.grad_ops_label = True
if self.get_by_list:
@ms_function(input_signature=dynamic_shape_inputs)
def after_grad(*args):

View File

@ -294,8 +294,10 @@ class Model:
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
net_inputs = network.get_inputs()
loss_inputs = [None]
if self._loss_fn:
loss_inputs = [self._loss_fn.get_inputs()]
if self._loss_fn.get_inputs():
loss_inputs = [*self._loss_fn.get_inputs()]
loss_inputs.pop(0)
if net_inputs:
net_inputs = [*net_inputs, *loss_inputs]
@ -349,7 +351,9 @@ class Model:
f" `loss_fn`.")
net_inputs = self._network.get_inputs()
loss_inputs = [self._loss_fn.get_inputs()]
loss_inputs = [None]
if self._loss_fn.get_inputs():
loss_inputs = [*self._loss_fn.get_inputs()]
loss_inputs.pop(0)
if net_inputs:
net_inputs = [*net_inputs, *loss_inputs]