set inputs bugfix

This commit is contained in:
Henry Shi 2022-06-26 16:59:49 +08:00 committed by Henry
parent 2bc6e70fd4
commit 7ee74f7ab2
5 changed files with 57 additions and 50 deletions

View File

@ -517,7 +517,8 @@
.. py:method:: set_inputs(*inputs) .. py:method:: set_inputs(*inputs)
设置编译计算图所需的输入,输入需与实例中定义的输入一致。 设置编译计算图所需的输入输入需与数据一致。若使用Model接口请确保所有传入Model的网络和损失函数都配置了set_inputs。
输入可以为动态或静态的Tensor。
**参数:** **参数:**

View File

@ -868,29 +868,47 @@ class Cell(Cell_):
def set_inputs(self, *inputs): def set_inputs(self, *inputs):
""" """
Save set inputs for computation graph. Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
configured with set_inputs. The inputs can be Tensor of either dynamic or static shape.
Args: Args:
inputs (tuple): Inputs of the Cell object. inputs (tuple): Inputs of the Cell object.
Examples: Examples:
>>> import numpy as np
>>> import mindspore as ms >>> import mindspore as ms
>>> from mindspore import nn, Tensor >>> from mindspore import nn, ops, Tensor, Model
>>> >>> from mindspore import dataset as ds
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") >>> import numpy as np
>>> class reluNet(nn.Cell):
>>> def get_data(num, w=2.0, b=3.0):
... for _ in range(num):
... x = np.random.uniform(-10.0, 10.0)
... noise = np.random.normal(0, 1)
... y = x * w + b + noise
... yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
>>> def create_dataset(num_data, batch_size=16):
... dataset = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
... dataset = dataset.batch(batch_size)
... return dataset
>>> class NetAddN(nn.Cell):
... def __init__(self): ... def __init__(self):
... super(reluNet, self).__init__() ... super(NetAddN, self).__init__()
... self.relu = nn.ReLU() ... self.addn = ops.AddN()
... def construct(self, x):
... return self.relu(x) ... def construct(self, *Z):
>>> ... return self.addn(Z)
>>> net = reluNet()
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32) >>> ds_train = create_dataset(num_data=160)
>>> net = NetAddN()
>>> loss = nn.MAELoss()
>>> input_dyn = Tensor(shape=[16, None], dtype=ms.float32)
>>> net.set_inputs(input_dyn) >>> net.set_inputs(input_dyn)
>>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32) >>> loss.set_inputs(None, input_dyn)
>>> output = net(input1) >>> model = Model(net, loss)
>>> model.train(epoch=1, train_dataset=ds_train)
NOTE: NOTE:
This is an experimental interface that is subject to change or deletion. This is an experimental interface that is subject to change or deletion.

View File

@ -182,7 +182,6 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager` ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager`
(with property `drop_overflow_update=False` ). (with property `drop_overflow_update=False` ).
""" """
validator.check_value_type('network', network, nn.Cell)
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt, validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt,
nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)) nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell))

View File

