set inputs bugfix

This commit is contained in:
Henry Shi 2022-06-26 16:59:49 +08:00 committed by Henry
parent 2bc6e70fd4
commit 7ee74f7ab2
5 changed files with 57 additions and 50 deletions

View File

@ -517,7 +517,8 @@
.. py:method:: set_inputs(*inputs)
设置编译计算图所需的输入,输入需与实例中定义的输入一致。
设置编译计算图所需的输入输入需与数据一致。若使用Model接口请确保所有传入Model的网络和损失函数都配置了set_inputs。
输入可以为动态或静态的Tensor。
**参数:**

View File

@ -868,29 +868,47 @@ class Cell(Cell_):
def set_inputs(self, *inputs):
"""
Save set inputs for computation graph.
Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
configured with set_inputs. The inputs can be Tensor of either dynamic or static shape.
Args:
inputs (tuple): Inputs of the Cell object.
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, Tensor
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
>>> class reluNet(nn.Cell):
>>> from mindspore import nn, ops, Tensor, Model
>>> from mindspore import dataset as ds
>>> import numpy as np
>>> def get_data(num, w=2.0, b=3.0):
... for _ in range(num):
... x = np.random.uniform(-10.0, 10.0)
... noise = np.random.normal(0, 1)
... y = x * w + b + noise
... yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
>>> def create_dataset(num_data, batch_size=16):
... dataset = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
... dataset = dataset.batch(batch_size)
... return dataset
>>> class NetAddN(nn.Cell):
... def __init__(self):
... super(reluNet, self).__init__()
... self.relu = nn.ReLU()
... def construct(self, x):
... return self.relu(x)
>>>
>>> net = reluNet()
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
... super(NetAddN, self).__init__()
... self.addn = ops.AddN()
... def construct(self, *Z):
... return self.addn(Z)
>>> ds_train = create_dataset(num_data=160)
>>> net = NetAddN()
>>> loss = nn.MAELoss()
>>> input_dyn = Tensor(shape=[16, None], dtype=ms.float32)
>>> net.set_inputs(input_dyn)
>>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
>>> output = net(input1)
>>> loss.set_inputs(None, input_dyn)
>>> model = Model(net, loss)
>>> model.train(epoch=1, train_dataset=ds_train)
NOTE:
This is an experimental interface that is subject to change or deletion.

View File

@ -182,7 +182,6 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager`
(with property `drop_overflow_update=False` ).
"""
validator.check_value_type('network', network, nn.Cell)
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt,
nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell))

View File

@ -140,10 +140,9 @@ def _check_inputs(network, dataset_shapes):
raise ValueError(
f"For 'set_inputs', the Length of Tensor must be {dataset_inputs_len}, but got {network_inputs_len}"
)
for tensor_index, ele_tensor in enumerate(dataset_shapes):
i_inputs = dataset_shapes[tensor_index]
set_inputs_shape = list(ele_tensor)
inputs_shape = list(i_inputs)
for tensor_index, ele_dataset_shape in enumerate(dataset_shapes):
set_inputs_shape = list(network_shapes[tensor_index].shape)
inputs_shape = list(ele_dataset_shape)
if len(inputs_shape) != len(set_inputs_shape):
raise ValueError(
f"For 'set_inputs', the Dimension of Tensor shape must be {len(inputs_shape)}, but got "
@ -151,12 +150,12 @@ def _check_inputs(network, dataset_shapes):
)
if network_shapes[tensor_index] is None:
break
for index, ele_shape in enumerate(ele_tensor):
for index, ele_shape in enumerate(ele_dataset_shape):
if network_shapes[tensor_index].shape[index] != -1:
if set_inputs_shape[index] != inputs_shape[index]:
if set_inputs_shape[index] != ele_shape:
raise ValueError(
f"For 'Length of Tensor shape', the value must be the same with that of inputs,but "
f"got {ele_shape}"
f"For 'Tensor shape', the value must be the same with that of inputs, but "
f"got {set_inputs_shape[index]}"
)
else:
dataset_shapes[tensor_index][index] = -1

View File

@ -283,14 +283,16 @@ class Model:
def _build_train_network(self):
"""Build train network"""
network = self._network
Validator.check_value_type('network', network, nn.Cell)
if self._loss_scale_manager is not None and self._optimizer is None:
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
if network.get_inputs():
net_set_inputs_temp = network.get_inputs()
net_inputs = network.get_inputs()
if self._loss_fn:
if self._loss_fn.get_inputs():
loss_set_inputs = self._check_loss_fn_set_inputs()
loss_inputs = [self._loss_fn.get_inputs()]
loss_inputs.pop(0)
if net_inputs:
net_inputs = [*net_inputs, *loss_inputs]
if self._optimizer:
amp_config = {}
if self._loss_scale_manager_set:
@ -312,20 +314,10 @@ class Model:
if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs()
if self._loss_fn:
if self._loss_fn.get_inputs() and network.get_inputs():
network.set_inputs(*net_set_inputs_temp, *loss_set_inputs)
elif network.get_inputs():
network.set_inputs(*net_set_inputs_temp)
if net_inputs is not None:
network.set_inputs(*net_inputs)
return network
def _check_loss_fn_set_inputs(self):
loss_set_inputs = []
for ele in self._loss_fn.get_inputs():
if ele is not None:
loss_set_inputs.append(ele)
return loss_set_inputs
def _build_eval_network(self, metrics, eval_network, eval_indexes):
"""Build the network for evaluation."""
self._metric_fns = get_metrics(metrics)
@ -350,16 +342,14 @@ class Model:
f" framework will automatically build an evaluation network with `network` and"
f" `loss_fn`.")
if self._network.get_inputs() is not None:
net_set_inputs = self._network.get_inputs()
if self._loss_fn.get_inputs() is not None:
loss_set_inputs = self._loss_fn.get_inputs()
net_inputs = self._network.get_inputs()
loss_inputs = [self._loss_fn.get_inputs()]
loss_inputs.pop(0)
if net_inputs:
net_inputs = [*net_inputs, *loss_inputs]
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"])
if self._network.get_inputs() is not None:
if self._loss_fn.get_inputs() is not None:
self._eval_network.set_inputs(net_set_inputs, loss_set_inputs)
else:
self._eval_network.set_inputs(net_set_inputs)
if net_inputs is not None:
self._eval_network.set_inputs(*net_inputs)
self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):