set inputs bugfix
This commit is contained in:
parent
2bc6e70fd4
commit
7ee74f7ab2
|
@ -517,7 +517,8 @@
|
||||||
|
|
||||||
.. py:method:: set_inputs(*inputs)
|
.. py:method:: set_inputs(*inputs)
|
||||||
|
|
||||||
设置编译计算图所需的输入,输入需与实例中定义的输入一致。
|
设置编译计算图所需的输入,输入需与数据一致。若使用Model接口,请确保所有传入Model的网络和损失函数都配置了set_inputs。
|
||||||
|
输入可以为动态或静态的Tensor。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
|
|
@ -868,29 +868,47 @@ class Cell(Cell_):
|
||||||
|
|
||||||
def set_inputs(self, *inputs):
|
def set_inputs(self, *inputs):
|
||||||
"""
|
"""
|
||||||
Save set inputs for computation graph.
|
Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
|
||||||
|
using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
|
||||||
|
configured with set_inputs. The inputs can be Tensor of either dynamic or static shape.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (tuple): Inputs of the Cell object.
|
inputs (tuple): Inputs of the Cell object.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import numpy as np
|
|
||||||
>>> import mindspore as ms
|
>>> import mindspore as ms
|
||||||
>>> from mindspore import nn, Tensor
|
>>> from mindspore import nn, ops, Tensor, Model
|
||||||
>>>
|
>>> from mindspore import dataset as ds
|
||||||
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
|
>>> import numpy as np
|
||||||
>>> class reluNet(nn.Cell):
|
|
||||||
|
>>> def get_data(num, w=2.0, b=3.0):
|
||||||
|
... for _ in range(num):
|
||||||
|
... x = np.random.uniform(-10.0, 10.0)
|
||||||
|
... noise = np.random.normal(0, 1)
|
||||||
|
... y = x * w + b + noise
|
||||||
|
... yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
|
||||||
|
|
||||||
|
>>> def create_dataset(num_data, batch_size=16):
|
||||||
|
... dataset = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
|
||||||
|
... dataset = dataset.batch(batch_size)
|
||||||
|
... return dataset
|
||||||
|
|
||||||
|
>>> class NetAddN(nn.Cell):
|
||||||
... def __init__(self):
|
... def __init__(self):
|
||||||
... super(reluNet, self).__init__()
|
... super(NetAddN, self).__init__()
|
||||||
... self.relu = nn.ReLU()
|
... self.addn = ops.AddN()
|
||||||
... def construct(self, x):
|
|
||||||
... return self.relu(x)
|
... def construct(self, *Z):
|
||||||
>>>
|
... return self.addn(Z)
|
||||||
>>> net = reluNet()
|
|
||||||
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
>>> ds_train = create_dataset(num_data=160)
|
||||||
|
>>> net = NetAddN()
|
||||||
|
>>> loss = nn.MAELoss()
|
||||||
|
>>> input_dyn = Tensor(shape=[16, None], dtype=ms.float32)
|
||||||
>>> net.set_inputs(input_dyn)
|
>>> net.set_inputs(input_dyn)
|
||||||
>>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
|
>>> loss.set_inputs(None, input_dyn)
|
||||||
>>> output = net(input1)
|
>>> model = Model(net, loss)
|
||||||
|
>>> model.train(epoch=1, train_dataset=ds_train)
|
||||||
|
|
||||||
NOTE:
|
NOTE:
|
||||||
This is an experimental interface that is subject to change or deletion.
|
This is an experimental interface that is subject to change or deletion.
|
||||||
|
|
|
@ -182,7 +182,6 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
||||||
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager`
|
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager`
|
||||||
(with property `drop_overflow_update=False` ).
|
(with property `drop_overflow_update=False` ).
|
||||||
"""
|
"""
|
||||||
validator.check_value_type('network', network, nn.Cell)
|
|
||||||
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt,
|
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt,
|
||||||
nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell))
|
nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell))
|
||||||
|
|
||||||
|
|
|
@ -140,10 +140,9 @@ def _check_inputs(network, dataset_shapes):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"For 'set_inputs', the Length of Tensor must be {dataset_inputs_len}, but got {network_inputs_len}"
|
f"For 'set_inputs', the Length of Tensor must be {dataset_inputs_len}, but got {network_inputs_len}"
|
||||||
)
|
)
|
||||||
for tensor_index, ele_tensor in enumerate(dataset_shapes):
|
for tensor_index, ele_dataset_shape in enumerate(dataset_shapes):
|
||||||
i_inputs = dataset_shapes[tensor_index]
|
set_inputs_shape = list(network_shapes[tensor_index].shape)
|
||||||
set_inputs_shape = list(ele_tensor)
|
inputs_shape = list(ele_dataset_shape)
|
||||||
inputs_shape = list(i_inputs)
|
|
||||||
if len(inputs_shape) != len(set_inputs_shape):
|
if len(inputs_shape) != len(set_inputs_shape):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"For 'set_inputs', the Dimension of Tensor shape must be {len(inputs_shape)}, but got "
|
f"For 'set_inputs', the Dimension of Tensor shape must be {len(inputs_shape)}, but got "
|
||||||
|
@ -151,12 +150,12 @@ def _check_inputs(network, dataset_shapes):
|
||||||
)
|
)
|
||||||
if network_shapes[tensor_index] is None:
|
if network_shapes[tensor_index] is None:
|
||||||
break
|
break
|
||||||
for index, ele_shape in enumerate(ele_tensor):
|
for index, ele_shape in enumerate(ele_dataset_shape):
|
||||||
if network_shapes[tensor_index].shape[index] != -1:
|
if network_shapes[tensor_index].shape[index] != -1:
|
||||||
if set_inputs_shape[index] != inputs_shape[index]:
|
if set_inputs_shape[index] != ele_shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"For 'Length of Tensor shape', the value must be the same with that of inputs,but "
|
f"For 'Tensor shape', the value must be the same with that of inputs, but "
|
||||||
f"got {ele_shape}"
|
f"got {set_inputs_shape[index]}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset_shapes[tensor_index][index] = -1
|
dataset_shapes[tensor_index][index] = -1
|
||||||
|
|
|
@ -283,14 +283,16 @@ class Model:
|
||||||
def _build_train_network(self):
|
def _build_train_network(self):
|
||||||
"""Build train network"""
|
"""Build train network"""
|
||||||
network = self._network
|
network = self._network
|
||||||
|
Validator.check_value_type('network', network, nn.Cell)
|
||||||
if self._loss_scale_manager is not None and self._optimizer is None:
|
if self._loss_scale_manager is not None and self._optimizer is None:
|
||||||
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
|
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
|
||||||
|
|
||||||
if network.get_inputs():
|
net_inputs = network.get_inputs()
|
||||||
net_set_inputs_temp = network.get_inputs()
|
|
||||||
if self._loss_fn:
|
if self._loss_fn:
|
||||||
if self._loss_fn.get_inputs():
|
loss_inputs = [self._loss_fn.get_inputs()]
|
||||||
loss_set_inputs = self._check_loss_fn_set_inputs()
|
loss_inputs.pop(0)
|
||||||
|
if net_inputs:
|
||||||
|
net_inputs = [*net_inputs, *loss_inputs]
|
||||||
if self._optimizer:
|
if self._optimizer:
|
||||||
amp_config = {}
|
amp_config = {}
|
||||||
if self._loss_scale_manager_set:
|
if self._loss_scale_manager_set:
|
||||||
|
@ -312,20 +314,10 @@ class Model:
|
||||||
if self._optimizer is None:
|
if self._optimizer is None:
|
||||||
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
||||||
_set_multi_subgraphs()
|
_set_multi_subgraphs()
|
||||||
if self._loss_fn:
|
if net_inputs is not None:
|
||||||
if self._loss_fn.get_inputs() and network.get_inputs():
|
network.set_inputs(*net_inputs)
|
||||||
network.set_inputs(*net_set_inputs_temp, *loss_set_inputs)
|
|
||||||
elif network.get_inputs():
|
|
||||||
network.set_inputs(*net_set_inputs_temp)
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
def _check_loss_fn_set_inputs(self):
|
|
||||||
loss_set_inputs = []
|
|
||||||
for ele in self._loss_fn.get_inputs():
|
|
||||||
if ele is not None:
|
|
||||||
loss_set_inputs.append(ele)
|
|
||||||
return loss_set_inputs
|
|
||||||
|
|
||||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||||
"""Build the network for evaluation."""
|
"""Build the network for evaluation."""
|
||||||
self._metric_fns = get_metrics(metrics)
|
self._metric_fns = get_metrics(metrics)
|
||||||
|
@ -350,16 +342,14 @@ class Model:
|
||||||
f" framework will automatically build an evaluation network with `network` and"
|
f" framework will automatically build an evaluation network with `network` and"
|
||||||
f" `loss_fn`.")
|
f" `loss_fn`.")
|
||||||
|
|
||||||
if self._network.get_inputs() is not None:
|
net_inputs = self._network.get_inputs()
|
||||||
net_set_inputs = self._network.get_inputs()
|
loss_inputs = [self._loss_fn.get_inputs()]
|
||||||
if self._loss_fn.get_inputs() is not None:
|
loss_inputs.pop(0)
|
||||||
loss_set_inputs = self._loss_fn.get_inputs()
|
if net_inputs:
|
||||||
|
net_inputs = [*net_inputs, *loss_inputs]
|
||||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"])
|
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"])
|
||||||
if self._network.get_inputs() is not None:
|
if net_inputs is not None:
|
||||||
if self._loss_fn.get_inputs() is not None:
|
self._eval_network.set_inputs(*net_inputs)
|
||||||
self._eval_network.set_inputs(net_set_inputs, loss_set_inputs)
|
|
||||||
else:
|
|
||||||
self._eval_network.set_inputs(net_set_inputs)
|
|
||||||
self._eval_indexes = [0, 1, 2]
|
self._eval_indexes = [0, 1, 2]
|
||||||
|
|
||||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
|
|
Loading…
Reference in New Issue