@ -140,10 +140,9 @@ def _check_inputs(network, dataset_shapes):
raise ValueError( raise ValueError(
f"For 'set_inputs', the Length of Tensor must be {dataset_inputs_len}, but got {network_inputs_len}" f"For 'set_inputs', the Length of Tensor must be {dataset_inputs_len}, but got {network_inputs_len}"
) )
for tensor_index, ele_tensor in enumerate(dataset_shapes): for tensor_index, ele_dataset_shape in enumerate(dataset_shapes):
i_inputs = dataset_shapes[tensor_index] set_inputs_shape = list(network_shapes[tensor_index].shape)
set_inputs_shape = list(ele_tensor) inputs_shape = list(ele_dataset_shape)
inputs_shape = list(i_inputs)
if len(inputs_shape) != len(set_inputs_shape): if len(inputs_shape) != len(set_inputs_shape):
raise ValueError( raise ValueError(
f"For 'set_inputs', the Dimension of Tensor shape must be {len(inputs_shape)}, but got " f"For 'set_inputs', the Dimension of Tensor shape must be {len(inputs_shape)}, but got "
@ -151,12 +150,12 @@ def _check_inputs(network, dataset_shapes):
) )
if network_shapes[tensor_index] is None: if network_shapes[tensor_index] is None:
break break
for index, ele_shape in enumerate(ele_tensor): for index, ele_shape in enumerate(ele_dataset_shape):
if network_shapes[tensor_index].shape[index] != -1: if network_shapes[tensor_index].shape[index] != -1:
if set_inputs_shape[index] != inputs_shape[index]: if set_inputs_shape[index] != ele_shape:
raise ValueError( raise ValueError(
f"For 'Length of Tensor shape', the value must be the same with that of inputs,but " f"For 'Tensor shape', the value must be the same with that of inputs, but "
f"got {ele_shape}" f"got {set_inputs_shape[index]}"
) )
else: else:
dataset_shapes[tensor_index][index] = -1 dataset_shapes[tensor_index][index] = -1

View File

@ -283,14 +283,16 @@ class Model:
def _build_train_network(self): def _build_train_network(self):
"""Build train network""" """Build train network"""
network = self._network network = self._network
Validator.check_value_type('network', network, nn.Cell)
if self._loss_scale_manager is not None and self._optimizer is None: if self._loss_scale_manager is not None and self._optimizer is None:
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.") raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
if network.get_inputs(): net_inputs = network.get_inputs()
net_set_inputs_temp = network.get_inputs()
if self._loss_fn: if self._loss_fn:
if self._loss_fn.get_inputs(): loss_inputs = [self._loss_fn.get_inputs()]
loss_set_inputs = self._check_loss_fn_set_inputs() loss_inputs.pop(0)
if net_inputs:
net_inputs = [*net_inputs, *loss_inputs]
if self._optimizer: if self._optimizer:
amp_config = {} amp_config = {}
if self._loss_scale_manager_set: if self._loss_scale_manager_set:
@ -312,20 +314,10 @@ class Model:
if self._optimizer is None: if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network' # In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs() _set_multi_subgraphs()
if self._loss_fn: if net_inputs is not None:
if self._loss_fn.get_inputs() and network.get_inputs(): network.set_inputs(*net_inputs)
network.set_inputs(*net_set_inputs_temp, *loss_set_inputs)
elif network.get_inputs():
network.set_inputs(*net_set_inputs_temp)
return network return network
def _check_loss_fn_set_inputs(self):
loss_set_inputs = []
for ele in self._loss_fn.get_inputs():
if ele is not None:
loss_set_inputs.append(ele)
return loss_set_inputs
def _build_eval_network(self, metrics, eval_network, eval_indexes): def _build_eval_network(self, metrics, eval_network, eval_indexes):
"""Build the network for evaluation.""" """Build the network for evaluation."""
self._metric_fns = get_metrics(metrics) self._metric_fns = get_metrics(metrics)
@ -350,16 +342,14 @@ class Model:
f" framework will automatically build an evaluation network with `network` and" f" framework will automatically build an evaluation network with `network` and"
f" `loss_fn`.") f" `loss_fn`.")
if self._network.get_inputs() is not None: net_inputs = self._network.get_inputs()
net_set_inputs = self._network.get_inputs() loss_inputs = [self._loss_fn.get_inputs()]
if self._loss_fn.get_inputs() is not None: loss_inputs.pop(0)
loss_set_inputs = self._loss_fn.get_inputs() if net_inputs:
net_inputs = [*net_inputs, *loss_inputs]
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"]) self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"])
if self._network.get_inputs() is not None: if net_inputs is not None:
if self._loss_fn.get_inputs() is not None: self._eval_network.set_inputs(*net_inputs)
self._eval_network.set_inputs(net_set_inputs, loss_set_inputs)
else:
self._eval_network.set_inputs(net_set_inputs)
self._eval_indexes = [0, 1, 2] self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):